Skip to main content

greentic_runner_host/
routing.rs

1use std::sync::Arc;
2
3use anyhow::{Result, anyhow, bail};
4use axum::extract::{FromRef, FromRequestParts};
5use axum::http::header::{AUTHORIZATION, HOST};
6use axum::http::request::Parts;
7use axum::http::{HeaderName, StatusCode};
8use base64::Engine;
9use base64::engine::general_purpose::{STANDARD, URL_SAFE, URL_SAFE_NO_PAD};
10use serde_json::Value;
11use serde_json::json;
12
13use crate::runner::ServerState;
14use crate::runtime::TenantRuntime;
15
16#[derive(Clone)]
17pub struct RoutingConfig {
18    pub resolver: TenantResolver,
19    pub default_tenant: String,
20}
21
22impl RoutingConfig {
23    pub fn from_env() -> Self {
24        Self::from_env_with_default("demo".into())
25    }
26
27    pub fn from_env_with_default(default_tenant: String) -> Self {
28        let default_tenant = std::env::var("DEFAULT_TENANT").unwrap_or(default_tenant);
29        let resolver = std::env::var("TENANT_RESOLVER")
30            .map(|value| TenantResolver::from_str(&value, &default_tenant))
31            .unwrap_or(Ok(TenantResolver::Env))
32            .unwrap_or_else(|err| {
33                tracing::warn!(error = %err, "invalid TENANT_RESOLVER, falling back to env");
34                TenantResolver::Env
35            });
36        Self {
37            resolver,
38            default_tenant,
39        }
40    }
41}
42
43impl Default for RoutingConfig {
44    fn default() -> Self {
45        Self {
46            resolver: TenantResolver::Env,
47            default_tenant: "demo".into(),
48        }
49    }
50}
51
52#[derive(Clone)]
53pub enum TenantResolver {
54    Host,
55    Header(HeaderName),
56    Jwt { header: HeaderName, claim: String },
57    Env,
58}
59
60impl TenantResolver {
61    fn from_str(value: &str, _default: &str) -> Result<Self> {
62        match value.to_ascii_lowercase().as_str() {
63            "host" => Ok(Self::Host),
64            "header" => Ok(Self::Header(HeaderName::from_static("x-greentic-tenant"))),
65            "jwt" => Ok(Self::Jwt {
66                header: AUTHORIZATION,
67                claim: "tenant".into(),
68            }),
69            "env" => Ok(Self::Env),
70            other => bail!("unsupported TENANT_RESOLVER `{other}`"),
71        }
72    }
73}
74
75#[derive(Clone)]
76pub struct TenantRouting {
77    resolver: TenantResolver,
78    default_tenant: String,
79}
80
81impl TenantRouting {
82    pub fn new(cfg: RoutingConfig) -> Self {
83        Self {
84            resolver: cfg.resolver,
85            default_tenant: cfg.default_tenant,
86        }
87    }
88
89    pub fn resolve(&self, parts: &Parts) -> Result<String> {
90        match &self.resolver {
91            TenantResolver::Env => Ok(self.default_tenant.clone()),
92            TenantResolver::Host => {
93                let host = parts
94                    .headers
95                    .get(HOST)
96                    .and_then(|value| value.to_str().ok())
97                    .unwrap_or_default();
98                if host.is_empty() {
99                    return Ok(self.default_tenant.clone());
100                }
101                Ok(host
102                    .split('.')
103                    .next()
104                    .map(|segment| segment.to_string())
105                    .filter(|segment| !segment.is_empty())
106                    .unwrap_or_else(|| self.default_tenant.clone()))
107            }
108            TenantResolver::Header(name) => {
109                let tenant = parts
110                    .headers
111                    .get(name)
112                    .and_then(|value| value.to_str().ok())
113                    .filter(|value| !value.is_empty())
114                    .map(|value| value.to_string())
115                    .unwrap_or_else(|| self.default_tenant.clone());
116                Ok(tenant)
117            }
118            TenantResolver::Jwt { header, claim } => {
119                let token = parts
120                    .headers
121                    .get(header)
122                    .and_then(|value| value.to_str().ok())
123                    .and_then(|value| value.strip_prefix("Bearer "))
124                    .ok_or_else(|| anyhow!("authorization header missing"))?;
125                let tenant = decode_jwt_claim(token, claim)
126                    .unwrap_or_else(|err| {
127                        tracing::warn!(error = %err, "failed to decode jwt claim");
128                        None
129                    })
130                    .unwrap_or_else(|| self.default_tenant.clone());
131                Ok(tenant)
132            }
133        }
134    }
135}
136
137fn decode_jwt_claim(token: &str, claim: &str) -> Result<Option<String>> {
138    let payload = token
139        .split('.')
140        .nth(1)
141        .ok_or_else(|| anyhow!("invalid jwt structure"))?;
142    let bytes = URL_SAFE_NO_PAD.decode(payload.as_bytes()).or_else(|_| {
143        let padded = match payload.len() % 4 {
144            2 => Some(format!("{payload}==")),
145            3 => Some(format!("{payload}=")),
146            _ => None,
147        };
148        if let Some(padded) = padded.as_deref() {
149            URL_SAFE
150                .decode(padded.as_bytes())
151                .or_else(|_| STANDARD.decode(padded.as_bytes()))
152        } else {
153            URL_SAFE
154                .decode(payload.as_bytes())
155                .or_else(|_| STANDARD.decode(payload.as_bytes()))
156        }
157    })?;
158    let value: Value = serde_json::from_slice(&bytes)?;
159    Ok(value
160        .get(claim)
161        .and_then(|node| node.as_str())
162        .map(|value| value.to_string()))
163}
164
165pub struct TenantRuntimeHandle {
166    pub tenant: String,
167    pub runtime: Arc<TenantRuntime>,
168}
169
170impl<S> FromRequestParts<S> for TenantRuntimeHandle
171where
172    ServerState: FromRef<S>,
173    S: Send + Sync,
174{
175    type Rejection = (StatusCode, axum::Json<Value>);
176
177    fn from_request_parts(
178        parts: &mut Parts,
179        state: &S,
180    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
181        let server_state = ServerState::from_ref(state);
182        async move {
183            let tenant = server_state.routing.resolve(parts).map_err(|err| {
184                (
185                    StatusCode::BAD_REQUEST,
186                    axum::Json(json!({ "error": err.to_string() })),
187                )
188            })?;
189            let runtime = server_state.active.load_pack(&tenant).ok_or_else(|| {
190                (
191                    StatusCode::NOT_FOUND,
192                    axum::Json(json!({ "error": "tenant not loaded" })),
193                )
194            })?;
195            Ok(Self { tenant, runtime })
196        }
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use axum::http::Request;
204
205    #[test]
206    fn host_resolver_picks_subdomain() {
207        let routing = TenantRouting::new(RoutingConfig {
208            resolver: TenantResolver::Host,
209            default_tenant: "demo".into(),
210        });
211        let (parts, _) = Request::builder()
212            .uri("http://foo.example.com/webhook")
213            .header(HOST, "foo.example.com")
214            .body(())
215            .unwrap()
216            .into_parts();
217        let tenant = routing.resolve(&parts).unwrap();
218        assert_eq!(tenant, "foo");
219    }
220
221    #[test]
222    fn header_resolver_defaults() {
223        let routing = TenantRouting::new(RoutingConfig {
224            resolver: TenantResolver::Header(HeaderName::from_static("x-tenant")),
225            default_tenant: "demo".into(),
226        });
227        let (parts, _) = Request::builder()
228            .uri("http://localhost")
229            .body(())
230            .unwrap()
231            .into_parts();
232        let tenant = routing.resolve(&parts).unwrap();
233        assert_eq!(tenant, "demo");
234    }
235
236    #[test]
237    fn from_env_with_default_uses_override() {
238        let expected = std::env::var("DEFAULT_TENANT").unwrap_or_else(|_| "custom".into());
239        let cfg = RoutingConfig::from_env_with_default("custom".into());
240        assert_eq!(cfg.default_tenant, expected);
241    }
242
243    #[test]
244    fn jwt_resolver_reads_tenant_claim() {
245        let routing = TenantRouting::new(RoutingConfig {
246            resolver: TenantResolver::Jwt {
247                header: AUTHORIZATION,
248                claim: "tenant".into(),
249            },
250            default_tenant: "demo".into(),
251        });
252        let payload = STANDARD.encode(br#"{"tenant":"jwt-tenant"}"#);
253        let token = format!("ignored.{payload}.ignored");
254        let (parts, _) = Request::builder()
255            .header(AUTHORIZATION, format!("Bearer {token}"))
256            .body(())
257            .unwrap()
258            .into_parts();
259
260        assert_eq!(routing.resolve(&parts).unwrap(), "jwt-tenant");
261    }
262
263    #[test]
264    fn jwt_resolver_falls_back_on_invalid_payload() {
265        let routing = TenantRouting::new(RoutingConfig {
266            resolver: TenantResolver::Jwt {
267                header: AUTHORIZATION,
268                claim: "tenant".into(),
269            },
270            default_tenant: "demo".into(),
271        });
272        let (parts, _) = Request::builder()
273            .header(AUTHORIZATION, "Bearer invalid.token.payload")
274            .body(())
275            .unwrap()
276            .into_parts();
277
278        assert_eq!(routing.resolve(&parts).unwrap(), "demo");
279    }
280
281    #[test]
282    fn jwt_resolver_requires_bearer_prefix() {
283        let routing = TenantRouting::new(RoutingConfig {
284            resolver: TenantResolver::Jwt {
285                header: AUTHORIZATION,
286                claim: "tenant".into(),
287            },
288            default_tenant: "demo".into(),
289        });
290        let (parts, _) = Request::builder().body(()).unwrap().into_parts();
291
292        assert!(routing.resolve(&parts).is_err());
293    }
294}