greentic_runner_host/
routing.rs

1use std::sync::Arc;
2
3use anyhow::{Result, anyhow, bail};
4use axum::extract::{FromRef, FromRequestParts};
5use axum::http::header::{AUTHORIZATION, HOST};
6use axum::http::request::Parts;
7use axum::http::{HeaderName, StatusCode};
8use base64::Engine;
9use base64::engine::general_purpose::STANDARD;
10use serde_json::Value;
11use serde_json::json;
12
13use crate::runner::ServerState;
14use crate::runtime::TenantRuntime;
15
16#[derive(Clone)]
17pub struct RoutingConfig {
18    pub resolver: TenantResolver,
19    pub default_tenant: String,
20}
21
22impl RoutingConfig {
23    pub fn from_env() -> Self {
24        let default_tenant = std::env::var("DEFAULT_TENANT").unwrap_or_else(|_| "demo".to_string());
25        let resolver = std::env::var("TENANT_RESOLVER")
26            .map(|value| TenantResolver::from_str(&value, &default_tenant))
27            .unwrap_or(Ok(TenantResolver::Env))
28            .unwrap_or_else(|err| {
29                tracing::warn!(error = %err, "invalid TENANT_RESOLVER, falling back to env");
30                TenantResolver::Env
31            });
32        Self {
33            resolver,
34            default_tenant,
35        }
36    }
37}
38
39#[derive(Clone)]
40pub enum TenantResolver {
41    Host,
42    Header(HeaderName),
43    Jwt { header: HeaderName, claim: String },
44    Env,
45}
46
47impl TenantResolver {
48    fn from_str(value: &str, _default: &str) -> Result<Self> {
49        match value.to_ascii_lowercase().as_str() {
50            "host" => Ok(Self::Host),
51            "header" => Ok(Self::Header(HeaderName::from_static("x-greentic-tenant"))),
52            "jwt" => Ok(Self::Jwt {
53                header: AUTHORIZATION,
54                claim: "tenant".into(),
55            }),
56            "env" => Ok(Self::Env),
57            other => bail!("unsupported TENANT_RESOLVER `{other}`"),
58        }
59    }
60}
61
62#[derive(Clone)]
63pub struct TenantRouting {
64    resolver: TenantResolver,
65    default_tenant: String,
66}
67
68impl TenantRouting {
69    pub fn new(cfg: RoutingConfig) -> Self {
70        Self {
71            resolver: cfg.resolver,
72            default_tenant: cfg.default_tenant,
73        }
74    }
75
76    pub fn resolve(&self, parts: &Parts) -> Result<String> {
77        match &self.resolver {
78            TenantResolver::Env => Ok(self.default_tenant.clone()),
79            TenantResolver::Host => {
80                let host = parts
81                    .headers
82                    .get(HOST)
83                    .and_then(|value| value.to_str().ok())
84                    .unwrap_or_default();
85                if host.is_empty() {
86                    return Ok(self.default_tenant.clone());
87                }
88                Ok(host
89                    .split('.')
90                    .next()
91                    .map(|segment| segment.to_string())
92                    .filter(|segment| !segment.is_empty())
93                    .unwrap_or_else(|| self.default_tenant.clone()))
94            }
95            TenantResolver::Header(name) => {
96                let tenant = parts
97                    .headers
98                    .get(name)
99                    .and_then(|value| value.to_str().ok())
100                    .filter(|value| !value.is_empty())
101                    .map(|value| value.to_string())
102                    .unwrap_or_else(|| self.default_tenant.clone());
103                Ok(tenant)
104            }
105            TenantResolver::Jwt { header, claim } => {
106                let token = parts
107                    .headers
108                    .get(header)
109                    .and_then(|value| value.to_str().ok())
110                    .and_then(|value| value.strip_prefix("Bearer "))
111                    .ok_or_else(|| anyhow!("authorization header missing"))?;
112                let tenant = decode_jwt_claim(token, claim)
113                    .unwrap_or_else(|err| {
114                        tracing::warn!(error = %err, "failed to decode jwt claim");
115                        None
116                    })
117                    .unwrap_or_else(|| self.default_tenant.clone());
118                Ok(tenant)
119            }
120        }
121    }
122}
123
124fn decode_jwt_claim(token: &str, claim: &str) -> Result<Option<String>> {
125    let payload = token
126        .split('.')
127        .nth(1)
128        .ok_or_else(|| anyhow!("invalid jwt structure"))?;
129    let padded = match payload.len() % 4 {
130        2 => format!("{payload}=="),
131        3 => format!("{payload}="),
132        _ => payload.to_string(),
133    };
134    let bytes = STANDARD.decode(padded.as_bytes())?;
135    let value: Value = serde_json::from_slice(&bytes)?;
136    Ok(value
137        .get(claim)
138        .and_then(|node| node.as_str())
139        .map(|value| value.to_string()))
140}
141
142pub struct TenantRuntimeHandle {
143    pub tenant: String,
144    pub runtime: Arc<TenantRuntime>,
145}
146
147impl<S> FromRequestParts<S> for TenantRuntimeHandle
148where
149    ServerState: FromRef<S>,
150    S: Send + Sync,
151{
152    type Rejection = (StatusCode, axum::Json<Value>);
153
154    fn from_request_parts(
155        parts: &mut Parts,
156        state: &S,
157    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
158        let server_state = ServerState::from_ref(state);
159        async move {
160            let tenant = server_state.routing.resolve(parts).map_err(|err| {
161                (
162                    StatusCode::BAD_REQUEST,
163                    axum::Json(json!({ "error": err.to_string() })),
164                )
165            })?;
166            let runtime = server_state.active.load(&tenant).ok_or_else(|| {
167                (
168                    StatusCode::NOT_FOUND,
169                    axum::Json(json!({ "error": "tenant not loaded" })),
170                )
171            })?;
172            Ok(Self { tenant, runtime })
173        }
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use axum::http::Request;
181
182    #[test]
183    fn host_resolver_picks_subdomain() {
184        let routing = TenantRouting::new(RoutingConfig {
185            resolver: TenantResolver::Host,
186            default_tenant: "demo".into(),
187        });
188        let (parts, _) = Request::builder()
189            .uri("http://foo.example.com/webhook")
190            .header(HOST, "foo.example.com")
191            .body(())
192            .unwrap()
193            .into_parts();
194        let tenant = routing.resolve(&parts).unwrap();
195        assert_eq!(tenant, "foo");
196    }
197
198    #[test]
199    fn header_resolver_defaults() {
200        let routing = TenantRouting::new(RoutingConfig {
201            resolver: TenantResolver::Header(HeaderName::from_static("x-tenant")),
202            default_tenant: "demo".into(),
203        });
204        let (parts, _) = Request::builder()
205            .uri("http://localhost")
206            .body(())
207            .unwrap()
208            .into_parts();
209        let tenant = routing.resolve(&parts).unwrap();
210        assert_eq!(tenant, "demo");
211    }
212}