greentic_runner_host/
routing.rs1use std::sync::Arc;
2
3use anyhow::{Result, anyhow, bail};
4use axum::extract::{FromRef, FromRequestParts};
5use axum::http::header::{AUTHORIZATION, HOST};
6use axum::http::request::Parts;
7use axum::http::{HeaderName, StatusCode};
8use base64::Engine;
9use base64::engine::general_purpose::STANDARD;
10use serde_json::Value;
11use serde_json::json;
12
13use crate::runner::ServerState;
14use crate::runtime::TenantRuntime;
15
16#[derive(Clone)]
17pub struct RoutingConfig {
18 pub resolver: TenantResolver,
19 pub default_tenant: String,
20}
21
22impl RoutingConfig {
23 pub fn from_env() -> Self {
24 let default_tenant = std::env::var("DEFAULT_TENANT").unwrap_or_else(|_| "demo".to_string());
25 let resolver = std::env::var("TENANT_RESOLVER")
26 .map(|value| TenantResolver::from_str(&value, &default_tenant))
27 .unwrap_or(Ok(TenantResolver::Env))
28 .unwrap_or_else(|err| {
29 tracing::warn!(error = %err, "invalid TENANT_RESOLVER, falling back to env");
30 TenantResolver::Env
31 });
32 Self {
33 resolver,
34 default_tenant,
35 }
36 }
37}
38
39#[derive(Clone)]
40pub enum TenantResolver {
41 Host,
42 Header(HeaderName),
43 Jwt { header: HeaderName, claim: String },
44 Env,
45}
46
47impl TenantResolver {
48 fn from_str(value: &str, _default: &str) -> Result<Self> {
49 match value.to_ascii_lowercase().as_str() {
50 "host" => Ok(Self::Host),
51 "header" => Ok(Self::Header(HeaderName::from_static("x-greentic-tenant"))),
52 "jwt" => Ok(Self::Jwt {
53 header: AUTHORIZATION,
54 claim: "tenant".into(),
55 }),
56 "env" => Ok(Self::Env),
57 other => bail!("unsupported TENANT_RESOLVER `{other}`"),
58 }
59 }
60}
61
62#[derive(Clone)]
63pub struct TenantRouting {
64 resolver: TenantResolver,
65 default_tenant: String,
66}
67
68impl TenantRouting {
69 pub fn new(cfg: RoutingConfig) -> Self {
70 Self {
71 resolver: cfg.resolver,
72 default_tenant: cfg.default_tenant,
73 }
74 }
75
76 pub fn resolve(&self, parts: &Parts) -> Result<String> {
77 match &self.resolver {
78 TenantResolver::Env => Ok(self.default_tenant.clone()),
79 TenantResolver::Host => {
80 let host = parts
81 .headers
82 .get(HOST)
83 .and_then(|value| value.to_str().ok())
84 .unwrap_or_default();
85 if host.is_empty() {
86 return Ok(self.default_tenant.clone());
87 }
88 Ok(host
89 .split('.')
90 .next()
91 .map(|segment| segment.to_string())
92 .filter(|segment| !segment.is_empty())
93 .unwrap_or_else(|| self.default_tenant.clone()))
94 }
95 TenantResolver::Header(name) => {
96 let tenant = parts
97 .headers
98 .get(name)
99 .and_then(|value| value.to_str().ok())
100 .filter(|value| !value.is_empty())
101 .map(|value| value.to_string())
102 .unwrap_or_else(|| self.default_tenant.clone());
103 Ok(tenant)
104 }
105 TenantResolver::Jwt { header, claim } => {
106 let token = parts
107 .headers
108 .get(header)
109 .and_then(|value| value.to_str().ok())
110 .and_then(|value| value.strip_prefix("Bearer "))
111 .ok_or_else(|| anyhow!("authorization header missing"))?;
112 let tenant = decode_jwt_claim(token, claim)
113 .unwrap_or_else(|err| {
114 tracing::warn!(error = %err, "failed to decode jwt claim");
115 None
116 })
117 .unwrap_or_else(|| self.default_tenant.clone());
118 Ok(tenant)
119 }
120 }
121 }
122}
123
124fn decode_jwt_claim(token: &str, claim: &str) -> Result<Option<String>> {
125 let payload = token
126 .split('.')
127 .nth(1)
128 .ok_or_else(|| anyhow!("invalid jwt structure"))?;
129 let padded = match payload.len() % 4 {
130 2 => format!("{payload}=="),
131 3 => format!("{payload}="),
132 _ => payload.to_string(),
133 };
134 let bytes = STANDARD.decode(padded.as_bytes())?;
135 let value: Value = serde_json::from_slice(&bytes)?;
136 Ok(value
137 .get(claim)
138 .and_then(|node| node.as_str())
139 .map(|value| value.to_string()))
140}
141
142pub struct TenantRuntimeHandle {
143 pub tenant: String,
144 pub runtime: Arc<TenantRuntime>,
145}
146
147impl<S> FromRequestParts<S> for TenantRuntimeHandle
148where
149 ServerState: FromRef<S>,
150 S: Send + Sync,
151{
152 type Rejection = (StatusCode, axum::Json<Value>);
153
154 fn from_request_parts(
155 parts: &mut Parts,
156 state: &S,
157 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
158 let server_state = ServerState::from_ref(state);
159 async move {
160 let tenant = server_state.routing.resolve(parts).map_err(|err| {
161 (
162 StatusCode::BAD_REQUEST,
163 axum::Json(json!({ "error": err.to_string() })),
164 )
165 })?;
166 let runtime = server_state.active.load(&tenant).ok_or_else(|| {
167 (
168 StatusCode::NOT_FOUND,
169 axum::Json(json!({ "error": "tenant not loaded" })),
170 )
171 })?;
172 Ok(Self { tenant, runtime })
173 }
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use axum::http::Request;
181
182 #[test]
183 fn host_resolver_picks_subdomain() {
184 let routing = TenantRouting::new(RoutingConfig {
185 resolver: TenantResolver::Host,
186 default_tenant: "demo".into(),
187 });
188 let (parts, _) = Request::builder()
189 .uri("http://foo.example.com/webhook")
190 .header(HOST, "foo.example.com")
191 .body(())
192 .unwrap()
193 .into_parts();
194 let tenant = routing.resolve(&parts).unwrap();
195 assert_eq!(tenant, "foo");
196 }
197
198 #[test]
199 fn header_resolver_defaults() {
200 let routing = TenantRouting::new(RoutingConfig {
201 resolver: TenantResolver::Header(HeaderName::from_static("x-tenant")),
202 default_tenant: "demo".into(),
203 });
204 let (parts, _) = Request::builder()
205 .uri("http://localhost")
206 .body(())
207 .unwrap()
208 .into_parts();
209 let tenant = routing.resolve(&parts).unwrap();
210 assert_eq!(tenant, "demo");
211 }
212}