greentic_runner_host/
routing.rs

1use std::sync::Arc;
2
3use anyhow::{Result, anyhow, bail};
4use axum::extract::{FromRef, FromRequestParts};
5use axum::http::header::{AUTHORIZATION, HOST};
6use axum::http::request::Parts;
7use axum::http::{HeaderName, StatusCode};
8use base64::Engine;
9use base64::engine::general_purpose::STANDARD;
10use serde_json::Value;
11use serde_json::json;
12
13use crate::runner::ServerState;
14use crate::runtime::TenantRuntime;
15
16#[derive(Clone)]
17pub struct RoutingConfig {
18    pub resolver: TenantResolver,
19    pub default_tenant: String,
20}
21
22impl RoutingConfig {
23    pub fn from_env() -> Self {
24        Self::from_env_with_default("demo".into())
25    }
26
27    pub fn from_env_with_default(default_tenant: String) -> Self {
28        let default_tenant = std::env::var("DEFAULT_TENANT").unwrap_or(default_tenant);
29        let resolver = std::env::var("TENANT_RESOLVER")
30            .map(|value| TenantResolver::from_str(&value, &default_tenant))
31            .unwrap_or(Ok(TenantResolver::Env))
32            .unwrap_or_else(|err| {
33                tracing::warn!(error = %err, "invalid TENANT_RESOLVER, falling back to env");
34                TenantResolver::Env
35            });
36        Self {
37            resolver,
38            default_tenant,
39        }
40    }
41}
42
43impl Default for RoutingConfig {
44    fn default() -> Self {
45        Self {
46            resolver: TenantResolver::Env,
47            default_tenant: "demo".into(),
48        }
49    }
50}
51
52#[derive(Clone)]
53pub enum TenantResolver {
54    Host,
55    Header(HeaderName),
56    Jwt { header: HeaderName, claim: String },
57    Env,
58}
59
60impl TenantResolver {
61    fn from_str(value: &str, _default: &str) -> Result<Self> {
62        match value.to_ascii_lowercase().as_str() {
63            "host" => Ok(Self::Host),
64            "header" => Ok(Self::Header(HeaderName::from_static("x-greentic-tenant"))),
65            "jwt" => Ok(Self::Jwt {
66                header: AUTHORIZATION,
67                claim: "tenant".into(),
68            }),
69            "env" => Ok(Self::Env),
70            other => bail!("unsupported TENANT_RESOLVER `{other}`"),
71        }
72    }
73}
74
75#[derive(Clone)]
76pub struct TenantRouting {
77    resolver: TenantResolver,
78    default_tenant: String,
79}
80
81impl TenantRouting {
82    pub fn new(cfg: RoutingConfig) -> Self {
83        Self {
84            resolver: cfg.resolver,
85            default_tenant: cfg.default_tenant,
86        }
87    }
88
89    pub fn resolve(&self, parts: &Parts) -> Result<String> {
90        match &self.resolver {
91            TenantResolver::Env => Ok(self.default_tenant.clone()),
92            TenantResolver::Host => {
93                let host = parts
94                    .headers
95                    .get(HOST)
96                    .and_then(|value| value.to_str().ok())
97                    .unwrap_or_default();
98                if host.is_empty() {
99                    return Ok(self.default_tenant.clone());
100                }
101                Ok(host
102                    .split('.')
103                    .next()
104                    .map(|segment| segment.to_string())
105                    .filter(|segment| !segment.is_empty())
106                    .unwrap_or_else(|| self.default_tenant.clone()))
107            }
108            TenantResolver::Header(name) => {
109                let tenant = parts
110                    .headers
111                    .get(name)
112                    .and_then(|value| value.to_str().ok())
113                    .filter(|value| !value.is_empty())
114                    .map(|value| value.to_string())
115                    .unwrap_or_else(|| self.default_tenant.clone());
116                Ok(tenant)
117            }
118            TenantResolver::Jwt { header, claim } => {
119                let token = parts
120                    .headers
121                    .get(header)
122                    .and_then(|value| value.to_str().ok())
123                    .and_then(|value| value.strip_prefix("Bearer "))
124                    .ok_or_else(|| anyhow!("authorization header missing"))?;
125                let tenant = decode_jwt_claim(token, claim)
126                    .unwrap_or_else(|err| {
127                        tracing::warn!(error = %err, "failed to decode jwt claim");
128                        None
129                    })
130                    .unwrap_or_else(|| self.default_tenant.clone());
131                Ok(tenant)
132            }
133        }
134    }
135}
136
137fn decode_jwt_claim(token: &str, claim: &str) -> Result<Option<String>> {
138    let payload = token
139        .split('.')
140        .nth(1)
141        .ok_or_else(|| anyhow!("invalid jwt structure"))?;
142    let padded = match payload.len() % 4 {
143        2 => format!("{payload}=="),
144        3 => format!("{payload}="),
145        _ => payload.to_string(),
146    };
147    let bytes = STANDARD.decode(padded.as_bytes())?;
148    let value: Value = serde_json::from_slice(&bytes)?;
149    Ok(value
150        .get(claim)
151        .and_then(|node| node.as_str())
152        .map(|value| value.to_string()))
153}
154
155pub struct TenantRuntimeHandle {
156    pub tenant: String,
157    pub runtime: Arc<TenantRuntime>,
158}
159
160impl<S> FromRequestParts<S> for TenantRuntimeHandle
161where
162    ServerState: FromRef<S>,
163    S: Send + Sync,
164{
165    type Rejection = (StatusCode, axum::Json<Value>);
166
167    fn from_request_parts(
168        parts: &mut Parts,
169        state: &S,
170    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
171        let server_state = ServerState::from_ref(state);
172        async move {
173            let tenant = server_state.routing.resolve(parts).map_err(|err| {
174                (
175                    StatusCode::BAD_REQUEST,
176                    axum::Json(json!({ "error": err.to_string() })),
177                )
178            })?;
179            let runtime = server_state.active.load(&tenant).ok_or_else(|| {
180                (
181                    StatusCode::NOT_FOUND,
182                    axum::Json(json!({ "error": "tenant not loaded" })),
183                )
184            })?;
185            Ok(Self { tenant, runtime })
186        }
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use axum::http::Request;
194
195    #[test]
196    fn host_resolver_picks_subdomain() {
197        let routing = TenantRouting::new(RoutingConfig {
198            resolver: TenantResolver::Host,
199            default_tenant: "demo".into(),
200        });
201        let (parts, _) = Request::builder()
202            .uri("http://foo.example.com/webhook")
203            .header(HOST, "foo.example.com")
204            .body(())
205            .unwrap()
206            .into_parts();
207        let tenant = routing.resolve(&parts).unwrap();
208        assert_eq!(tenant, "foo");
209    }
210
211    #[test]
212    fn header_resolver_defaults() {
213        let routing = TenantRouting::new(RoutingConfig {
214            resolver: TenantResolver::Header(HeaderName::from_static("x-tenant")),
215            default_tenant: "demo".into(),
216        });
217        let (parts, _) = Request::builder()
218            .uri("http://localhost")
219            .body(())
220            .unwrap()
221            .into_parts();
222        let tenant = routing.resolve(&parts).unwrap();
223        assert_eq!(tenant, "demo");
224    }
225
226    #[test]
227    fn from_env_with_default_uses_override() {
228        let cfg = RoutingConfig::from_env_with_default("custom".into());
229        assert_eq!(cfg.default_tenant, "custom");
230    }
231}