use std::collections::{HashMap, HashSet};
use std::sync::{LazyLock, RwLock};
use majra::ratelimit::RateLimiter;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct TenantQuota {
pub max_requests_per_sec: f64,
pub max_concurrent_flows: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TenantCtx {
pub tenant_id: String,
pub display_name: Option<String>,
pub quota: TenantQuota,
pub allowed_tools: Option<HashSet<String>>,
}
#[derive(Debug)]
pub struct TenantRegistry {
inner: RwLock<HashMap<String, TenantCtx>>,
}
impl TenantRegistry {
#[must_use]
pub fn new() -> Self {
Self {
inner: RwLock::new(HashMap::new()),
}
}
pub fn register(&self, ctx: TenantCtx) {
let mut map = self.inner.write().expect("tenant registry lock poisoned");
map.insert(ctx.tenant_id.clone(), ctx);
}
#[must_use]
pub fn get(&self, id: &str) -> Option<TenantCtx> {
let map = self.inner.read().expect("tenant registry lock poisoned");
map.get(id).cloned()
}
pub fn deregister(&self, id: &str) {
let mut map = self.inner.write().expect("tenant registry lock poisoned");
map.remove(id);
}
}
impl Default for TenantRegistry {
fn default() -> Self {
Self::new()
}
}
pub static TENANT_REGISTRY: LazyLock<TenantRegistry> = LazyLock::new(TenantRegistry::new);
static TENANT_LIMITER: LazyLock<RateLimiter> = LazyLock::new(|| RateLimiter::new(100.0, 500));
pub fn check_tenant_quota(tenant_id: &str) -> Result<(), String> {
let _ctx = match TENANT_REGISTRY.get(tenant_id) {
Some(ctx) => ctx,
None => return Ok(()),
};
if TENANT_LIMITER.check(tenant_id) {
Ok(())
} else {
Err(format!("rate limit exceeded for tenant {tenant_id}"))
}
}
pub fn check_tenant_tool_access(tenant_id: &str, tool_name: &str) -> Result<(), String> {
let ctx = match TENANT_REGISTRY.get(tenant_id) {
Some(ctx) => ctx,
None => return Ok(()),
};
match &ctx.allowed_tools {
None => Ok(()),
Some(allowed) => {
if allowed.contains(tool_name) {
Ok(())
} else {
Err(format!(
"tenant {tenant_id} is not permitted to use tool {tool_name}"
))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_ctx(id: &str, rate: f64, tools: Option<Vec<&str>>) -> TenantCtx {
TenantCtx {
tenant_id: id.to_string(),
display_name: Some(format!("Test Tenant {id}")),
quota: TenantQuota {
max_requests_per_sec: rate,
max_concurrent_flows: 5,
},
allowed_tools: tools.map(|t| t.into_iter().map(String::from).collect()),
}
}
#[test]
fn register_and_get() {
let registry = TenantRegistry::new();
let ctx = test_ctx("t1", 10.0, None);
registry.register(ctx.clone());
let got = registry.get("t1").expect("tenant should exist");
assert_eq!(got.tenant_id, "t1");
assert_eq!(got.display_name, Some("Test Tenant t1".to_string()));
}
#[test]
fn get_missing_returns_none() {
let registry = TenantRegistry::new();
assert!(registry.get("nonexistent").is_none());
}
#[test]
fn deregister_removes_tenant() {
let registry = TenantRegistry::new();
registry.register(test_ctx("t2", 10.0, None));
assert!(registry.get("t2").is_some());
registry.deregister("t2");
assert!(registry.get("t2").is_none());
}
#[test]
fn register_overwrites_existing() {
let registry = TenantRegistry::new();
registry.register(test_ctx("t3", 10.0, None));
let updated = TenantCtx {
tenant_id: "t3".to_string(),
display_name: Some("Updated".to_string()),
quota: TenantQuota {
max_requests_per_sec: 20.0,
max_concurrent_flows: 10,
},
allowed_tools: None,
};
registry.register(updated);
let got = registry.get("t3").unwrap();
assert_eq!(got.display_name, Some("Updated".to_string()));
assert_eq!(got.quota.max_concurrent_flows, 10);
}
#[test]
fn quota_check_unknown_tenant_is_permissive() {
assert!(check_tenant_quota("unknown-tenant").is_ok());
}
#[test]
fn quota_check_registered_tenant_allows_initially() {
TENANT_REGISTRY.register(test_ctx("quota-ok", 10.0, None));
assert!(check_tenant_quota("quota-ok").is_ok());
}
#[test]
fn quota_check_rejects_after_burst() {
TENANT_REGISTRY.register(test_ctx("quota-burst", 1.0, None));
let mut rejected = false;
for _ in 0..600 {
if check_tenant_quota("quota-burst").is_err() {
rejected = true;
break;
}
}
assert!(rejected, "expected rate limit rejection after burst");
}
#[test]
fn tool_access_unknown_tenant_is_permissive() {
assert!(check_tenant_tool_access("unknown", "any_tool").is_ok());
}
#[test]
fn tool_access_no_restrictions_allows_all() {
TENANT_REGISTRY.register(test_ctx("tool-all", 10.0, None));
assert!(check_tenant_tool_access("tool-all", "szal_http").is_ok());
assert!(check_tenant_tool_access("tool-all", "szal_dns_lookup").is_ok());
}
#[test]
fn tool_access_restricted_allows_listed() {
TENANT_REGISTRY.register(test_ctx(
"tool-restricted",
10.0,
Some(vec!["szal_http", "szal_dns_lookup"]),
));
assert!(check_tenant_tool_access("tool-restricted", "szal_http").is_ok());
assert!(check_tenant_tool_access("tool-restricted", "szal_dns_lookup").is_ok());
}
#[test]
fn tool_access_restricted_rejects_unlisted() {
TENANT_REGISTRY.register(test_ctx("tool-denied", 10.0, Some(vec!["szal_http"])));
let err = check_tenant_tool_access("tool-denied", "szal_port_check").unwrap_err();
assert!(
err.contains("not permitted"),
"expected 'not permitted' in: {err}"
);
assert!(err.contains("tool-denied"));
assert!(err.contains("szal_port_check"));
}
#[test]
fn global_registry_accessible() {
assert!(TENANT_REGISTRY.get("nonexistent").is_none());
}
#[test]
fn tenant_ctx_serializes() {
let ctx = test_ctx("ser-test", 50.0, Some(vec!["szal_http"]));
let json = serde_json::to_string(&ctx).expect("serialize");
let deser: TenantCtx = serde_json::from_str(&json).expect("deserialize");
assert_eq!(deser.tenant_id, "ser-test");
assert_eq!(deser.quota.max_requests_per_sec, 50.0);
assert!(deser.allowed_tools.unwrap().contains("szal_http"));
}
}