greentic_runner_host/
routing.rs1use std::sync::Arc;
2
3use anyhow::{Result, anyhow, bail};
4use axum::extract::{FromRef, FromRequestParts};
5use axum::http::header::{AUTHORIZATION, HOST};
6use axum::http::request::Parts;
7use axum::http::{HeaderName, StatusCode};
8use base64::Engine;
9use base64::engine::general_purpose::STANDARD;
10use serde_json::Value;
11use serde_json::json;
12
13use crate::runner::ServerState;
14use crate::runtime::TenantRuntime;
15
16#[derive(Clone)]
17pub struct RoutingConfig {
18 pub resolver: TenantResolver,
19 pub default_tenant: String,
20}
21
22impl RoutingConfig {
23 pub fn from_env() -> Self {
24 Self::from_env_with_default("demo".into())
25 }
26
27 pub fn from_env_with_default(default_tenant: String) -> Self {
28 let default_tenant = std::env::var("DEFAULT_TENANT").unwrap_or(default_tenant);
29 let resolver = std::env::var("TENANT_RESOLVER")
30 .map(|value| TenantResolver::from_str(&value, &default_tenant))
31 .unwrap_or(Ok(TenantResolver::Env))
32 .unwrap_or_else(|err| {
33 tracing::warn!(error = %err, "invalid TENANT_RESOLVER, falling back to env");
34 TenantResolver::Env
35 });
36 Self {
37 resolver,
38 default_tenant,
39 }
40 }
41}
42
43impl Default for RoutingConfig {
44 fn default() -> Self {
45 Self {
46 resolver: TenantResolver::Env,
47 default_tenant: "demo".into(),
48 }
49 }
50}
51
52#[derive(Clone)]
53pub enum TenantResolver {
54 Host,
55 Header(HeaderName),
56 Jwt { header: HeaderName, claim: String },
57 Env,
58}
59
60impl TenantResolver {
61 fn from_str(value: &str, _default: &str) -> Result<Self> {
62 match value.to_ascii_lowercase().as_str() {
63 "host" => Ok(Self::Host),
64 "header" => Ok(Self::Header(HeaderName::from_static("x-greentic-tenant"))),
65 "jwt" => Ok(Self::Jwt {
66 header: AUTHORIZATION,
67 claim: "tenant".into(),
68 }),
69 "env" => Ok(Self::Env),
70 other => bail!("unsupported TENANT_RESOLVER `{other}`"),
71 }
72 }
73}
74
75#[derive(Clone)]
76pub struct TenantRouting {
77 resolver: TenantResolver,
78 default_tenant: String,
79}
80
81impl TenantRouting {
82 pub fn new(cfg: RoutingConfig) -> Self {
83 Self {
84 resolver: cfg.resolver,
85 default_tenant: cfg.default_tenant,
86 }
87 }
88
89 pub fn resolve(&self, parts: &Parts) -> Result<String> {
90 match &self.resolver {
91 TenantResolver::Env => Ok(self.default_tenant.clone()),
92 TenantResolver::Host => {
93 let host = parts
94 .headers
95 .get(HOST)
96 .and_then(|value| value.to_str().ok())
97 .unwrap_or_default();
98 if host.is_empty() {
99 return Ok(self.default_tenant.clone());
100 }
101 Ok(host
102 .split('.')
103 .next()
104 .map(|segment| segment.to_string())
105 .filter(|segment| !segment.is_empty())
106 .unwrap_or_else(|| self.default_tenant.clone()))
107 }
108 TenantResolver::Header(name) => {
109 let tenant = parts
110 .headers
111 .get(name)
112 .and_then(|value| value.to_str().ok())
113 .filter(|value| !value.is_empty())
114 .map(|value| value.to_string())
115 .unwrap_or_else(|| self.default_tenant.clone());
116 Ok(tenant)
117 }
118 TenantResolver::Jwt { header, claim } => {
119 let token = parts
120 .headers
121 .get(header)
122 .and_then(|value| value.to_str().ok())
123 .and_then(|value| value.strip_prefix("Bearer "))
124 .ok_or_else(|| anyhow!("authorization header missing"))?;
125 let tenant = decode_jwt_claim(token, claim)
126 .unwrap_or_else(|err| {
127 tracing::warn!(error = %err, "failed to decode jwt claim");
128 None
129 })
130 .unwrap_or_else(|| self.default_tenant.clone());
131 Ok(tenant)
132 }
133 }
134 }
135}
136
137fn decode_jwt_claim(token: &str, claim: &str) -> Result<Option<String>> {
138 let payload = token
139 .split('.')
140 .nth(1)
141 .ok_or_else(|| anyhow!("invalid jwt structure"))?;
142 let padded = match payload.len() % 4 {
143 2 => format!("{payload}=="),
144 3 => format!("{payload}="),
145 _ => payload.to_string(),
146 };
147 let bytes = STANDARD.decode(padded.as_bytes())?;
148 let value: Value = serde_json::from_slice(&bytes)?;
149 Ok(value
150 .get(claim)
151 .and_then(|node| node.as_str())
152 .map(|value| value.to_string()))
153}
154
155pub struct TenantRuntimeHandle {
156 pub tenant: String,
157 pub runtime: Arc<TenantRuntime>,
158}
159
160impl<S> FromRequestParts<S> for TenantRuntimeHandle
161where
162 ServerState: FromRef<S>,
163 S: Send + Sync,
164{
165 type Rejection = (StatusCode, axum::Json<Value>);
166
167 fn from_request_parts(
168 parts: &mut Parts,
169 state: &S,
170 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
171 let server_state = ServerState::from_ref(state);
172 async move {
173 let tenant = server_state.routing.resolve(parts).map_err(|err| {
174 (
175 StatusCode::BAD_REQUEST,
176 axum::Json(json!({ "error": err.to_string() })),
177 )
178 })?;
179 let runtime = server_state.active.load(&tenant).ok_or_else(|| {
180 (
181 StatusCode::NOT_FOUND,
182 axum::Json(json!({ "error": "tenant not loaded" })),
183 )
184 })?;
185 Ok(Self { tenant, runtime })
186 }
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use axum::http::Request;
194
195 #[test]
196 fn host_resolver_picks_subdomain() {
197 let routing = TenantRouting::new(RoutingConfig {
198 resolver: TenantResolver::Host,
199 default_tenant: "demo".into(),
200 });
201 let (parts, _) = Request::builder()
202 .uri("http://foo.example.com/webhook")
203 .header(HOST, "foo.example.com")
204 .body(())
205 .unwrap()
206 .into_parts();
207 let tenant = routing.resolve(&parts).unwrap();
208 assert_eq!(tenant, "foo");
209 }
210
211 #[test]
212 fn header_resolver_defaults() {
213 let routing = TenantRouting::new(RoutingConfig {
214 resolver: TenantResolver::Header(HeaderName::from_static("x-tenant")),
215 default_tenant: "demo".into(),
216 });
217 let (parts, _) = Request::builder()
218 .uri("http://localhost")
219 .body(())
220 .unwrap()
221 .into_parts();
222 let tenant = routing.resolve(&parts).unwrap();
223 assert_eq!(tenant, "demo");
224 }
225
226 #[test]
227 fn from_env_with_default_uses_override() {
228 let expected = std::env::var("DEFAULT_TENANT").unwrap_or_else(|_| "custom".into());
229 let cfg = RoutingConfig::from_env_with_default("custom".into());
230 assert_eq!(cfg.default_tenant, expected);
231 }
232}