oxify_connect_llm/
workflow.rs

1//! Per-Workflow Cost Tracking
2//!
3//! This module provides cost tracking and budget enforcement on a per-workflow basis.
4//! It allows you to track LLM usage and costs for individual workflows, set per-workflow
5//! budgets, and enforce budget limits.
6//!
7//! # Example
8//!
9//! ```rust
10//! use oxify_connect_llm::{WorkflowProvider, WorkflowTracker, BudgetLimit, OpenAIProvider};
11//!
12//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
13//! let tracker = WorkflowTracker::new();
14//! let provider = OpenAIProvider::new("api-key".to_string(), "gpt-4".to_string());
15//!
16//! // Wrap with workflow tracking
17//! let workflow_provider = WorkflowProvider::new(
18//!     provider,
19//!     tracker.clone(),
20//!     "workflow-123".to_string(),
21//!     Some(BudgetLimit::cents(1000)), // $10.00 budget
22//! );
23//!
24//! // Make requests - costs are tracked per workflow
25//! // let response = workflow_provider.complete(request).await?;
26//!
27//! // Check workflow stats
28//! let stats = tracker.get_stats("workflow-123");
29//! println!("Workflow spent: ${:.2}", stats.total_cost_usd);
30//! # Ok(())
31//! # }
32//! ```
33
34use crate::{
35    BudgetLimit, EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmError, LlmProvider,
36    LlmRequest, LlmResponse, LlmStream, ModelPricing, Result, StreamingLlmProvider,
37};
38use async_trait::async_trait;
39use std::collections::HashMap;
40use std::sync::{Arc, Mutex};
41
42/// Statistics for a workflow
43#[derive(Debug, Clone)]
44pub struct WorkflowStats {
45    /// Workflow ID
46    pub workflow_id: String,
47    /// Total requests made
48    pub total_requests: u64,
49    /// Total tokens used (prompt + completion)
50    pub total_tokens: u64,
51    /// Total prompt tokens
52    pub prompt_tokens: u64,
53    /// Total completion tokens
54    pub completion_tokens: u64,
55    /// Total cost in cents
56    pub total_cost_cents: u64,
57    /// Total cost in USD
58    pub total_cost_usd: f64,
59    /// Budget limit (if any)
60    pub budget_limit: Option<BudgetLimit>,
61    /// Budget remaining (if budget set)
62    pub budget_remaining_cents: Option<u64>,
63    /// Budget remaining percentage (0-100, if budget set)
64    pub budget_remaining_percent: Option<f64>,
65}
66
67/// Tracks costs and usage per workflow
68#[derive(Clone)]
69pub struct WorkflowTracker {
70    workflows: Arc<Mutex<HashMap<String, WorkflowData>>>,
71}
72
73#[derive(Debug, Clone)]
74struct WorkflowData {
75    requests: u64,
76    tokens: u64,
77    prompt_tokens: u64,
78    completion_tokens: u64,
79    cost_cents: u64,
80    budget: Option<BudgetLimit>,
81}
82
83impl WorkflowTracker {
84    /// Create a new workflow tracker
85    pub fn new() -> Self {
86        Self {
87            workflows: Arc::new(Mutex::new(HashMap::new())),
88        }
89    }
90
91    /// Set a budget limit for a workflow
92    pub fn set_budget(&self, workflow_id: &str, budget: BudgetLimit) {
93        let mut workflows = self.workflows.lock().unwrap();
94        workflows
95            .entry(workflow_id.to_string())
96            .or_insert_with(|| WorkflowData {
97                requests: 0,
98                tokens: 0,
99                prompt_tokens: 0,
100                completion_tokens: 0,
101                cost_cents: 0,
102                budget: None,
103            })
104            .budget = Some(budget);
105    }
106
107    /// Record usage for a workflow
108    pub fn record_usage(
109        &self,
110        workflow_id: &str,
111        prompt_tokens: u32,
112        completion_tokens: u32,
113        cost_cents: u64,
114    ) {
115        let mut workflows = self.workflows.lock().unwrap();
116        let data = workflows
117            .entry(workflow_id.to_string())
118            .or_insert_with(|| WorkflowData {
119                requests: 0,
120                tokens: 0,
121                prompt_tokens: 0,
122                completion_tokens: 0,
123                cost_cents: 0,
124                budget: None,
125            });
126
127        data.requests += 1;
128        data.prompt_tokens += prompt_tokens as u64;
129        data.completion_tokens += completion_tokens as u64;
130        data.tokens += (prompt_tokens + completion_tokens) as u64;
131        data.cost_cents += cost_cents;
132    }
133
134    /// Check if a workflow can afford a request (budget check)
135    pub fn can_afford(&self, workflow_id: &str, estimated_cost_cents: u64) -> bool {
136        let workflows = self.workflows.lock().unwrap();
137        if let Some(data) = workflows.get(workflow_id) {
138            if let Some(budget) = &data.budget {
139                let budget_cents = budget.as_cents();
140                return data.cost_cents + estimated_cost_cents <= budget_cents;
141            }
142        }
143        true // No budget = always allowed
144    }
145
146    /// Get statistics for a workflow
147    pub fn get_stats(&self, workflow_id: &str) -> WorkflowStats {
148        let workflows = self.workflows.lock().unwrap();
149        if let Some(data) = workflows.get(workflow_id) {
150            let budget_remaining_cents = data.budget.as_ref().map(|b| {
151                let budget_cents = b.as_cents();
152                budget_cents.saturating_sub(data.cost_cents)
153            });
154
155            let budget_remaining_percent = data.budget.as_ref().map(|b| {
156                let budget_cents = b.as_cents();
157                if budget_cents == 0 {
158                    0.0
159                } else {
160                    let remaining = budget_cents.saturating_sub(data.cost_cents);
161                    (remaining as f64 / budget_cents as f64) * 100.0
162                }
163            });
164
165            WorkflowStats {
166                workflow_id: workflow_id.to_string(),
167                total_requests: data.requests,
168                total_tokens: data.tokens,
169                prompt_tokens: data.prompt_tokens,
170                completion_tokens: data.completion_tokens,
171                total_cost_cents: data.cost_cents,
172                total_cost_usd: data.cost_cents as f64 / 100.0,
173                budget_limit: data.budget.clone(),
174                budget_remaining_cents,
175                budget_remaining_percent,
176            }
177        } else {
178            WorkflowStats {
179                workflow_id: workflow_id.to_string(),
180                total_requests: 0,
181                total_tokens: 0,
182                prompt_tokens: 0,
183                completion_tokens: 0,
184                total_cost_cents: 0,
185                total_cost_usd: 0.0,
186                budget_limit: None,
187                budget_remaining_cents: None,
188                budget_remaining_percent: None,
189            }
190        }
191    }
192
193    /// Get statistics for all workflows
194    pub fn get_all_stats(&self) -> Vec<WorkflowStats> {
195        let workflows = self.workflows.lock().unwrap();
196        workflows
197            .iter()
198            .map(|(workflow_id, data)| {
199                let budget_remaining_cents = data.budget.as_ref().map(|b| {
200                    let budget_cents = b.as_cents();
201                    budget_cents.saturating_sub(data.cost_cents)
202                });
203
204                let budget_remaining_percent = data.budget.as_ref().map(|b| {
205                    let budget_cents = b.as_cents();
206                    if budget_cents == 0 {
207                        0.0
208                    } else {
209                        let remaining = budget_cents.saturating_sub(data.cost_cents);
210                        (remaining as f64 / budget_cents as f64) * 100.0
211                    }
212                });
213
214                WorkflowStats {
215                    workflow_id: workflow_id.to_string(),
216                    total_requests: data.requests,
217                    total_tokens: data.tokens,
218                    prompt_tokens: data.prompt_tokens,
219                    completion_tokens: data.completion_tokens,
220                    total_cost_cents: data.cost_cents,
221                    total_cost_usd: data.cost_cents as f64 / 100.0,
222                    budget_limit: data.budget.clone(),
223                    budget_remaining_cents,
224                    budget_remaining_percent,
225                }
226            })
227            .collect()
228    }
229
230    /// Reset statistics for a workflow
231    pub fn reset(&self, workflow_id: &str) {
232        let mut workflows = self.workflows.lock().unwrap();
233        if let Some(data) = workflows.get_mut(workflow_id) {
234            let budget = data.budget.clone();
235            *data = WorkflowData {
236                requests: 0,
237                tokens: 0,
238                prompt_tokens: 0,
239                completion_tokens: 0,
240                cost_cents: 0,
241                budget,
242            };
243        }
244    }
245
246    /// Reset all workflow statistics
247    pub fn reset_all(&self) {
248        let mut workflows = self.workflows.lock().unwrap();
249        for data in workflows.values_mut() {
250            let budget = data.budget.clone();
251            *data = WorkflowData {
252                requests: 0,
253                tokens: 0,
254                prompt_tokens: 0,
255                completion_tokens: 0,
256                cost_cents: 0,
257                budget,
258            };
259        }
260    }
261
262    /// Remove a workflow from tracking
263    pub fn remove(&self, workflow_id: &str) {
264        let mut workflows = self.workflows.lock().unwrap();
265        workflows.remove(workflow_id);
266    }
267}
268
269impl Default for WorkflowTracker {
270    fn default() -> Self {
271        Self::new()
272    }
273}
274
275/// LLM provider wrapper that tracks costs per workflow
276pub struct WorkflowProvider<P> {
277    provider: P,
278    tracker: WorkflowTracker,
279    workflow_id: String,
280    pricing: ModelPricing,
281}
282
283impl<P> WorkflowProvider<P> {
284    /// Create a new workflow provider
285    pub fn new(
286        provider: P,
287        tracker: WorkflowTracker,
288        workflow_id: String,
289        budget: Option<BudgetLimit>,
290    ) -> Self {
291        // Set budget in tracker if provided
292        if let Some(ref budget_limit) = budget {
293            tracker.set_budget(&workflow_id, budget_limit.clone());
294        }
295
296        Self {
297            provider,
298            tracker,
299            workflow_id,
300            pricing: ModelPricing::gpt4(),
301        }
302    }
303
304    /// Set the pricing model for cost calculation
305    pub fn with_pricing(mut self, pricing: ModelPricing) -> Self {
306        self.pricing = pricing;
307        self
308    }
309
310    /// Get the workflow tracker
311    pub fn tracker(&self) -> &WorkflowTracker {
312        &self.tracker
313    }
314
315    /// Get the workflow ID
316    pub fn workflow_id(&self) -> &str {
317        &self.workflow_id
318    }
319
320    /// Get statistics for this workflow
321    pub fn stats(&self) -> WorkflowStats {
322        self.tracker.get_stats(&self.workflow_id)
323    }
324}
325
326#[async_trait]
327impl<P: LlmProvider> LlmProvider for WorkflowProvider<P> {
328    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
329        // Estimate cost before making request
330        let estimated_tokens =
331            request.prompt.len() / 4 + request.max_tokens.unwrap_or(1000) as usize;
332        let estimated_cost_cents = self
333            .pricing
334            .calculate_cost((estimated_tokens / 2) as u32, (estimated_tokens / 2) as u32)
335            as u64;
336
337        // Check budget
338        if !self
339            .tracker
340            .can_afford(&self.workflow_id, estimated_cost_cents)
341        {
342            return Err(LlmError::Other(format!(
343                "Workflow '{}' has exceeded its budget limit",
344                self.workflow_id
345            )));
346        }
347
348        // Make the request
349        let response = self.provider.complete(request).await?;
350
351        // Record actual usage
352        if let Some(usage) = &response.usage {
353            let actual_cost_cents = self
354                .pricing
355                .calculate_cost(usage.prompt_tokens, usage.completion_tokens)
356                as u64;
357            self.tracker.record_usage(
358                &self.workflow_id,
359                usage.prompt_tokens,
360                usage.completion_tokens,
361                actual_cost_cents,
362            );
363        }
364
365        Ok(response)
366    }
367}
368
369#[async_trait]
370impl<P: StreamingLlmProvider> StreamingLlmProvider for WorkflowProvider<P> {
371    async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
372        // Estimate cost before making request
373        let estimated_tokens =
374            request.prompt.len() / 4 + request.max_tokens.unwrap_or(1000) as usize;
375        let estimated_cost_cents = self
376            .pricing
377            .calculate_cost((estimated_tokens / 2) as u32, (estimated_tokens / 2) as u32)
378            as u64;
379
380        // Check budget
381        if !self
382            .tracker
383            .can_afford(&self.workflow_id, estimated_cost_cents)
384        {
385            return Err(LlmError::Other(format!(
386                "Workflow '{}' has exceeded its budget limit",
387                self.workflow_id
388            )));
389        }
390
391        // Note: For streaming, we can't track exact usage until the stream completes
392        // This is a limitation of the streaming API
393        self.provider.complete_stream(request).await
394    }
395}
396
397/// Embedding provider wrapper that tracks costs per workflow
398pub struct WorkflowEmbeddingProvider<P> {
399    provider: P,
400    tracker: WorkflowTracker,
401    workflow_id: String,
402    pricing: ModelPricing,
403}
404
405impl<P> WorkflowEmbeddingProvider<P> {
406    /// Create a new workflow embedding provider
407    pub fn new(provider: P, tracker: WorkflowTracker, workflow_id: String) -> Self {
408        Self {
409            provider,
410            tracker,
411            workflow_id,
412            pricing: ModelPricing::ada_embedding(),
413        }
414    }
415
416    /// Set the pricing model for cost calculation
417    pub fn with_pricing(mut self, pricing: ModelPricing) -> Self {
418        self.pricing = pricing;
419        self
420    }
421
422    /// Get the workflow tracker
423    pub fn tracker(&self) -> &WorkflowTracker {
424        &self.tracker
425    }
426
427    /// Get statistics for this workflow
428    pub fn stats(&self) -> WorkflowStats {
429        self.tracker.get_stats(&self.workflow_id)
430    }
431}
432
433#[async_trait]
434impl<P: EmbeddingProvider> EmbeddingProvider for WorkflowEmbeddingProvider<P> {
435    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
436        let response = self.provider.embed(request).await?;
437
438        // Record usage
439        if let Some(usage) = &response.usage {
440            let cost_cents = self.pricing.calculate_cost(usage.prompt_tokens, 0) as u64;
441            self.tracker
442                .record_usage(&self.workflow_id, usage.prompt_tokens, 0, cost_cents);
443        }
444
445        Ok(response)
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452
453    #[test]
454    fn test_workflow_tracker_new() {
455        let tracker = WorkflowTracker::new();
456        let stats = tracker.get_stats("workflow-1");
457        assert_eq!(stats.total_requests, 0);
458        assert_eq!(stats.total_tokens, 0);
459        assert_eq!(stats.total_cost_cents, 0);
460    }
461
462    #[test]
463    fn test_workflow_tracker_record_usage() {
464        let tracker = WorkflowTracker::new();
465
466        // Record some usage
467        tracker.record_usage("workflow-1", 100, 50, 10);
468        tracker.record_usage("workflow-1", 200, 100, 20);
469
470        let stats = tracker.get_stats("workflow-1");
471        assert_eq!(stats.total_requests, 2);
472        assert_eq!(stats.prompt_tokens, 300);
473        assert_eq!(stats.completion_tokens, 150);
474        assert_eq!(stats.total_tokens, 450);
475        assert_eq!(stats.total_cost_cents, 30);
476        assert_eq!(stats.total_cost_usd, 0.30);
477    }
478
479    #[test]
480    fn test_workflow_tracker_multiple_workflows() {
481        let tracker = WorkflowTracker::new();
482
483        tracker.record_usage("workflow-1", 100, 50, 10);
484        tracker.record_usage("workflow-2", 200, 100, 20);
485        tracker.record_usage("workflow-1", 50, 25, 5);
486
487        let stats1 = tracker.get_stats("workflow-1");
488        assert_eq!(stats1.total_requests, 2);
489        assert_eq!(stats1.total_cost_cents, 15);
490
491        let stats2 = tracker.get_stats("workflow-2");
492        assert_eq!(stats2.total_requests, 1);
493        assert_eq!(stats2.total_cost_cents, 20);
494
495        let all_stats = tracker.get_all_stats();
496        assert_eq!(all_stats.len(), 2);
497    }
498
499    #[test]
500    fn test_workflow_budget() {
501        let tracker = WorkflowTracker::new();
502
503        // Set a budget of $1.00 (100 cents)
504        tracker.set_budget("workflow-1", BudgetLimit::cents(100));
505
506        // Record usage
507        tracker.record_usage("workflow-1", 100, 50, 30);
508
509        let stats = tracker.get_stats("workflow-1");
510        assert_eq!(stats.budget_limit, Some(BudgetLimit::cents(100)));
511        assert_eq!(stats.budget_remaining_cents, Some(70));
512        assert_eq!(stats.budget_remaining_percent, Some(70.0));
513    }
514
515    #[test]
516    fn test_workflow_can_afford() {
517        let tracker = WorkflowTracker::new();
518
519        // Set a budget of $1.00 (100 cents)
520        tracker.set_budget("workflow-1", BudgetLimit::cents(100));
521
522        // Can afford 50 cents
523        assert!(tracker.can_afford("workflow-1", 50));
524
525        // Record 80 cents usage
526        tracker.record_usage("workflow-1", 100, 50, 80);
527
528        // Can afford 20 cents
529        assert!(tracker.can_afford("workflow-1", 20));
530
531        // Cannot afford 21 cents (would exceed budget)
532        assert!(!tracker.can_afford("workflow-1", 21));
533
534        // Workflow without budget can always afford
535        assert!(tracker.can_afford("workflow-2", 1000000));
536    }
537
538    #[test]
539    fn test_workflow_reset() {
540        let tracker = WorkflowTracker::new();
541
542        tracker.record_usage("workflow-1", 100, 50, 10);
543        tracker.set_budget("workflow-1", BudgetLimit::cents(100));
544
545        let stats = tracker.get_stats("workflow-1");
546        assert_eq!(stats.total_requests, 1);
547
548        // Reset workflow
549        tracker.reset("workflow-1");
550
551        let stats = tracker.get_stats("workflow-1");
552        assert_eq!(stats.total_requests, 0);
553        assert_eq!(stats.total_cost_cents, 0);
554        // Budget should be preserved
555        assert_eq!(stats.budget_limit, Some(BudgetLimit::cents(100)));
556    }
557
558    #[test]
559    fn test_workflow_reset_all() {
560        let tracker = WorkflowTracker::new();
561
562        tracker.record_usage("workflow-1", 100, 50, 10);
563        tracker.record_usage("workflow-2", 200, 100, 20);
564
565        tracker.reset_all();
566
567        let stats1 = tracker.get_stats("workflow-1");
568        let stats2 = tracker.get_stats("workflow-2");
569        assert_eq!(stats1.total_requests, 0);
570        assert_eq!(stats2.total_requests, 0);
571    }
572
573    #[test]
574    fn test_workflow_remove() {
575        let tracker = WorkflowTracker::new();
576
577        tracker.record_usage("workflow-1", 100, 50, 10);
578        tracker.record_usage("workflow-2", 200, 100, 20);
579
580        tracker.remove("workflow-1");
581
582        let all_stats = tracker.get_all_stats();
583        assert_eq!(all_stats.len(), 1);
584        assert_eq!(all_stats[0].workflow_id, "workflow-2");
585    }
586
587    #[test]
588    fn test_workflow_budget_exceeded() {
589        let tracker = WorkflowTracker::new();
590
591        // Set a budget of $0.50 (50 cents)
592        tracker.set_budget("workflow-1", BudgetLimit::cents(50));
593
594        // Record usage that exceeds budget
595        tracker.record_usage("workflow-1", 100, 50, 60);
596
597        let stats = tracker.get_stats("workflow-1");
598        assert_eq!(stats.budget_remaining_cents, Some(0)); // Saturating sub
599        assert_eq!(stats.budget_remaining_percent, Some(0.0));
600    }
601
602    #[test]
603    fn test_workflow_stats_no_budget() {
604        let tracker = WorkflowTracker::new();
605
606        tracker.record_usage("workflow-1", 100, 50, 10);
607
608        let stats = tracker.get_stats("workflow-1");
609        assert_eq!(stats.budget_limit, None);
610        assert_eq!(stats.budget_remaining_cents, None);
611        assert_eq!(stats.budget_remaining_percent, None);
612    }
613}