oxify_model/
cost.rs

1//! Cost estimation for workflow execution
2//!
3//! This module provides cost estimation for LLM calls, vector operations,
4//! and overall workflow execution costs.
5
6use crate::{LlmConfig, Node, NodeKind, VectorConfig, Workflow};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Cost breakdown for a workflow execution
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct CostEstimate {
13    /// Total estimated cost in USD
14    pub total_usd: f64,
15
16    /// Cost breakdown by node
17    pub node_costs: HashMap<String, NodeCost>,
18
19    /// Cost breakdown by category
20    pub category_costs: CategoryCosts,
21
22    /// Token usage estimates
23    pub token_estimates: TokenEstimates,
24}
25
26/// Cost for a single node
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct NodeCost {
29    /// Node name
30    pub node_name: String,
31
32    /// Node type
33    pub node_type: String,
34
35    /// Estimated cost in USD
36    pub cost_usd: f64,
37
38    /// Number of expected executions (for loops, retries, etc.)
39    pub expected_executions: u32,
40
41    /// Breakdown of cost components
42    pub components: Vec<CostComponent>,
43}
44
45/// Individual cost component
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct CostComponent {
48    /// Component name (e.g., "input_tokens", "output_tokens", "api_call")
49    pub name: String,
50
51    /// Cost in USD
52    pub cost_usd: f64,
53
54    /// Quantity (tokens, API calls, etc.)
55    pub quantity: f64,
56
57    /// Unit (e.g., "tokens", "calls", "MB")
58    pub unit: String,
59}
60
61/// Cost breakdown by category
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct CategoryCosts {
64    /// Total LLM costs
65    pub llm_total: f64,
66
67    /// Total vector database costs
68    pub vector_total: f64,
69
70    /// Total code execution costs
71    pub code_total: f64,
72
73    /// Total tool/MCP costs
74    pub tool_total: f64,
75
76    /// Other costs
77    pub other_total: f64,
78}
79
80/// Token usage estimates
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct TokenEstimates {
83    /// Total input tokens across all LLM nodes
84    pub total_input_tokens: u64,
85
86    /// Total output tokens across all LLM nodes
87    pub total_output_tokens: u64,
88
89    /// Total tokens (input + output)
90    pub total_tokens: u64,
91}
92
93/// Pricing information for LLM models
94#[derive(Debug, Clone)]
95pub struct ModelPricing {
96    /// Cost per 1M input tokens in USD
97    pub input_cost_per_million: f64,
98
99    /// Cost per 1M output tokens in USD
100    pub output_cost_per_million: f64,
101}
102
103impl ModelPricing {
104    /// Get pricing for a model
105    pub fn for_model(provider: &str, model: &str) -> Self {
106        match (
107            provider.to_lowercase().as_str(),
108            model.to_lowercase().as_str(),
109        ) {
110            // OpenAI GPT-4 models
111            ("openai", m) if m.contains("gpt-4-turbo") => Self {
112                input_cost_per_million: 10.0,
113                output_cost_per_million: 30.0,
114            },
115            ("openai", m) if m.contains("gpt-4") => Self {
116                input_cost_per_million: 30.0,
117                output_cost_per_million: 60.0,
118            },
119            // OpenAI GPT-3.5 models
120            ("openai", m) if m.contains("gpt-3.5-turbo") => Self {
121                input_cost_per_million: 0.5,
122                output_cost_per_million: 1.5,
123            },
124            // Anthropic Claude models
125            ("anthropic", m) if m.contains("claude-3-opus") => Self {
126                input_cost_per_million: 15.0,
127                output_cost_per_million: 75.0,
128            },
129            ("anthropic", m) if m.contains("claude-3-sonnet") => Self {
130                input_cost_per_million: 3.0,
131                output_cost_per_million: 15.0,
132            },
133            ("anthropic", m) if m.contains("claude-3-haiku") => Self {
134                input_cost_per_million: 0.25,
135                output_cost_per_million: 1.25,
136            },
137            // Local/Ollama models (free but with compute costs)
138            ("ollama", _) | ("local", _) => Self {
139                input_cost_per_million: 0.0,
140                output_cost_per_million: 0.0,
141            },
142            // Default pricing (conservative estimate)
143            _ => Self {
144                input_cost_per_million: 5.0,
145                output_cost_per_million: 15.0,
146            },
147        }
148    }
149
150    /// Calculate cost for given token counts
151    pub fn calculate_cost(&self, input_tokens: u64, output_tokens: u64) -> f64 {
152        let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input_cost_per_million;
153        let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output_cost_per_million;
154        input_cost + output_cost
155    }
156}
157
158/// Cost estimator for workflows
159pub struct CostEstimator;
160
161impl CostEstimator {
162    /// Estimate cost for a workflow
163    pub fn estimate(workflow: &Workflow) -> CostEstimate {
164        let mut node_costs = HashMap::new();
165        let mut llm_total = 0.0;
166        let mut vector_total = 0.0;
167        let mut code_total = 0.0;
168        let mut tool_total = 0.0;
169        let mut other_total = 0.0;
170        let mut total_input_tokens = 0u64;
171        let mut total_output_tokens = 0u64;
172
173        for node in &workflow.nodes {
174            let node_cost = Self::estimate_node_cost(node);
175
176            // Update category totals
177            match &node.kind {
178                NodeKind::LLM(_) => llm_total += node_cost.cost_usd,
179                NodeKind::Retriever(_) => vector_total += node_cost.cost_usd,
180                NodeKind::Code(_) => code_total += node_cost.cost_usd,
181                NodeKind::Tool(_) => tool_total += node_cost.cost_usd,
182                _ => other_total += node_cost.cost_usd,
183            }
184
185            // Update token estimates
186            for component in &node_cost.components {
187                match component.name.as_str() {
188                    "input_tokens" => total_input_tokens += component.quantity as u64,
189                    "output_tokens" => total_output_tokens += component.quantity as u64,
190                    _ => {}
191                }
192            }
193
194            node_costs.insert(node.id.to_string(), node_cost);
195        }
196
197        let total_usd = llm_total + vector_total + code_total + tool_total + other_total;
198
199        CostEstimate {
200            total_usd,
201            node_costs,
202            category_costs: CategoryCosts {
203                llm_total,
204                vector_total,
205                code_total,
206                tool_total,
207                other_total,
208            },
209            token_estimates: TokenEstimates {
210                total_input_tokens,
211                total_output_tokens,
212                total_tokens: total_input_tokens + total_output_tokens,
213            },
214        }
215    }
216
217    /// Estimate cost for a single node
218    fn estimate_node_cost(node: &Node) -> NodeCost {
219        let mut components = Vec::new();
220        let expected_executions = Self::estimate_executions(node);
221
222        let cost_usd = match &node.kind {
223            NodeKind::LLM(config) => {
224                let (input_tokens, output_tokens) = Self::estimate_llm_tokens(config);
225                let pricing = ModelPricing::for_model(&config.provider, &config.model);
226
227                let input_cost = pricing.calculate_cost(input_tokens, 0);
228                let output_cost = pricing.calculate_cost(0, output_tokens);
229
230                components.push(CostComponent {
231                    name: "input_tokens".to_string(),
232                    cost_usd: input_cost,
233                    quantity: input_tokens as f64,
234                    unit: "tokens".to_string(),
235                });
236
237                components.push(CostComponent {
238                    name: "output_tokens".to_string(),
239                    cost_usd: output_cost,
240                    quantity: output_tokens as f64,
241                    unit: "tokens".to_string(),
242                });
243
244                (input_cost + output_cost) * expected_executions as f64
245            }
246            NodeKind::Retriever(config) => {
247                let vector_cost = Self::estimate_vector_cost(config);
248                components.push(CostComponent {
249                    name: "vector_search".to_string(),
250                    cost_usd: vector_cost,
251                    quantity: config.top_k as f64,
252                    unit: "results".to_string(),
253                });
254                vector_cost * expected_executions as f64
255            }
256            NodeKind::Code(_) => {
257                // Estimate compute cost (very rough estimate)
258                let compute_cost = 0.0001; // $0.0001 per execution
259                components.push(CostComponent {
260                    name: "compute".to_string(),
261                    cost_usd: compute_cost,
262                    quantity: 1.0,
263                    unit: "execution".to_string(),
264                });
265                compute_cost * expected_executions as f64
266            }
267            NodeKind::Tool(_) => {
268                // Estimate tool/API call cost
269                let api_cost = 0.001; // $0.001 per call
270                components.push(CostComponent {
271                    name: "api_call".to_string(),
272                    cost_usd: api_cost,
273                    quantity: 1.0,
274                    unit: "call".to_string(),
275                });
276                api_cost * expected_executions as f64
277            }
278            _ => {
279                // Start, End, IfElse, etc. have no cost
280                0.0
281            }
282        };
283
284        NodeCost {
285            node_name: node.name.clone(),
286            node_type: match &node.kind {
287                NodeKind::Start => "Start".to_string(),
288                NodeKind::End => "End".to_string(),
289                NodeKind::LLM(_) => "LLM".to_string(),
290                NodeKind::Retriever(_) => "Retriever".to_string(),
291                NodeKind::Code(_) => "Code".to_string(),
292                NodeKind::IfElse(_) => "IfElse".to_string(),
293                NodeKind::Tool(_) => "Tool".to_string(),
294                NodeKind::Loop(_) => "Loop".to_string(),
295                NodeKind::TryCatch(_) => "TryCatch".to_string(),
296                NodeKind::SubWorkflow(_) => "SubWorkflow".to_string(),
297                NodeKind::Switch(_) => "Switch".to_string(),
298                NodeKind::Parallel(_) => "Parallel".to_string(),
299                NodeKind::Approval(_) => "Approval".to_string(),
300                NodeKind::Form(_) => "Form".to_string(),
301                NodeKind::Vision(_) => "Vision".to_string(),
302            },
303            cost_usd,
304            expected_executions,
305            components,
306        }
307    }
308
309    /// Estimate number of executions for a node (considering retries, loops, etc.)
310    fn estimate_executions(node: &Node) -> u32 {
311        let mut executions = 1u32;
312
313        // Account for retries
314        if let Some(retry_config) = &node.retry_config {
315            // Assume 30% failure rate requiring retries
316            let avg_retries = (retry_config.max_retries as f32 * 0.3).ceil() as u32;
317            executions += avg_retries;
318        }
319
320        executions
321    }
322
323    /// Estimate token usage for an LLM node
324    fn estimate_llm_tokens(config: &LlmConfig) -> (u64, u64) {
325        // Estimate input tokens based on prompt template length
326        let system_prompt_tokens = config
327            .system_prompt
328            .as_ref()
329            .map(|s| Self::estimate_token_count(s))
330            .unwrap_or(0);
331
332        let user_prompt_tokens = Self::estimate_token_count(&config.prompt_template);
333        let input_tokens = system_prompt_tokens + user_prompt_tokens + 100; // +100 for context
334
335        // Estimate output tokens
336        let output_tokens = config.max_tokens.unwrap_or(1000) as u64;
337
338        (input_tokens, output_tokens)
339    }
340
341    /// Rough estimate of token count from text (1 token ≈ 4 characters)
342    fn estimate_token_count(text: &str) -> u64 {
343        (text.len() as f64 / 4.0).ceil() as u64
344    }
345
346    /// Estimate cost for vector database operations
347    fn estimate_vector_cost(config: &VectorConfig) -> f64 {
348        match config.db_type.to_lowercase().as_str() {
349            "qdrant" => {
350                // Qdrant cloud pricing: ~$0.0001 per 1000 searches
351                (config.top_k as f64 / 1000.0) * 0.0001
352            }
353            "pgvector" => {
354                // PostgreSQL compute cost estimate
355                0.00001 // $0.00001 per query
356            }
357            _ => 0.00001, // Default estimate
358        }
359    }
360}
361
362impl CostEstimate {
363    /// Format cost estimate as a human-readable string
364    pub fn format_summary(&self) -> String {
365        format!(
366            "Total Cost: ${:.4}\n\
367             LLM: ${:.4} | Vector: ${:.4} | Code: ${:.4} | Tools: ${:.4}\n\
368             Tokens: {} input, {} output ({} total)",
369            self.total_usd,
370            self.category_costs.llm_total,
371            self.category_costs.vector_total,
372            self.category_costs.code_total,
373            self.category_costs.tool_total,
374            self.token_estimates.total_input_tokens,
375            self.token_estimates.total_output_tokens,
376            self.token_estimates.total_tokens
377        )
378    }
379
380    /// Get the most expensive nodes
381    pub fn top_expensive_nodes(&self, limit: usize) -> Vec<&NodeCost> {
382        let mut costs: Vec<&NodeCost> = self.node_costs.values().collect();
383        costs.sort_by(|a, b| b.cost_usd.partial_cmp(&a.cost_usd).unwrap());
384        costs.into_iter().take(limit).collect()
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391    use crate::WorkflowBuilder;
392
393    #[test]
394    fn test_model_pricing_openai() {
395        let pricing = ModelPricing::for_model("openai", "gpt-4");
396        assert_eq!(pricing.input_cost_per_million, 30.0);
397        assert_eq!(pricing.output_cost_per_million, 60.0);
398    }
399
400    #[test]
401    fn test_model_pricing_anthropic() {
402        let pricing = ModelPricing::for_model("anthropic", "claude-3-opus");
403        assert_eq!(pricing.input_cost_per_million, 15.0);
404        assert_eq!(pricing.output_cost_per_million, 75.0);
405    }
406
407    #[test]
408    fn test_model_pricing_local() {
409        let pricing = ModelPricing::for_model("ollama", "llama2");
410        assert_eq!(pricing.input_cost_per_million, 0.0);
411        assert_eq!(pricing.output_cost_per_million, 0.0);
412    }
413
414    #[test]
415    fn test_calculate_cost() {
416        let pricing = ModelPricing::for_model("openai", "gpt-3.5-turbo");
417        let cost = pricing.calculate_cost(1000, 500);
418
419        // Expected: (1000/1M * 0.5) + (500/1M * 1.5) = 0.0005 + 0.00075 = 0.00125
420        assert!((cost - 0.00125).abs() < 0.0001);
421    }
422
423    #[test]
424    fn test_estimate_token_count() {
425        let text = "Hello, world!"; // 13 characters
426        let tokens = CostEstimator::estimate_token_count(text);
427        assert_eq!(tokens, 4); // 13 / 4 = 3.25, ceil = 4
428    }
429
430    #[test]
431    fn test_estimate_simple_workflow() {
432        let workflow = WorkflowBuilder::new("Test")
433            .start("Start")
434            .llm(
435                "Generate",
436                LlmConfig {
437                    provider: "openai".to_string(),
438                    model: "gpt-3.5-turbo".to_string(),
439                    system_prompt: Some("You are a helpful assistant".to_string()),
440                    prompt_template: "Say hello".to_string(),
441                    temperature: Some(0.7),
442                    max_tokens: Some(100),
443                    tools: vec![],
444                    images: vec![],
445                    extra_params: serde_json::Value::Null,
446                },
447            )
448            .end("End")
449            .build();
450
451        let estimate = CostEstimator::estimate(&workflow);
452
453        assert!(estimate.total_usd > 0.0);
454        assert!(estimate.category_costs.llm_total > 0.0);
455        assert_eq!(estimate.category_costs.vector_total, 0.0);
456        assert!(estimate.token_estimates.total_tokens > 0);
457    }
458
459    #[test]
460    fn test_estimate_with_vector() {
461        let workflow = WorkflowBuilder::new("RAG")
462            .start("Start")
463            .retriever(
464                "Search",
465                VectorConfig {
466                    db_type: "qdrant".to_string(),
467                    collection: "docs".to_string(),
468                    query: "test query".to_string(),
469                    top_k: 5,
470                    score_threshold: Some(0.7),
471                },
472            )
473            .end("End")
474            .build();
475
476        let estimate = CostEstimator::estimate(&workflow);
477
478        assert!(estimate.category_costs.vector_total > 0.0);
479        assert_eq!(estimate.category_costs.llm_total, 0.0);
480    }
481
482    #[test]
483    fn test_cost_estimate_summary() {
484        let workflow = WorkflowBuilder::new("Test")
485            .start("Start")
486            .llm(
487                "LLM",
488                LlmConfig {
489                    provider: "openai".to_string(),
490                    model: "gpt-4".to_string(),
491                    system_prompt: None,
492                    prompt_template: "test".to_string(),
493                    temperature: None,
494                    max_tokens: Some(500),
495                    tools: vec![],
496                    images: vec![],
497                    extra_params: serde_json::Value::Null,
498                },
499            )
500            .end("End")
501            .build();
502
503        let estimate = CostEstimator::estimate(&workflow);
504        let summary = estimate.format_summary();
505
506        assert!(summary.contains("Total Cost:"));
507        assert!(summary.contains("Tokens:"));
508    }
509
510    #[test]
511    fn test_top_expensive_nodes() {
512        let workflow = WorkflowBuilder::new("Multi-LLM")
513            .start("Start")
514            .llm(
515                "GPT4",
516                LlmConfig {
517                    provider: "openai".to_string(),
518                    model: "gpt-4".to_string(),
519                    system_prompt: None,
520                    prompt_template: "expensive call".to_string(),
521                    temperature: None,
522                    max_tokens: Some(2000),
523                    tools: vec![],
524                    images: vec![],
525                    extra_params: serde_json::Value::Null,
526                },
527            )
528            .llm(
529                "GPT3.5",
530                LlmConfig {
531                    provider: "openai".to_string(),
532                    model: "gpt-3.5-turbo".to_string(),
533                    system_prompt: None,
534                    prompt_template: "cheap call".to_string(),
535                    temperature: None,
536                    max_tokens: Some(100),
537                    tools: vec![],
538                    images: vec![],
539                    extra_params: serde_json::Value::Null,
540                },
541            )
542            .end("End")
543            .build();
544
545        let estimate = CostEstimator::estimate(&workflow);
546        let top = estimate.top_expensive_nodes(1);
547
548        assert_eq!(top.len(), 1);
549        assert_eq!(top[0].node_name, "GPT4");
550    }
551
552    #[test]
553    fn test_estimate_with_retry() {
554        let llm_config = LlmConfig {
555            provider: "openai".to_string(),
556            model: "gpt-4".to_string(),
557            system_prompt: None,
558            prompt_template: "test".to_string(),
559            temperature: None,
560            max_tokens: Some(100),
561            tools: vec![],
562            images: vec![],
563            extra_params: serde_json::Value::Null,
564        };
565
566        let node_with_retry = Node::new("LLM".to_string(), NodeKind::LLM(llm_config)).with_retry(
567            crate::RetryConfig {
568                max_retries: 3,
569                initial_delay_ms: 1000,
570                backoff_multiplier: 2.0,
571                max_delay_ms: 30000,
572            },
573        );
574
575        let cost = CostEstimator::estimate_node_cost(&node_with_retry);
576
577        // Should have higher expected executions due to retries
578        assert!(cost.expected_executions > 1);
579    }
580}