use super::{PolicyContext, PolicyDecision, PolicyEvaluator};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TenantLimits {
pub max_concurrent_executions: Option<usize>,
pub max_executions_per_day: Option<usize>,
pub max_tokens_per_day: Option<usize>,
pub max_storage_bytes: Option<usize>,
}
impl Default for TenantLimits {
fn default() -> Self {
Self {
max_concurrent_executions: Some(10),
max_executions_per_day: Some(1000),
max_tokens_per_day: Some(1_000_000),
max_storage_bytes: Some(1024 * 1024 * 1024), }
}
}
impl TenantLimits {
pub fn unlimited() -> Self {
Self {
max_concurrent_executions: None,
max_executions_per_day: None,
max_tokens_per_day: None,
max_storage_bytes: None,
}
}
pub fn free_tier() -> Self {
Self {
max_concurrent_executions: Some(2),
max_executions_per_day: Some(100),
max_tokens_per_day: Some(50_000),
max_storage_bytes: Some(100 * 1024 * 1024), }
}
pub fn pro_tier() -> Self {
Self {
max_concurrent_executions: Some(20),
max_executions_per_day: Some(10_000),
max_tokens_per_day: Some(10_000_000),
max_storage_bytes: Some(10 * 1024 * 1024 * 1024), }
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct FeatureFlags {
pub enabled: HashSet<String>,
pub disabled: HashSet<String>,
}
impl FeatureFlags {
pub fn new() -> Self {
Self::default()
}
pub fn enable(mut self, feature: impl Into<String>) -> Self {
self.enabled.insert(feature.into());
self
}
pub fn disable(mut self, feature: impl Into<String>) -> Self {
self.disabled.insert(feature.into());
self
}
pub fn is_enabled(&self, feature: &str) -> bool {
!self.disabled.contains(feature) && self.enabled.contains(feature)
}
pub fn with_defaults() -> Self {
Self::new()
.enable("basic_execution")
.enable("tool_invocation")
.enable("streaming")
}
pub fn all_enabled() -> Self {
Self::new()
.enable("basic_execution")
.enable("tool_invocation")
.enable("streaming")
.enable("parallel_execution")
.enable("nested_execution")
.enable("custom_tools")
.enable("mcp_integration")
.enable("advanced_memory")
}
}
#[derive(Debug, Clone)]
pub struct TenantPolicy {
pub limits: TenantLimits,
pub features: FeatureFlags,
pub allowed_models: HashSet<String>,
pub allowed_tools: Option<HashSet<String>>,
pub isolation: TenantIsolation,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum TenantIsolation {
#[default]
Shared,
Dedicated,
Strict,
}
impl Default for TenantPolicy {
fn default() -> Self {
Self {
limits: TenantLimits::default(),
features: FeatureFlags::with_defaults(),
allowed_models: HashSet::new(),
allowed_tools: None, isolation: TenantIsolation::default(),
}
}
}
impl TenantPolicy {
pub fn new() -> Self {
Self::default()
}
pub fn with_limits(mut self, limits: TenantLimits) -> Self {
self.limits = limits;
self
}
pub fn with_features(mut self, features: FeatureFlags) -> Self {
self.features = features;
self
}
pub fn allow_model(mut self, model: impl Into<String>) -> Self {
self.allowed_models.insert(model.into());
self
}
pub fn allow_tool(mut self, tool: impl Into<String>) -> Self {
self.allowed_tools
.get_or_insert_with(HashSet::new)
.insert(tool.into());
self
}
pub fn with_isolation(mut self, isolation: TenantIsolation) -> Self {
self.isolation = isolation;
self
}
pub fn is_model_allowed(&self, model: &str) -> bool {
self.allowed_models.is_empty() || self.allowed_models.contains(model)
}
pub fn is_tool_allowed(&self, tool: &str) -> bool {
self.allowed_tools
.as_ref()
.map(|tools| tools.contains(tool))
.unwrap_or(true)
}
}
impl PolicyEvaluator for TenantPolicy {
fn evaluate(&self, context: &PolicyContext) -> PolicyDecision {
match &context.action {
super::PolicyAction::LlmCall { model } => {
if !self.is_model_allowed(model) {
return PolicyDecision::Deny {
reason: format!("Model '{}' is not allowed for this tenant", model),
};
}
PolicyDecision::Allow
}
super::PolicyAction::InvokeTool { tool_name } => {
if !self.is_tool_allowed(tool_name) {
return PolicyDecision::Deny {
reason: format!("Tool '{}' is not allowed for this tenant", tool_name),
};
}
PolicyDecision::Allow
}
_ => PolicyDecision::Allow,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_tenant_limits_default() {
let limits = TenantLimits::default();
assert_eq!(limits.max_concurrent_executions, Some(10));
assert_eq!(limits.max_executions_per_day, Some(1000));
assert_eq!(limits.max_tokens_per_day, Some(1_000_000));
assert_eq!(limits.max_storage_bytes, Some(1024 * 1024 * 1024));
}
#[test]
fn test_tenant_limits_unlimited() {
let limits = TenantLimits::unlimited();
assert!(limits.max_concurrent_executions.is_none());
assert!(limits.max_executions_per_day.is_none());
assert!(limits.max_tokens_per_day.is_none());
assert!(limits.max_storage_bytes.is_none());
}
#[test]
fn test_tenant_limits_free_tier() {
let limits = TenantLimits::free_tier();
assert_eq!(limits.max_concurrent_executions, Some(2));
assert_eq!(limits.max_executions_per_day, Some(100));
assert_eq!(limits.max_tokens_per_day, Some(50_000));
}
#[test]
fn test_tenant_limits_pro_tier() {
let limits = TenantLimits::pro_tier();
assert_eq!(limits.max_concurrent_executions, Some(20));
assert_eq!(limits.max_executions_per_day, Some(10_000));
assert_eq!(limits.max_tokens_per_day, Some(10_000_000));
}
#[test]
fn test_feature_flags_default() {
let flags = FeatureFlags::default();
assert!(flags.enabled.is_empty());
assert!(flags.disabled.is_empty());
}
#[test]
fn test_feature_flags_enable() {
let flags = FeatureFlags::new().enable("feature_a").enable("feature_b");
assert!(flags.is_enabled("feature_a"));
assert!(flags.is_enabled("feature_b"));
assert!(!flags.is_enabled("feature_c"));
}
#[test]
fn test_feature_flags_disable_overrides_enable() {
let flags = FeatureFlags::new().enable("feature_x").disable("feature_x");
assert!(!flags.is_enabled("feature_x"));
}
#[test]
fn test_feature_flags_with_defaults() {
let flags = FeatureFlags::with_defaults();
assert!(flags.is_enabled("basic_execution"));
assert!(flags.is_enabled("tool_invocation"));
assert!(flags.is_enabled("streaming"));
}
#[test]
fn test_feature_flags_all_enabled() {
let flags = FeatureFlags::all_enabled();
assert!(flags.is_enabled("basic_execution"));
assert!(flags.is_enabled("parallel_execution"));
assert!(flags.is_enabled("mcp_integration"));
assert!(flags.is_enabled("advanced_memory"));
}
#[test]
fn test_tenant_isolation_default() {
assert_eq!(TenantIsolation::default(), TenantIsolation::Shared);
}
#[test]
fn test_tenant_policy_default() {
let policy = TenantPolicy::default();
assert!(policy.allowed_models.is_empty());
assert!(policy.allowed_tools.is_none());
assert_eq!(policy.isolation, TenantIsolation::Shared);
}
#[test]
fn test_tenant_policy_with_limits() {
let policy = TenantPolicy::new().with_limits(TenantLimits::free_tier());
assert_eq!(policy.limits.max_concurrent_executions, Some(2));
}
#[test]
fn test_tenant_policy_with_features() {
let policy = TenantPolicy::new().with_features(FeatureFlags::all_enabled());
assert!(policy.features.is_enabled("parallel_execution"));
}
#[test]
fn test_tenant_policy_allow_model() {
let policy = TenantPolicy::new()
.allow_model("gpt-4")
.allow_model("claude-3");
assert!(policy.is_model_allowed("gpt-4"));
assert!(policy.is_model_allowed("claude-3"));
assert!(!policy.is_model_allowed("unknown-model"));
}
#[test]
fn test_tenant_policy_allow_model_empty_allows_all() {
let policy = TenantPolicy::new();
assert!(policy.is_model_allowed("any-model"));
}
#[test]
fn test_tenant_policy_allow_tool() {
let policy = TenantPolicy::new()
.allow_tool("web_search")
.allow_tool("calculator");
assert!(policy.is_tool_allowed("web_search"));
assert!(policy.is_tool_allowed("calculator"));
assert!(!policy.is_tool_allowed("file_system"));
}
#[test]
fn test_tenant_policy_no_tool_restriction_allows_all() {
let policy = TenantPolicy::new();
assert!(policy.is_tool_allowed("any-tool"));
}
#[test]
fn test_tenant_policy_with_isolation() {
let policy = TenantPolicy::new().with_isolation(TenantIsolation::Strict);
assert_eq!(policy.isolation, TenantIsolation::Strict);
}
#[test]
fn test_tenant_policy_evaluate_llm_allowed() {
let policy = TenantPolicy::new(); let context = PolicyContext {
tenant_id: Some("tenant-1".to_string()),
user_id: None,
action: super::super::PolicyAction::LlmCall {
model: "gpt-4".to_string(),
},
metadata: HashMap::new(),
};
let decision = policy.evaluate(&context);
assert!(decision.is_allowed());
}
#[test]
fn test_tenant_policy_evaluate_llm_denied() {
let policy = TenantPolicy::new().allow_model("claude-3"); let context = PolicyContext {
tenant_id: Some("tenant-1".to_string()),
user_id: None,
action: super::super::PolicyAction::LlmCall {
model: "gpt-4".to_string(),
},
metadata: HashMap::new(),
};
let decision = policy.evaluate(&context);
assert!(decision.is_denied());
}
#[test]
fn test_tenant_policy_evaluate_tool_allowed() {
let policy = TenantPolicy::new(); let context = PolicyContext {
tenant_id: Some("tenant-1".to_string()),
user_id: None,
action: super::super::PolicyAction::InvokeTool {
tool_name: "any_tool".to_string(),
},
metadata: HashMap::new(),
};
let decision = policy.evaluate(&context);
assert!(decision.is_allowed());
}
#[test]
fn test_tenant_policy_evaluate_tool_denied() {
let policy = TenantPolicy::new().allow_tool("calculator"); let context = PolicyContext {
tenant_id: Some("tenant-1".to_string()),
user_id: None,
action: super::super::PolicyAction::InvokeTool {
tool_name: "web_search".to_string(),
},
metadata: HashMap::new(),
};
let decision = policy.evaluate(&context);
assert!(decision.is_denied());
}
}