use crate::ai_studio::budget_manager::AiFeature;
use crate::Result;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub struct OrgAiControlsConfig {
pub budgets: OrgBudgetConfig,
pub rate_limits: OrgRateLimitConfig,
pub feature_toggles: HashMap<String, bool>,
}
impl Default for OrgAiControlsConfig {
fn default() -> Self {
Self {
budgets: OrgBudgetConfig::default(),
rate_limits: OrgRateLimitConfig::default(),
feature_toggles: HashMap::from([
("mock_generation".to_string(), true),
("contract_diff".to_string(), true),
("persona_generation".to_string(), true),
("free_form_generation".to_string(), true),
("debug_analysis".to_string(), true),
]),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub struct OrgBudgetConfig {
pub max_tokens_per_period: u64,
pub period_type: String,
pub max_calls_per_period: u64,
}
impl Default for OrgBudgetConfig {
fn default() -> Self {
Self {
max_tokens_per_period: 1_000_000,
period_type: "month".to_string(),
max_calls_per_period: 10_000,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub struct OrgRateLimitConfig {
pub rate_limit_per_minute: u64,
pub rate_limit_per_hour: Option<u64>,
pub rate_limit_per_day: Option<u64>,
}
impl Default for OrgRateLimitConfig {
fn default() -> Self {
Self {
rate_limit_per_minute: 100,
rate_limit_per_hour: None,
rate_limit_per_day: None,
}
}
}
#[allow(clippy::too_many_arguments)]
#[async_trait]
pub trait OrgControlsAccessor: Send + Sync {
async fn load_org_config(
&self,
org_id: &str,
workspace_id: Option<&str>,
) -> Result<Option<OrgAiControlsConfig>>;
async fn check_budget(
&self,
org_id: &str,
workspace_id: Option<&str>,
estimated_tokens: u64,
) -> Result<BudgetCheckResult>;
async fn check_rate_limit(
&self,
org_id: &str,
workspace_id: Option<&str>,
) -> Result<RateLimitCheckResult>;
async fn is_feature_enabled(
&self,
org_id: &str,
workspace_id: Option<&str>,
feature: &str,
) -> Result<bool>;
async fn record_usage(
&self,
org_id: &str,
workspace_id: Option<&str>,
user_id: Option<&str>,
feature: AiFeature,
tokens: u64,
cost_usd: f64,
metadata: Option<serde_json::Value>,
) -> Result<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BudgetCheckResult {
pub allowed: bool,
pub current_tokens: u64,
pub max_tokens: u64,
pub current_calls: u64,
pub max_calls: u64,
pub period_start: Option<DateTime<Utc>>,
pub reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitCheckResult {
pub allowed: bool,
pub current_requests: u64,
pub max_requests: u64,
pub window_type: String,
pub retry_after: Option<DateTime<Utc>>,
pub reason: Option<String>,
}
pub struct OrgControls {
yaml_config: OrgAiControlsConfig,
db_accessor: Option<Box<dyn OrgControlsAccessor>>,
}
impl OrgControls {
pub fn new(yaml_config: OrgAiControlsConfig) -> Self {
Self {
yaml_config,
db_accessor: None,
}
}
pub fn with_db_accessor(
yaml_config: OrgAiControlsConfig,
db_accessor: Box<dyn OrgControlsAccessor>,
) -> Self {
Self {
yaml_config,
db_accessor: Some(db_accessor),
}
}
pub async fn load_org_config(
&self,
org_id: &str,
workspace_id: Option<&str>,
) -> Result<OrgAiControlsConfig> {
if let Some(ref accessor) = self.db_accessor {
if let Some(db_config) = accessor.load_org_config(org_id, workspace_id).await? {
return Ok(self.merge_configs(self.yaml_config.clone(), db_config));
}
}
Ok(self.yaml_config.clone())
}
pub async fn check_budget(
&self,
org_id: &str,
workspace_id: Option<&str>,
estimated_tokens: u64,
) -> Result<BudgetCheckResult> {
if let Some(ref accessor) = self.db_accessor {
return accessor.check_budget(org_id, workspace_id, estimated_tokens).await;
}
let config = self.load_org_config(org_id, workspace_id).await?;
Ok(BudgetCheckResult {
allowed: true, current_tokens: 0,
max_tokens: config.budgets.max_tokens_per_period,
current_calls: 0,
max_calls: config.budgets.max_calls_per_period,
period_start: None,
reason: None,
})
}
pub async fn check_rate_limit(
&self,
org_id: &str,
workspace_id: Option<&str>,
) -> Result<RateLimitCheckResult> {
if let Some(ref accessor) = self.db_accessor {
return accessor.check_rate_limit(org_id, workspace_id).await;
}
let config = self.load_org_config(org_id, workspace_id).await?;
Ok(RateLimitCheckResult {
allowed: true, current_requests: 0,
max_requests: config.rate_limits.rate_limit_per_minute,
window_type: "minute".to_string(),
retry_after: None,
reason: None,
})
}
pub async fn is_feature_enabled(
&self,
org_id: &str,
workspace_id: Option<&str>,
feature: &str,
) -> Result<bool> {
if let Some(ref accessor) = self.db_accessor {
return accessor.is_feature_enabled(org_id, workspace_id, feature).await;
}
let config = self.load_org_config(org_id, workspace_id).await?;
Ok(config.feature_toggles.get(feature).copied().unwrap_or(true))
}
#[allow(clippy::too_many_arguments)]
pub async fn record_usage(
&self,
org_id: &str,
workspace_id: Option<&str>,
user_id: Option<&str>,
feature: AiFeature,
tokens: u64,
cost_usd: f64,
metadata: Option<serde_json::Value>,
) -> Result<()> {
if let Some(ref accessor) = self.db_accessor {
return accessor
.record_usage(org_id, workspace_id, user_id, feature, tokens, cost_usd, metadata)
.await;
}
Ok(())
}
fn merge_configs(
&self,
yaml: OrgAiControlsConfig,
db: OrgAiControlsConfig,
) -> OrgAiControlsConfig {
let mut merged_toggles = yaml.feature_toggles.clone();
for (key, value) in db.feature_toggles {
merged_toggles.insert(key, value);
}
OrgAiControlsConfig {
budgets: db.budgets, rate_limits: db.rate_limits, feature_toggles: merged_toggles,
}
}
}
impl Default for OrgControls {
fn default() -> Self {
Self::new(OrgAiControlsConfig::default())
}
}