mockforge_intelligence/ai_studio/
budget_manager.rs1use crate::ai_studio::org_controls::OrgControls;
8use chrono::{DateTime, Utc};
9use mockforge_foundation::Result;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14
15pub struct BudgetManager {
17 config: BudgetConfig,
19 usage_tracker: Arc<RwLock<HashMap<String, WorkspaceUsage>>>,
21 org_controls: Option<Arc<OrgControls>>,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
27#[serde(rename_all = "snake_case")]
28pub enum AiFeature {
29 MockAi,
31 ContractDiff,
33 PersonaGeneration,
35 DebugAnalysis,
37 GenerativeSchema,
39 VoiceInterface,
41 GeneralChat,
43}
44
45impl AiFeature {
46 pub fn display_name(&self) -> &'static str {
48 match self {
49 AiFeature::MockAi => "MockAI",
50 AiFeature::ContractDiff => "Contract Diff",
51 AiFeature::PersonaGeneration => "Persona Generation",
52 AiFeature::DebugAnalysis => "Debug Analysis",
53 AiFeature::GenerativeSchema => "Generative Schema",
54 AiFeature::VoiceInterface => "Voice Interface",
55 AiFeature::GeneralChat => "General Chat",
56 }
57 }
58}
59
60#[derive(Debug, Clone, Default, Serialize, Deserialize)]
62pub struct FeatureUsage {
63 pub tokens_used: u64,
65 pub cost_usd: f64,
67 pub calls_made: u64,
69}
70
71#[derive(Debug, Clone)]
73struct WorkspaceUsage {
74 tokens_used: u64,
76 cost_usd: f64,
78 calls_made: u64,
80 #[allow(dead_code)]
82 last_reset: DateTime<Utc>,
83 daily_calls: HashMap<chrono::NaiveDate, u64>,
85 feature_usage: HashMap<AiFeature, FeatureUsage>,
87}
88
89impl BudgetManager {
90 pub fn new(config: BudgetConfig) -> Self {
92 Self {
93 config,
94 usage_tracker: Arc::new(RwLock::new(HashMap::new())),
95 org_controls: None,
96 }
97 }
98
99 pub fn with_org_controls(config: BudgetConfig, org_controls: Arc<OrgControls>) -> Self {
101 Self {
102 config,
103 usage_tracker: Arc::new(RwLock::new(HashMap::new())),
104 org_controls: Some(org_controls),
105 }
106 }
107
108 pub async fn get_usage(&self, workspace_id: &str) -> Result<UsageStats> {
110 let tracker = self.usage_tracker.read().await;
111 let usage = tracker.get(workspace_id).cloned().unwrap_or_else(|| WorkspaceUsage {
112 tokens_used: 0,
113 cost_usd: 0.0,
114 calls_made: 0,
115 last_reset: Utc::now(),
116 daily_calls: HashMap::new(),
117 feature_usage: HashMap::new(),
118 });
119
120 let usage_percentage = if self.config.max_tokens_per_workspace > 0 {
121 (usage.tokens_used as f64 / self.config.max_tokens_per_workspace as f64).min(1.0)
122 } else {
123 0.0
124 };
125
126 let feature_breakdown: HashMap<String, FeatureUsage> = usage
128 .feature_usage
129 .iter()
130 .map(|(feature, usage)| (format!("{:?}", feature), usage.clone()))
131 .collect();
132
133 Ok(UsageStats {
134 tokens_used: usage.tokens_used,
135 cost_usd: usage.cost_usd,
136 calls_made: usage.calls_made,
137 budget_limit: self.config.max_tokens_per_workspace,
138 usage_percentage,
139 feature_breakdown: Some(feature_breakdown),
140 })
141 }
142
143 pub async fn check_budget(
148 &self,
149 org_id: Option<&str>,
150 workspace_id: &str,
151 estimated_tokens: u64,
152 ) -> Result<bool> {
153 if let (Some(org_id), Some(ref org_controls)) = (org_id, &self.org_controls) {
155 let budget_result =
156 org_controls.check_budget(org_id, Some(workspace_id), estimated_tokens).await?;
157 if !budget_result.allowed {
158 return Ok(false);
159 }
160 }
161
162 let tracker = self.usage_tracker.read().await;
164 let usage = tracker.get(workspace_id);
165
166 if let Some(usage) = usage {
168 if self.config.max_tokens_per_workspace > 0
169 && usage.tokens_used + estimated_tokens > self.config.max_tokens_per_workspace
170 {
171 return Ok(false);
172 }
173 }
174
175 let today = Utc::now().date_naive();
177 if let Some(usage) = usage {
178 let today_calls = usage.daily_calls.get(&today).copied().unwrap_or(0);
179 if today_calls >= self.config.max_ai_calls_per_day {
180 return Ok(false);
181 }
182 }
183
184 Ok(true)
187 }
188
189 pub async fn check_rate_limit(&self, org_id: Option<&str>, workspace_id: &str) -> Result<bool> {
191 if let (Some(org_id), Some(ref org_controls)) = (org_id, &self.org_controls) {
193 let rate_limit_result =
194 org_controls.check_rate_limit(org_id, Some(workspace_id)).await?;
195 if !rate_limit_result.allowed {
196 return Ok(false);
197 }
198 }
199
200 Ok(true)
203 }
204
205 pub async fn is_feature_enabled(
207 &self,
208 org_id: Option<&str>,
209 workspace_id: &str,
210 feature: &str,
211 ) -> Result<bool> {
212 if let (Some(org_id), Some(ref org_controls)) = (org_id, &self.org_controls) {
214 return org_controls.is_feature_enabled(org_id, Some(workspace_id), feature).await;
215 }
216
217 Ok(true)
219 }
220
221 pub async fn record_usage(
223 &self,
224 org_id: Option<&str>,
225 workspace_id: &str,
226 user_id: Option<&str>,
227 tokens: u64,
228 cost_usd: f64,
229 ) -> Result<()> {
230 self.record_usage_with_feature(org_id, workspace_id, user_id, tokens, cost_usd, None)
231 .await
232 }
233
234 pub async fn record_usage_with_feature(
238 &self,
239 org_id: Option<&str>,
240 workspace_id: &str,
241 user_id: Option<&str>,
242 tokens: u64,
243 cost_usd: f64,
244 feature: Option<AiFeature>,
245 ) -> Result<()> {
246 if let (Some(org_id), Some(ref org_controls)) = (org_id, &self.org_controls) {
248 if let Some(feature) = feature {
249 let _feature_name = match feature {
250 AiFeature::MockAi => "mock_generation",
251 AiFeature::ContractDiff => "contract_diff",
252 AiFeature::PersonaGeneration => "persona_generation",
253 AiFeature::DebugAnalysis => "debug_analysis",
254 AiFeature::GenerativeSchema => "generative_schema",
255 AiFeature::VoiceInterface => "voice_interface",
256 AiFeature::GeneralChat => "free_form_generation",
257 };
258 org_controls
259 .record_usage(
260 org_id,
261 Some(workspace_id),
262 user_id,
263 feature,
264 tokens,
265 cost_usd,
266 None,
267 )
268 .await?;
269 }
270 }
271
272 let mut tracker = self.usage_tracker.write().await;
274 let usage = tracker.entry(workspace_id.to_string()).or_insert_with(|| WorkspaceUsage {
275 tokens_used: 0,
276 cost_usd: 0.0,
277 calls_made: 0,
278 last_reset: Utc::now(),
279 daily_calls: HashMap::new(),
280 feature_usage: HashMap::new(),
281 });
282
283 usage.tokens_used += tokens;
284 usage.cost_usd += cost_usd;
285 usage.calls_made += 1;
286
287 if let Some(feature) = feature {
289 let feature_usage =
290 usage.feature_usage.entry(feature).or_insert_with(FeatureUsage::default);
291 feature_usage.tokens_used += tokens;
292 feature_usage.cost_usd += cost_usd;
293 feature_usage.calls_made += 1;
294 }
295
296 let today = Utc::now().date_naive();
298 *usage.daily_calls.entry(today).or_insert(0) += 1;
299
300 Ok(())
301 }
302
303 pub async fn reset_usage(&self, workspace_id: &str) -> Result<()> {
305 let mut tracker = self.usage_tracker.write().await;
306 tracker.remove(workspace_id);
307 Ok(())
308 }
309
310 pub fn calculate_cost(provider: &str, model: &str, tokens: u64) -> f64 {
318 let tokens_k = tokens as f64 / 1000.0;
319
320 let price_per_1k = if provider.to_lowercase() == "ollama" {
322 0.0 } else if model.contains("gpt-4") {
324 0.03 } else if model.contains("gpt-3.5") || model.contains("gpt-3") {
326 0.002 } else if provider.to_lowercase() == "anthropic" {
328 0.008 } else {
330 0.002 };
332
333 tokens_k * price_per_1k
334 }
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct BudgetConfig {
340 pub max_tokens_per_workspace: u64,
342
343 pub max_ai_calls_per_day: u64,
345
346 pub rate_limit_per_minute: u64,
348}
349
350impl Default for BudgetConfig {
351 fn default() -> Self {
352 Self {
353 max_tokens_per_workspace: 100_000,
354 max_ai_calls_per_day: 1_000,
355 rate_limit_per_minute: 10,
356 }
357 }
358}
359
360#[derive(Debug, Clone, Serialize, Deserialize)]
362pub struct UsageStats {
363 pub tokens_used: u64,
365
366 pub cost_usd: f64,
368
369 pub calls_made: u64,
371
372 pub budget_limit: u64,
374
375 pub usage_percentage: f64,
377
378 #[serde(skip_serializing_if = "Option::is_none")]
380 pub feature_breakdown: Option<HashMap<String, FeatureUsage>>,
381}