use std::future::Future;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TenantKey(String);
impl TenantKey {
pub fn new(s: impl Into<String>) -> Self {
TenantKey(s.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
#[derive(Clone, Default)]
pub struct RouteContext {
tenant: Option<TenantKey>,
extensions: http::Extensions,
}
impl RouteContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_tenant(mut self, tenant: TenantKey) -> Self {
self.tenant = Some(tenant);
self
}
pub fn tenant(&self) -> Option<&TenantKey> {
self.tenant.as_ref()
}
pub fn insert<T: Clone + Send + Sync + 'static>(&mut self, value: T) {
self.extensions.insert(value);
}
pub fn get<T: Clone + Send + Sync + 'static>(&self) -> Option<&T> {
self.extensions.get::<T>()
}
}
tokio::task_local! {
static ROUTE_CONTEXT: Arc<RouteContext>;
}
pub fn current() -> Arc<RouteContext> {
ROUTE_CONTEXT
.try_with(|c| c.clone())
.unwrap_or_else(|_| Arc::new(RouteContext::default()))
}
pub async fn scope<F: Future>(ctx: RouteContext, fut: F) -> F::Output {
ROUTE_CONTEXT.scope(Arc::new(ctx), fut).await
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn current_is_default_when_unset() {
assert!(current().tenant().is_none());
}
#[tokio::test]
async fn scope_sets_and_restores_context() {
let ctx = RouteContext::new().with_tenant(TenantKey::new("acme"));
scope(ctx, async {
assert_eq!(current().tenant().unwrap().as_str(), "acme");
})
.await;
assert!(current().tenant().is_none());
}
#[tokio::test]
async fn spawned_task_does_not_inherit_context() {
let ctx = RouteContext::new().with_tenant(TenantKey::new("acme"));
scope(ctx, async {
let handle = tokio::spawn(async { current().tenant().cloned() });
assert!(handle.await.unwrap().is_none());
})
.await;
}
#[tokio::test]
async fn extensions_store_typed_values() {
#[derive(Clone, PartialEq, Debug)]
struct Region(&'static str);
let mut ctx = RouteContext::new();
ctx.insert(Region("eu"));
scope(ctx, async {
assert_eq!(current().get::<Region>(), Some(&Region("eu")));
})
.await;
}
}