Skip to main content

mockforge_intelligence/ai_studio/
budget_manager.rs

1//! Budget manager for AI usage tracking and controls
2//!
3//! This module provides functionality to track token usage, calculate costs,
4//! and enforce budget limits. It uses in-memory tracking for local usage,
5//! and can integrate with cloud usage tracking when available.
6
7use crate::ai_studio::org_controls::OrgControls;
8use chrono::{DateTime, Utc};
9use mockforge_foundation::Result;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14
15/// Budget manager for AI usage
16pub struct BudgetManager {
17    /// Budget configuration (workspace-level defaults)
18    config: BudgetConfig,
19    /// In-memory usage tracking (workspace_id -> usage stats)
20    usage_tracker: Arc<RwLock<HashMap<String, WorkspaceUsage>>>,
21    /// Optional org controls for org-level enforcement
22    org_controls: Option<Arc<OrgControls>>,
23}
24
25/// AI feature types for tracking
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
27#[serde(rename_all = "snake_case")]
28pub enum AiFeature {
29    /// MockAI - Natural language mock generation
30    MockAi,
31    /// AI Contract Diff - Contract analysis and recommendations
32    ContractDiff,
33    /// Persona Generation - AI-generated personas
34    PersonaGeneration,
35    /// Debug Analysis - AI-guided debugging
36    DebugAnalysis,
37    /// Generative Schema - Schema generation from examples
38    GenerativeSchema,
39    /// Voice/LLM Interface - Voice commands and chat
40    VoiceInterface,
41    /// General chat/assistant
42    GeneralChat,
43}
44
45impl AiFeature {
46    /// Get display name for the feature
47    pub fn display_name(&self) -> &'static str {
48        match self {
49            AiFeature::MockAi => "MockAI",
50            AiFeature::ContractDiff => "Contract Diff",
51            AiFeature::PersonaGeneration => "Persona Generation",
52            AiFeature::DebugAnalysis => "Debug Analysis",
53            AiFeature::GenerativeSchema => "Generative Schema",
54            AiFeature::VoiceInterface => "Voice Interface",
55            AiFeature::GeneralChat => "General Chat",
56        }
57    }
58}
59
60/// Per-feature usage statistics
61#[derive(Debug, Clone, Default, Serialize, Deserialize)]
62pub struct FeatureUsage {
63    /// Tokens used by this feature
64    pub tokens_used: u64,
65    /// Cost in USD for this feature
66    pub cost_usd: f64,
67    /// Number of calls made for this feature
68    pub calls_made: u64,
69}
70
71/// Per-workspace usage tracking
72#[derive(Debug, Clone)]
73struct WorkspaceUsage {
74    /// Total tokens used
75    tokens_used: u64,
76    /// Total cost in USD
77    cost_usd: f64,
78    /// Number of AI calls made
79    calls_made: u64,
80    /// Last reset time
81    #[allow(dead_code)]
82    last_reset: DateTime<Utc>,
83    /// Per-day call tracking (for rate limiting)
84    daily_calls: HashMap<chrono::NaiveDate, u64>,
85    /// Per-feature usage tracking
86    feature_usage: HashMap<AiFeature, FeatureUsage>,
87}
88
89impl BudgetManager {
90    /// Create a new budget manager
91    pub fn new(config: BudgetConfig) -> Self {
92        Self {
93            config,
94            usage_tracker: Arc::new(RwLock::new(HashMap::new())),
95            org_controls: None,
96        }
97    }
98
99    /// Create a new budget manager with org controls
100    pub fn with_org_controls(config: BudgetConfig, org_controls: Arc<OrgControls>) -> Self {
101        Self {
102            config,
103            usage_tracker: Arc::new(RwLock::new(HashMap::new())),
104            org_controls: Some(org_controls),
105        }
106    }
107
108    /// Get usage statistics for a workspace
109    pub async fn get_usage(&self, workspace_id: &str) -> Result<UsageStats> {
110        let tracker = self.usage_tracker.read().await;
111        let usage = tracker.get(workspace_id).cloned().unwrap_or_else(|| WorkspaceUsage {
112            tokens_used: 0,
113            cost_usd: 0.0,
114            calls_made: 0,
115            last_reset: Utc::now(),
116            daily_calls: HashMap::new(),
117            feature_usage: HashMap::new(),
118        });
119
120        let usage_percentage = if self.config.max_tokens_per_workspace > 0 {
121            (usage.tokens_used as f64 / self.config.max_tokens_per_workspace as f64).min(1.0)
122        } else {
123            0.0
124        };
125
126        // Convert feature usage to serializable format
127        let feature_breakdown: HashMap<String, FeatureUsage> = usage
128            .feature_usage
129            .iter()
130            .map(|(feature, usage)| (format!("{:?}", feature), usage.clone()))
131            .collect();
132
133        Ok(UsageStats {
134            tokens_used: usage.tokens_used,
135            cost_usd: usage.cost_usd,
136            calls_made: usage.calls_made,
137            budget_limit: self.config.max_tokens_per_workspace,
138            usage_percentage,
139            feature_breakdown: Some(feature_breakdown),
140        })
141    }
142
143    /// Check if request is within budget
144    ///
145    /// Checks org-level limits first (if available), then workspace-level limits.
146    /// Org-level limits take precedence.
147    pub async fn check_budget(
148        &self,
149        org_id: Option<&str>,
150        workspace_id: &str,
151        estimated_tokens: u64,
152    ) -> Result<bool> {
153        // Check org-level budget first (if available)
154        if let (Some(org_id), Some(ref org_controls)) = (org_id, &self.org_controls) {
155            let budget_result =
156                org_controls.check_budget(org_id, Some(workspace_id), estimated_tokens).await?;
157            if !budget_result.allowed {
158                return Ok(false);
159            }
160        }
161
162        // Check workspace-level budget
163        let tracker = self.usage_tracker.read().await;
164        let usage = tracker.get(workspace_id);
165
166        // Check token budget
167        if let Some(usage) = usage {
168            if self.config.max_tokens_per_workspace > 0
169                && usage.tokens_used + estimated_tokens > self.config.max_tokens_per_workspace
170            {
171                return Ok(false);
172            }
173        }
174
175        // Check daily call limit
176        let today = Utc::now().date_naive();
177        if let Some(usage) = usage {
178            let today_calls = usage.daily_calls.get(&today).copied().unwrap_or(0);
179            if today_calls >= self.config.max_ai_calls_per_day {
180                return Ok(false);
181            }
182        }
183
184        // Check rate limit (per minute)
185        // Note: This is a simplified check - in production, you'd want more sophisticated rate limiting
186        Ok(true)
187    }
188
189    /// Check rate limit (org-level first, then workspace-level)
190    pub async fn check_rate_limit(&self, org_id: Option<&str>, workspace_id: &str) -> Result<bool> {
191        // Check org-level rate limit first (if available)
192        if let (Some(org_id), Some(ref org_controls)) = (org_id, &self.org_controls) {
193            let rate_limit_result =
194                org_controls.check_rate_limit(org_id, Some(workspace_id)).await?;
195            if !rate_limit_result.allowed {
196                return Ok(false);
197            }
198        }
199
200        // Workspace-level rate limiting would be handled here if needed
201        // For now, we rely on org-level rate limiting
202        Ok(true)
203    }
204
205    /// Check if a feature is enabled (org-level first, then defaults to true)
206    pub async fn is_feature_enabled(
207        &self,
208        org_id: Option<&str>,
209        workspace_id: &str,
210        feature: &str,
211    ) -> Result<bool> {
212        // Check org-level feature toggle first (if available)
213        if let (Some(org_id), Some(ref org_controls)) = (org_id, &self.org_controls) {
214            return org_controls.is_feature_enabled(org_id, Some(workspace_id), feature).await;
215        }
216
217        // Default to enabled if no org controls
218        Ok(true)
219    }
220
221    /// Record token usage and cost
222    pub async fn record_usage(
223        &self,
224        org_id: Option<&str>,
225        workspace_id: &str,
226        user_id: Option<&str>,
227        tokens: u64,
228        cost_usd: f64,
229    ) -> Result<()> {
230        self.record_usage_with_feature(org_id, workspace_id, user_id, tokens, cost_usd, None)
231            .await
232    }
233
234    /// Record token usage and cost with feature tracking
235    ///
236    /// Records usage both in-memory (workspace-level) and in org controls (if available).
237    pub async fn record_usage_with_feature(
238        &self,
239        org_id: Option<&str>,
240        workspace_id: &str,
241        user_id: Option<&str>,
242        tokens: u64,
243        cost_usd: f64,
244        feature: Option<AiFeature>,
245    ) -> Result<()> {
246        // Record in org controls (if available) for audit log
247        if let (Some(org_id), Some(ref org_controls)) = (org_id, &self.org_controls) {
248            if let Some(feature) = feature {
249                let _feature_name = match feature {
250                    AiFeature::MockAi => "mock_generation",
251                    AiFeature::ContractDiff => "contract_diff",
252                    AiFeature::PersonaGeneration => "persona_generation",
253                    AiFeature::DebugAnalysis => "debug_analysis",
254                    AiFeature::GenerativeSchema => "generative_schema",
255                    AiFeature::VoiceInterface => "voice_interface",
256                    AiFeature::GeneralChat => "free_form_generation",
257                };
258                org_controls
259                    .record_usage(
260                        org_id,
261                        Some(workspace_id),
262                        user_id,
263                        feature,
264                        tokens,
265                        cost_usd,
266                        None,
267                    )
268                    .await?;
269            }
270        }
271
272        // Record in-memory (workspace-level tracking)
273        let mut tracker = self.usage_tracker.write().await;
274        let usage = tracker.entry(workspace_id.to_string()).or_insert_with(|| WorkspaceUsage {
275            tokens_used: 0,
276            cost_usd: 0.0,
277            calls_made: 0,
278            last_reset: Utc::now(),
279            daily_calls: HashMap::new(),
280            feature_usage: HashMap::new(),
281        });
282
283        usage.tokens_used += tokens;
284        usage.cost_usd += cost_usd;
285        usage.calls_made += 1;
286
287        // Track per-feature usage
288        if let Some(feature) = feature {
289            let feature_usage =
290                usage.feature_usage.entry(feature).or_insert_with(FeatureUsage::default);
291            feature_usage.tokens_used += tokens;
292            feature_usage.cost_usd += cost_usd;
293            feature_usage.calls_made += 1;
294        }
295
296        // Track daily calls
297        let today = Utc::now().date_naive();
298        *usage.daily_calls.entry(today).or_insert(0) += 1;
299
300        Ok(())
301    }
302
303    /// Reset usage for a workspace (useful for testing or monthly resets)
304    pub async fn reset_usage(&self, workspace_id: &str) -> Result<()> {
305        let mut tracker = self.usage_tracker.write().await;
306        tracker.remove(workspace_id);
307        Ok(())
308    }
309
310    /// Calculate cost based on provider and tokens
311    ///
312    /// Uses approximate pricing for common providers:
313    /// - OpenAI GPT-3.5: ~$0.002 per 1K tokens
314    /// - OpenAI GPT-4: ~$0.03 per 1K tokens
315    /// - Anthropic Claude: ~$0.008 per 1K tokens
316    /// - Ollama: $0 (local)
317    pub fn calculate_cost(provider: &str, model: &str, tokens: u64) -> f64 {
318        let tokens_k = tokens as f64 / 1000.0;
319
320        // Approximate pricing per 1K tokens
321        let price_per_1k = if provider.to_lowercase() == "ollama" {
322            0.0 // Free local models
323        } else if model.contains("gpt-4") {
324            0.03 // GPT-4 pricing
325        } else if model.contains("gpt-3.5") || model.contains("gpt-3") {
326            0.002 // GPT-3.5 pricing
327        } else if provider.to_lowercase() == "anthropic" {
328            0.008 // Claude pricing
329        } else {
330            0.002 // Default to GPT-3.5 pricing
331        };
332
333        tokens_k * price_per_1k
334    }
335}
336
337/// Budget configuration
338#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct BudgetConfig {
340    /// Maximum tokens per workspace
341    pub max_tokens_per_workspace: u64,
342
343    /// Maximum AI calls per day
344    pub max_ai_calls_per_day: u64,
345
346    /// Rate limit per minute
347    pub rate_limit_per_minute: u64,
348}
349
350impl Default for BudgetConfig {
351    fn default() -> Self {
352        Self {
353            max_tokens_per_workspace: 100_000,
354            max_ai_calls_per_day: 1_000,
355            rate_limit_per_minute: 10,
356        }
357    }
358}
359
360/// Usage statistics
361#[derive(Debug, Clone, Serialize, Deserialize)]
362pub struct UsageStats {
363    /// Total tokens used
364    pub tokens_used: u64,
365
366    /// Total cost in USD
367    pub cost_usd: f64,
368
369    /// Number of AI calls made
370    pub calls_made: u64,
371
372    /// Budget limit
373    pub budget_limit: u64,
374
375    /// Usage percentage (0.0 to 1.0)
376    pub usage_percentage: f64,
377
378    /// Per-feature usage breakdown
379    #[serde(skip_serializing_if = "Option::is_none")]
380    pub feature_breakdown: Option<HashMap<String, FeatureUsage>>,
381}