a2a_protocol_server/
tenant_resolver.rs1use std::future::Future;
33use std::pin::Pin;
34use std::sync::Arc;
35
36use crate::call_context::CallContext;
37
38pub trait TenantResolver: Send + Sync + 'static {
54 fn resolve<'a>(
58 &'a self,
59 ctx: &'a CallContext,
60 ) -> Pin<Box<dyn Future<Output = Option<String>> + Send + 'a>>;
61}
62
63#[derive(Debug, Clone)]
82pub struct HeaderTenantResolver {
83 header_name: String,
84}
85
86impl HeaderTenantResolver {
87 #[must_use]
91 pub fn new(header_name: impl Into<String>) -> Self {
92 Self {
93 header_name: header_name.into().to_ascii_lowercase(),
94 }
95 }
96}
97
98impl Default for HeaderTenantResolver {
99 fn default() -> Self {
100 Self::new("x-tenant-id")
101 }
102}
103
104impl TenantResolver for HeaderTenantResolver {
105 fn resolve<'a>(
106 &'a self,
107 ctx: &'a CallContext,
108 ) -> Pin<Box<dyn Future<Output = Option<String>> + Send + 'a>> {
109 Box::pin(async move { ctx.http_headers().get(&self.header_name).cloned() })
110 }
111}
112
113type TokenMapper = Arc<dyn Fn(&str) -> Option<String> + Send + Sync + 'static>;
117
118pub struct BearerTokenTenantResolver {
139 mapper: Option<TokenMapper>,
140}
141
142impl std::fmt::Debug for BearerTokenTenantResolver {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 f.debug_struct("BearerTokenTenantResolver")
145 .field("has_mapper", &self.mapper.is_some())
146 .finish()
147 }
148}
149
150impl BearerTokenTenantResolver {
151 #[must_use]
153 pub fn new() -> Self {
154 Self { mapper: None }
155 }
156
157 #[must_use]
163 pub fn with_mapper<F>(mapper: F) -> Self
164 where
165 F: Fn(&str) -> Option<String> + Send + Sync + 'static,
166 {
167 Self {
168 mapper: Some(Arc::new(mapper)),
169 }
170 }
171}
172
173impl Default for BearerTokenTenantResolver {
174 fn default() -> Self {
175 Self::new()
176 }
177}
178
179impl TenantResolver for BearerTokenTenantResolver {
180 fn resolve<'a>(
181 &'a self,
182 ctx: &'a CallContext,
183 ) -> Pin<Box<dyn Future<Output = Option<String>> + Send + 'a>> {
184 Box::pin(async move {
185 let auth = ctx.http_headers().get("authorization")?;
186 let token = auth
187 .strip_prefix("Bearer ")
188 .or_else(|| auth.strip_prefix("bearer "))?;
189
190 if token.is_empty() {
191 return None;
192 }
193
194 self.mapper
195 .as_ref()
196 .map_or_else(|| Some(token.to_owned()), |mapper| mapper(token))
197 })
198 }
199}
200
201#[derive(Debug, Clone)]
222pub struct PathSegmentTenantResolver {
223 segment_index: usize,
224}
225
226impl PathSegmentTenantResolver {
227 #[must_use]
231 pub const fn new(segment_index: usize) -> Self {
232 Self { segment_index }
233 }
234
235 fn extract_from_path(&self, path: &str) -> Option<String> {
237 let segment = path
238 .split('/')
239 .filter(|s| !s.is_empty())
240 .nth(self.segment_index)?;
241
242 if segment.is_empty() {
243 None
244 } else {
245 Some(segment.to_owned())
246 }
247 }
248}
249
250impl TenantResolver for PathSegmentTenantResolver {
251 fn resolve<'a>(
252 &'a self,
253 ctx: &'a CallContext,
254 ) -> Pin<Box<dyn Future<Output = Option<String>> + Send + 'a>> {
255 Box::pin(async move {
256 let path = ctx
258 .http_headers()
259 .get(":path")
260 .or_else(|| ctx.http_headers().get("path"))?;
261 self.extract_from_path(path)
262 })
263 }
264}
265
266#[cfg(test)]
269mod tests {
270 use super::*;
271
272 fn make_ctx() -> CallContext {
273 CallContext::new("message/send")
274 }
275
276 #[tokio::test]
279 async fn header_resolver_default_header() {
280 let resolver = HeaderTenantResolver::default();
281 let ctx = make_ctx().with_http_header("x-tenant-id", "acme");
282 assert_eq!(resolver.resolve(&ctx).await, Some("acme".into()));
283 }
284
285 #[tokio::test]
286 async fn header_resolver_custom_header() {
287 let resolver = HeaderTenantResolver::new("X-Org-Id");
288 let ctx = make_ctx().with_http_header("x-org-id", "org-42");
289 assert_eq!(resolver.resolve(&ctx).await, Some("org-42".into()));
290 }
291
292 #[tokio::test]
293 async fn header_resolver_missing_header() {
294 let resolver = HeaderTenantResolver::default();
295 let ctx = make_ctx();
296 assert_eq!(resolver.resolve(&ctx).await, None);
297 }
298
299 #[tokio::test]
302 async fn bearer_resolver_raw_token() {
303 let resolver = BearerTokenTenantResolver::new();
304 let ctx = make_ctx().with_http_header("authorization", "Bearer tok_abc123");
305 assert_eq!(resolver.resolve(&ctx).await, Some("tok_abc123".into()));
306 }
307
308 #[tokio::test]
309 async fn bearer_resolver_with_mapper() {
310 let resolver = BearerTokenTenantResolver::with_mapper(|token| {
311 token.strip_prefix("tok_").map(str::to_uppercase)
312 });
313 let ctx = make_ctx().with_http_header("authorization", "Bearer tok_abc");
314 assert_eq!(resolver.resolve(&ctx).await, Some("ABC".into()));
315 }
316
317 #[tokio::test]
318 async fn bearer_resolver_mapper_returns_none() {
319 let resolver = BearerTokenTenantResolver::with_mapper(|_| None);
320 let ctx = make_ctx().with_http_header("authorization", "Bearer tok");
321 assert_eq!(resolver.resolve(&ctx).await, None);
322 }
323
324 #[tokio::test]
325 async fn bearer_resolver_missing_header() {
326 let resolver = BearerTokenTenantResolver::new();
327 let ctx = make_ctx();
328 assert_eq!(resolver.resolve(&ctx).await, None);
329 }
330
331 #[tokio::test]
332 async fn bearer_resolver_non_bearer_auth() {
333 let resolver = BearerTokenTenantResolver::new();
334 let ctx = make_ctx().with_http_header("authorization", "Basic abc123");
335 assert_eq!(resolver.resolve(&ctx).await, None);
336 }
337
338 #[tokio::test]
339 async fn bearer_resolver_empty_token() {
340 let resolver = BearerTokenTenantResolver::new();
341 let ctx = make_ctx().with_http_header("authorization", "Bearer ");
342 assert_eq!(resolver.resolve(&ctx).await, None);
343 }
344
345 #[tokio::test]
348 async fn path_resolver_extracts_segment() {
349 let resolver = PathSegmentTenantResolver::new(1);
350 let ctx = make_ctx().with_http_header("path", "/tenants/acme/tasks");
351 assert_eq!(resolver.resolve(&ctx).await, Some("acme".into()));
352 }
353
354 #[tokio::test]
355 async fn path_resolver_first_segment() {
356 let resolver = PathSegmentTenantResolver::new(0);
357 let ctx = make_ctx().with_http_header("path", "/v1/agents");
358 assert_eq!(resolver.resolve(&ctx).await, Some("v1".into()));
359 }
360
361 #[tokio::test]
362 async fn path_resolver_out_of_bounds() {
363 let resolver = PathSegmentTenantResolver::new(10);
364 let ctx = make_ctx().with_http_header("path", "/a/b");
365 assert_eq!(resolver.resolve(&ctx).await, None);
366 }
367
368 #[tokio::test]
369 async fn path_resolver_prefers_pseudo_header() {
370 let resolver = PathSegmentTenantResolver::new(0);
371 let ctx = make_ctx()
372 .with_http_header(":path", "/h2-tenant/foo")
373 .with_http_header("path", "/fallback/bar");
374 assert_eq!(resolver.resolve(&ctx).await, Some("h2-tenant".into()));
375 }
376
377 #[tokio::test]
378 async fn path_resolver_missing_path() {
379 let resolver = PathSegmentTenantResolver::new(0);
380 let ctx = make_ctx();
381 assert_eq!(resolver.resolve(&ctx).await, None);
382 }
383
384 #[tokio::test]
386 async fn bearer_resolver_default_same_as_new() {
387 let resolver = BearerTokenTenantResolver::default();
388 let ctx = make_ctx().with_http_header("authorization", "Bearer test-token");
389 assert_eq!(
390 resolver.resolve(&ctx).await,
391 Some("test-token".into()),
392 "default() should behave the same as new()"
393 );
394 }
395
396 #[tokio::test]
398 async fn path_resolver_uses_fallback_path_header() {
399 let resolver = PathSegmentTenantResolver::new(0);
400 let ctx = make_ctx().with_http_header("path", "/tenant-from-path/tasks");
402 assert_eq!(
403 resolver.resolve(&ctx).await,
404 Some("tenant-from-path".into())
405 );
406 }
407
408 #[tokio::test]
410 async fn bearer_resolver_lowercase_bearer() {
411 let resolver = BearerTokenTenantResolver::new();
412 let ctx = make_ctx().with_http_header("authorization", "bearer lowercase_tok");
413 assert_eq!(resolver.resolve(&ctx).await, Some("lowercase_tok".into()));
414 }
415
416 #[test]
417 fn bearer_resolver_debug_shows_has_mapper() {
418 let resolver = BearerTokenTenantResolver::new();
419 let debug = format!("{resolver:?}");
420 assert!(debug.contains("BearerTokenTenantResolver"));
421 assert!(debug.contains("has_mapper"));
422 assert!(debug.contains("false"));
423
424 let resolver_with = BearerTokenTenantResolver::with_mapper(|t| Some(t.to_string()));
425 let debug = format!("{resolver_with:?}");
426 assert!(debug.contains("true"));
427 }
428}