1use crate::{
35 BudgetLimit, EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmError, LlmProvider,
36 LlmRequest, LlmResponse, LlmStream, ModelPricing, Result, StreamingLlmProvider,
37};
38use async_trait::async_trait;
39use std::collections::HashMap;
40use std::sync::{Arc, Mutex};
41
42#[derive(Debug, Clone)]
44pub struct WorkflowStats {
45 pub workflow_id: String,
47 pub total_requests: u64,
49 pub total_tokens: u64,
51 pub prompt_tokens: u64,
53 pub completion_tokens: u64,
55 pub total_cost_cents: u64,
57 pub total_cost_usd: f64,
59 pub budget_limit: Option<BudgetLimit>,
61 pub budget_remaining_cents: Option<u64>,
63 pub budget_remaining_percent: Option<f64>,
65}
66
67#[derive(Clone)]
69pub struct WorkflowTracker {
70 workflows: Arc<Mutex<HashMap<String, WorkflowData>>>,
71}
72
73#[derive(Debug, Clone)]
74struct WorkflowData {
75 requests: u64,
76 tokens: u64,
77 prompt_tokens: u64,
78 completion_tokens: u64,
79 cost_cents: u64,
80 budget: Option<BudgetLimit>,
81}
82
83impl WorkflowTracker {
84 pub fn new() -> Self {
86 Self {
87 workflows: Arc::new(Mutex::new(HashMap::new())),
88 }
89 }
90
91 pub fn set_budget(&self, workflow_id: &str, budget: BudgetLimit) {
93 let mut workflows = self.workflows.lock().unwrap();
94 workflows
95 .entry(workflow_id.to_string())
96 .or_insert_with(|| WorkflowData {
97 requests: 0,
98 tokens: 0,
99 prompt_tokens: 0,
100 completion_tokens: 0,
101 cost_cents: 0,
102 budget: None,
103 })
104 .budget = Some(budget);
105 }
106
107 pub fn record_usage(
109 &self,
110 workflow_id: &str,
111 prompt_tokens: u32,
112 completion_tokens: u32,
113 cost_cents: u64,
114 ) {
115 let mut workflows = self.workflows.lock().unwrap();
116 let data = workflows
117 .entry(workflow_id.to_string())
118 .or_insert_with(|| WorkflowData {
119 requests: 0,
120 tokens: 0,
121 prompt_tokens: 0,
122 completion_tokens: 0,
123 cost_cents: 0,
124 budget: None,
125 });
126
127 data.requests += 1;
128 data.prompt_tokens += prompt_tokens as u64;
129 data.completion_tokens += completion_tokens as u64;
130 data.tokens += (prompt_tokens + completion_tokens) as u64;
131 data.cost_cents += cost_cents;
132 }
133
134 pub fn can_afford(&self, workflow_id: &str, estimated_cost_cents: u64) -> bool {
136 let workflows = self.workflows.lock().unwrap();
137 if let Some(data) = workflows.get(workflow_id) {
138 if let Some(budget) = &data.budget {
139 let budget_cents = budget.as_cents();
140 return data.cost_cents + estimated_cost_cents <= budget_cents;
141 }
142 }
143 true }
145
146 pub fn get_stats(&self, workflow_id: &str) -> WorkflowStats {
148 let workflows = self.workflows.lock().unwrap();
149 if let Some(data) = workflows.get(workflow_id) {
150 let budget_remaining_cents = data.budget.as_ref().map(|b| {
151 let budget_cents = b.as_cents();
152 budget_cents.saturating_sub(data.cost_cents)
153 });
154
155 let budget_remaining_percent = data.budget.as_ref().map(|b| {
156 let budget_cents = b.as_cents();
157 if budget_cents == 0 {
158 0.0
159 } else {
160 let remaining = budget_cents.saturating_sub(data.cost_cents);
161 (remaining as f64 / budget_cents as f64) * 100.0
162 }
163 });
164
165 WorkflowStats {
166 workflow_id: workflow_id.to_string(),
167 total_requests: data.requests,
168 total_tokens: data.tokens,
169 prompt_tokens: data.prompt_tokens,
170 completion_tokens: data.completion_tokens,
171 total_cost_cents: data.cost_cents,
172 total_cost_usd: data.cost_cents as f64 / 100.0,
173 budget_limit: data.budget.clone(),
174 budget_remaining_cents,
175 budget_remaining_percent,
176 }
177 } else {
178 WorkflowStats {
179 workflow_id: workflow_id.to_string(),
180 total_requests: 0,
181 total_tokens: 0,
182 prompt_tokens: 0,
183 completion_tokens: 0,
184 total_cost_cents: 0,
185 total_cost_usd: 0.0,
186 budget_limit: None,
187 budget_remaining_cents: None,
188 budget_remaining_percent: None,
189 }
190 }
191 }
192
193 pub fn get_all_stats(&self) -> Vec<WorkflowStats> {
195 let workflows = self.workflows.lock().unwrap();
196 workflows
197 .iter()
198 .map(|(workflow_id, data)| {
199 let budget_remaining_cents = data.budget.as_ref().map(|b| {
200 let budget_cents = b.as_cents();
201 budget_cents.saturating_sub(data.cost_cents)
202 });
203
204 let budget_remaining_percent = data.budget.as_ref().map(|b| {
205 let budget_cents = b.as_cents();
206 if budget_cents == 0 {
207 0.0
208 } else {
209 let remaining = budget_cents.saturating_sub(data.cost_cents);
210 (remaining as f64 / budget_cents as f64) * 100.0
211 }
212 });
213
214 WorkflowStats {
215 workflow_id: workflow_id.to_string(),
216 total_requests: data.requests,
217 total_tokens: data.tokens,
218 prompt_tokens: data.prompt_tokens,
219 completion_tokens: data.completion_tokens,
220 total_cost_cents: data.cost_cents,
221 total_cost_usd: data.cost_cents as f64 / 100.0,
222 budget_limit: data.budget.clone(),
223 budget_remaining_cents,
224 budget_remaining_percent,
225 }
226 })
227 .collect()
228 }
229
230 pub fn reset(&self, workflow_id: &str) {
232 let mut workflows = self.workflows.lock().unwrap();
233 if let Some(data) = workflows.get_mut(workflow_id) {
234 let budget = data.budget.clone();
235 *data = WorkflowData {
236 requests: 0,
237 tokens: 0,
238 prompt_tokens: 0,
239 completion_tokens: 0,
240 cost_cents: 0,
241 budget,
242 };
243 }
244 }
245
246 pub fn reset_all(&self) {
248 let mut workflows = self.workflows.lock().unwrap();
249 for data in workflows.values_mut() {
250 let budget = data.budget.clone();
251 *data = WorkflowData {
252 requests: 0,
253 tokens: 0,
254 prompt_tokens: 0,
255 completion_tokens: 0,
256 cost_cents: 0,
257 budget,
258 };
259 }
260 }
261
262 pub fn remove(&self, workflow_id: &str) {
264 let mut workflows = self.workflows.lock().unwrap();
265 workflows.remove(workflow_id);
266 }
267}
268
269impl Default for WorkflowTracker {
270 fn default() -> Self {
271 Self::new()
272 }
273}
274
275pub struct WorkflowProvider<P> {
277 provider: P,
278 tracker: WorkflowTracker,
279 workflow_id: String,
280 pricing: ModelPricing,
281}
282
283impl<P> WorkflowProvider<P> {
284 pub fn new(
286 provider: P,
287 tracker: WorkflowTracker,
288 workflow_id: String,
289 budget: Option<BudgetLimit>,
290 ) -> Self {
291 if let Some(ref budget_limit) = budget {
293 tracker.set_budget(&workflow_id, budget_limit.clone());
294 }
295
296 Self {
297 provider,
298 tracker,
299 workflow_id,
300 pricing: ModelPricing::gpt4(),
301 }
302 }
303
304 pub fn with_pricing(mut self, pricing: ModelPricing) -> Self {
306 self.pricing = pricing;
307 self
308 }
309
310 pub fn tracker(&self) -> &WorkflowTracker {
312 &self.tracker
313 }
314
315 pub fn workflow_id(&self) -> &str {
317 &self.workflow_id
318 }
319
320 pub fn stats(&self) -> WorkflowStats {
322 self.tracker.get_stats(&self.workflow_id)
323 }
324}
325
326#[async_trait]
327impl<P: LlmProvider> LlmProvider for WorkflowProvider<P> {
328 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
329 let estimated_tokens =
331 request.prompt.len() / 4 + request.max_tokens.unwrap_or(1000) as usize;
332 let estimated_cost_cents = self
333 .pricing
334 .calculate_cost((estimated_tokens / 2) as u32, (estimated_tokens / 2) as u32)
335 as u64;
336
337 if !self
339 .tracker
340 .can_afford(&self.workflow_id, estimated_cost_cents)
341 {
342 return Err(LlmError::Other(format!(
343 "Workflow '{}' has exceeded its budget limit",
344 self.workflow_id
345 )));
346 }
347
348 let response = self.provider.complete(request).await?;
350
351 if let Some(usage) = &response.usage {
353 let actual_cost_cents = self
354 .pricing
355 .calculate_cost(usage.prompt_tokens, usage.completion_tokens)
356 as u64;
357 self.tracker.record_usage(
358 &self.workflow_id,
359 usage.prompt_tokens,
360 usage.completion_tokens,
361 actual_cost_cents,
362 );
363 }
364
365 Ok(response)
366 }
367}
368
369#[async_trait]
370impl<P: StreamingLlmProvider> StreamingLlmProvider for WorkflowProvider<P> {
371 async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
372 let estimated_tokens =
374 request.prompt.len() / 4 + request.max_tokens.unwrap_or(1000) as usize;
375 let estimated_cost_cents = self
376 .pricing
377 .calculate_cost((estimated_tokens / 2) as u32, (estimated_tokens / 2) as u32)
378 as u64;
379
380 if !self
382 .tracker
383 .can_afford(&self.workflow_id, estimated_cost_cents)
384 {
385 return Err(LlmError::Other(format!(
386 "Workflow '{}' has exceeded its budget limit",
387 self.workflow_id
388 )));
389 }
390
391 self.provider.complete_stream(request).await
394 }
395}
396
397pub struct WorkflowEmbeddingProvider<P> {
399 provider: P,
400 tracker: WorkflowTracker,
401 workflow_id: String,
402 pricing: ModelPricing,
403}
404
405impl<P> WorkflowEmbeddingProvider<P> {
406 pub fn new(provider: P, tracker: WorkflowTracker, workflow_id: String) -> Self {
408 Self {
409 provider,
410 tracker,
411 workflow_id,
412 pricing: ModelPricing::ada_embedding(),
413 }
414 }
415
416 pub fn with_pricing(mut self, pricing: ModelPricing) -> Self {
418 self.pricing = pricing;
419 self
420 }
421
422 pub fn tracker(&self) -> &WorkflowTracker {
424 &self.tracker
425 }
426
427 pub fn stats(&self) -> WorkflowStats {
429 self.tracker.get_stats(&self.workflow_id)
430 }
431}
432
433#[async_trait]
434impl<P: EmbeddingProvider> EmbeddingProvider for WorkflowEmbeddingProvider<P> {
435 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
436 let response = self.provider.embed(request).await?;
437
438 if let Some(usage) = &response.usage {
440 let cost_cents = self.pricing.calculate_cost(usage.prompt_tokens, 0) as u64;
441 self.tracker
442 .record_usage(&self.workflow_id, usage.prompt_tokens, 0, cost_cents);
443 }
444
445 Ok(response)
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452
453 #[test]
454 fn test_workflow_tracker_new() {
455 let tracker = WorkflowTracker::new();
456 let stats = tracker.get_stats("workflow-1");
457 assert_eq!(stats.total_requests, 0);
458 assert_eq!(stats.total_tokens, 0);
459 assert_eq!(stats.total_cost_cents, 0);
460 }
461
462 #[test]
463 fn test_workflow_tracker_record_usage() {
464 let tracker = WorkflowTracker::new();
465
466 tracker.record_usage("workflow-1", 100, 50, 10);
468 tracker.record_usage("workflow-1", 200, 100, 20);
469
470 let stats = tracker.get_stats("workflow-1");
471 assert_eq!(stats.total_requests, 2);
472 assert_eq!(stats.prompt_tokens, 300);
473 assert_eq!(stats.completion_tokens, 150);
474 assert_eq!(stats.total_tokens, 450);
475 assert_eq!(stats.total_cost_cents, 30);
476 assert_eq!(stats.total_cost_usd, 0.30);
477 }
478
479 #[test]
480 fn test_workflow_tracker_multiple_workflows() {
481 let tracker = WorkflowTracker::new();
482
483 tracker.record_usage("workflow-1", 100, 50, 10);
484 tracker.record_usage("workflow-2", 200, 100, 20);
485 tracker.record_usage("workflow-1", 50, 25, 5);
486
487 let stats1 = tracker.get_stats("workflow-1");
488 assert_eq!(stats1.total_requests, 2);
489 assert_eq!(stats1.total_cost_cents, 15);
490
491 let stats2 = tracker.get_stats("workflow-2");
492 assert_eq!(stats2.total_requests, 1);
493 assert_eq!(stats2.total_cost_cents, 20);
494
495 let all_stats = tracker.get_all_stats();
496 assert_eq!(all_stats.len(), 2);
497 }
498
499 #[test]
500 fn test_workflow_budget() {
501 let tracker = WorkflowTracker::new();
502
503 tracker.set_budget("workflow-1", BudgetLimit::cents(100));
505
506 tracker.record_usage("workflow-1", 100, 50, 30);
508
509 let stats = tracker.get_stats("workflow-1");
510 assert_eq!(stats.budget_limit, Some(BudgetLimit::cents(100)));
511 assert_eq!(stats.budget_remaining_cents, Some(70));
512 assert_eq!(stats.budget_remaining_percent, Some(70.0));
513 }
514
515 #[test]
516 fn test_workflow_can_afford() {
517 let tracker = WorkflowTracker::new();
518
519 tracker.set_budget("workflow-1", BudgetLimit::cents(100));
521
522 assert!(tracker.can_afford("workflow-1", 50));
524
525 tracker.record_usage("workflow-1", 100, 50, 80);
527
528 assert!(tracker.can_afford("workflow-1", 20));
530
531 assert!(!tracker.can_afford("workflow-1", 21));
533
534 assert!(tracker.can_afford("workflow-2", 1000000));
536 }
537
538 #[test]
539 fn test_workflow_reset() {
540 let tracker = WorkflowTracker::new();
541
542 tracker.record_usage("workflow-1", 100, 50, 10);
543 tracker.set_budget("workflow-1", BudgetLimit::cents(100));
544
545 let stats = tracker.get_stats("workflow-1");
546 assert_eq!(stats.total_requests, 1);
547
548 tracker.reset("workflow-1");
550
551 let stats = tracker.get_stats("workflow-1");
552 assert_eq!(stats.total_requests, 0);
553 assert_eq!(stats.total_cost_cents, 0);
554 assert_eq!(stats.budget_limit, Some(BudgetLimit::cents(100)));
556 }
557
558 #[test]
559 fn test_workflow_reset_all() {
560 let tracker = WorkflowTracker::new();
561
562 tracker.record_usage("workflow-1", 100, 50, 10);
563 tracker.record_usage("workflow-2", 200, 100, 20);
564
565 tracker.reset_all();
566
567 let stats1 = tracker.get_stats("workflow-1");
568 let stats2 = tracker.get_stats("workflow-2");
569 assert_eq!(stats1.total_requests, 0);
570 assert_eq!(stats2.total_requests, 0);
571 }
572
573 #[test]
574 fn test_workflow_remove() {
575 let tracker = WorkflowTracker::new();
576
577 tracker.record_usage("workflow-1", 100, 50, 10);
578 tracker.record_usage("workflow-2", 200, 100, 20);
579
580 tracker.remove("workflow-1");
581
582 let all_stats = tracker.get_all_stats();
583 assert_eq!(all_stats.len(), 1);
584 assert_eq!(all_stats[0].workflow_id, "workflow-2");
585 }
586
587 #[test]
588 fn test_workflow_budget_exceeded() {
589 let tracker = WorkflowTracker::new();
590
591 tracker.set_budget("workflow-1", BudgetLimit::cents(50));
593
594 tracker.record_usage("workflow-1", 100, 50, 60);
596
597 let stats = tracker.get_stats("workflow-1");
598 assert_eq!(stats.budget_remaining_cents, Some(0)); assert_eq!(stats.budget_remaining_percent, Some(0.0));
600 }
601
602 #[test]
603 fn test_workflow_stats_no_budget() {
604 let tracker = WorkflowTracker::new();
605
606 tracker.record_usage("workflow-1", 100, 50, 10);
607
608 let stats = tracker.get_stats("workflow-1");
609 assert_eq!(stats.budget_limit, None);
610 assert_eq!(stats.budget_remaining_cents, None);
611 assert_eq!(stats.budget_remaining_percent, None);
612 }
613}