greentic_runner_host/
routing.rs1use std::sync::Arc;
2
3use anyhow::{Result, anyhow, bail};
4use axum::extract::{FromRef, FromRequestParts};
5use axum::http::header::{AUTHORIZATION, HOST};
6use axum::http::request::Parts;
7use axum::http::{HeaderName, StatusCode};
8use base64::Engine;
9use base64::engine::general_purpose::{STANDARD, URL_SAFE, URL_SAFE_NO_PAD};
10use serde_json::Value;
11use serde_json::json;
12
13use crate::runner::ServerState;
14use crate::runtime::TenantRuntime;
15
16#[derive(Clone)]
17pub struct RoutingConfig {
18 pub resolver: TenantResolver,
19 pub default_tenant: String,
20}
21
22impl RoutingConfig {
23 pub fn from_env() -> Self {
24 Self::from_env_with_default("demo".into())
25 }
26
27 pub fn from_env_with_default(default_tenant: String) -> Self {
28 let default_tenant = std::env::var("DEFAULT_TENANT").unwrap_or(default_tenant);
29 let resolver = std::env::var("TENANT_RESOLVER")
30 .map(|value| TenantResolver::from_str(&value, &default_tenant))
31 .unwrap_or(Ok(TenantResolver::Env))
32 .unwrap_or_else(|err| {
33 tracing::warn!(error = %err, "invalid TENANT_RESOLVER, falling back to env");
34 TenantResolver::Env
35 });
36 Self {
37 resolver,
38 default_tenant,
39 }
40 }
41}
42
43impl Default for RoutingConfig {
44 fn default() -> Self {
45 Self {
46 resolver: TenantResolver::Env,
47 default_tenant: "demo".into(),
48 }
49 }
50}
51
52#[derive(Clone)]
53pub enum TenantResolver {
54 Host,
55 Header(HeaderName),
56 Jwt { header: HeaderName, claim: String },
57 Env,
58}
59
60impl TenantResolver {
61 fn from_str(value: &str, _default: &str) -> Result<Self> {
62 match value.to_ascii_lowercase().as_str() {
63 "host" => Ok(Self::Host),
64 "header" => Ok(Self::Header(HeaderName::from_static("x-greentic-tenant"))),
65 "jwt" => Ok(Self::Jwt {
66 header: AUTHORIZATION,
67 claim: "tenant".into(),
68 }),
69 "env" => Ok(Self::Env),
70 other => bail!("unsupported TENANT_RESOLVER `{other}`"),
71 }
72 }
73}
74
75#[derive(Clone)]
76pub struct TenantRouting {
77 resolver: TenantResolver,
78 default_tenant: String,
79}
80
81impl TenantRouting {
82 pub fn new(cfg: RoutingConfig) -> Self {
83 Self {
84 resolver: cfg.resolver,
85 default_tenant: cfg.default_tenant,
86 }
87 }
88
89 pub fn resolve(&self, parts: &Parts) -> Result<String> {
90 match &self.resolver {
91 TenantResolver::Env => Ok(self.default_tenant.clone()),
92 TenantResolver::Host => {
93 let host = parts
94 .headers
95 .get(HOST)
96 .and_then(|value| value.to_str().ok())
97 .unwrap_or_default();
98 if host.is_empty() {
99 return Ok(self.default_tenant.clone());
100 }
101 Ok(host
102 .split('.')
103 .next()
104 .map(|segment| segment.to_string())
105 .filter(|segment| !segment.is_empty())
106 .unwrap_or_else(|| self.default_tenant.clone()))
107 }
108 TenantResolver::Header(name) => {
109 let tenant = parts
110 .headers
111 .get(name)
112 .and_then(|value| value.to_str().ok())
113 .filter(|value| !value.is_empty())
114 .map(|value| value.to_string())
115 .unwrap_or_else(|| self.default_tenant.clone());
116 Ok(tenant)
117 }
118 TenantResolver::Jwt { header, claim } => {
119 let token = parts
120 .headers
121 .get(header)
122 .and_then(|value| value.to_str().ok())
123 .and_then(|value| value.strip_prefix("Bearer "))
124 .ok_or_else(|| anyhow!("authorization header missing"))?;
125 let tenant = decode_jwt_claim(token, claim)
126 .unwrap_or_else(|err| {
127 tracing::warn!(error = %err, "failed to decode jwt claim");
128 None
129 })
130 .unwrap_or_else(|| self.default_tenant.clone());
131 Ok(tenant)
132 }
133 }
134 }
135}
136
137fn decode_jwt_claim(token: &str, claim: &str) -> Result<Option<String>> {
138 let payload = token
139 .split('.')
140 .nth(1)
141 .ok_or_else(|| anyhow!("invalid jwt structure"))?;
142 let bytes = URL_SAFE_NO_PAD.decode(payload.as_bytes()).or_else(|_| {
143 let padded = match payload.len() % 4 {
144 2 => Some(format!("{payload}==")),
145 3 => Some(format!("{payload}=")),
146 _ => None,
147 };
148 if let Some(padded) = padded.as_deref() {
149 URL_SAFE
150 .decode(padded.as_bytes())
151 .or_else(|_| STANDARD.decode(padded.as_bytes()))
152 } else {
153 URL_SAFE
154 .decode(payload.as_bytes())
155 .or_else(|_| STANDARD.decode(payload.as_bytes()))
156 }
157 })?;
158 let value: Value = serde_json::from_slice(&bytes)?;
159 Ok(value
160 .get(claim)
161 .and_then(|node| node.as_str())
162 .map(|value| value.to_string()))
163}
164
165pub struct TenantRuntimeHandle {
166 pub tenant: String,
167 pub runtime: Arc<TenantRuntime>,
168}
169
170impl<S> FromRequestParts<S> for TenantRuntimeHandle
171where
172 ServerState: FromRef<S>,
173 S: Send + Sync,
174{
175 type Rejection = (StatusCode, axum::Json<Value>);
176
177 fn from_request_parts(
178 parts: &mut Parts,
179 state: &S,
180 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
181 let server_state = ServerState::from_ref(state);
182 async move {
183 let tenant = server_state.routing.resolve(parts).map_err(|err| {
184 (
185 StatusCode::BAD_REQUEST,
186 axum::Json(json!({ "error": err.to_string() })),
187 )
188 })?;
189 let runtime = server_state.active.load_pack(&tenant).ok_or_else(|| {
190 (
191 StatusCode::NOT_FOUND,
192 axum::Json(json!({ "error": "tenant not loaded" })),
193 )
194 })?;
195 Ok(Self { tenant, runtime })
196 }
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use axum::http::Request;
204
205 #[test]
206 fn host_resolver_picks_subdomain() {
207 let routing = TenantRouting::new(RoutingConfig {
208 resolver: TenantResolver::Host,
209 default_tenant: "demo".into(),
210 });
211 let (parts, _) = Request::builder()
212 .uri("http://foo.example.com/webhook")
213 .header(HOST, "foo.example.com")
214 .body(())
215 .unwrap()
216 .into_parts();
217 let tenant = routing.resolve(&parts).unwrap();
218 assert_eq!(tenant, "foo");
219 }
220
221 #[test]
222 fn header_resolver_defaults() {
223 let routing = TenantRouting::new(RoutingConfig {
224 resolver: TenantResolver::Header(HeaderName::from_static("x-tenant")),
225 default_tenant: "demo".into(),
226 });
227 let (parts, _) = Request::builder()
228 .uri("http://localhost")
229 .body(())
230 .unwrap()
231 .into_parts();
232 let tenant = routing.resolve(&parts).unwrap();
233 assert_eq!(tenant, "demo");
234 }
235
236 #[test]
237 fn from_env_with_default_uses_override() {
238 let expected = std::env::var("DEFAULT_TENANT").unwrap_or_else(|_| "custom".into());
239 let cfg = RoutingConfig::from_env_with_default("custom".into());
240 assert_eq!(cfg.default_tenant, expected);
241 }
242
243 #[test]
244 fn jwt_resolver_reads_tenant_claim() {
245 let routing = TenantRouting::new(RoutingConfig {
246 resolver: TenantResolver::Jwt {
247 header: AUTHORIZATION,
248 claim: "tenant".into(),
249 },
250 default_tenant: "demo".into(),
251 });
252 let payload = STANDARD.encode(br#"{"tenant":"jwt-tenant"}"#);
253 let token = format!("ignored.{payload}.ignored");
254 let (parts, _) = Request::builder()
255 .header(AUTHORIZATION, format!("Bearer {token}"))
256 .body(())
257 .unwrap()
258 .into_parts();
259
260 assert_eq!(routing.resolve(&parts).unwrap(), "jwt-tenant");
261 }
262
263 #[test]
264 fn jwt_resolver_falls_back_on_invalid_payload() {
265 let routing = TenantRouting::new(RoutingConfig {
266 resolver: TenantResolver::Jwt {
267 header: AUTHORIZATION,
268 claim: "tenant".into(),
269 },
270 default_tenant: "demo".into(),
271 });
272 let (parts, _) = Request::builder()
273 .header(AUTHORIZATION, "Bearer invalid.token.payload")
274 .body(())
275 .unwrap()
276 .into_parts();
277
278 assert_eq!(routing.resolve(&parts).unwrap(), "demo");
279 }
280
281 #[test]
282 fn jwt_resolver_requires_bearer_prefix() {
283 let routing = TenantRouting::new(RoutingConfig {
284 resolver: TenantResolver::Jwt {
285 header: AUTHORIZATION,
286 claim: "tenant".into(),
287 },
288 default_tenant: "demo".into(),
289 });
290 let (parts, _) = Request::builder().body(()).unwrap().into_parts();
291
292 assert!(routing.resolve(&parts).is_err());
293 }
294}