1use crate::{LlmConfig, Node, NodeKind, VectorConfig, Workflow};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct CostEstimate {
13 pub total_usd: f64,
15
16 pub node_costs: HashMap<String, NodeCost>,
18
19 pub category_costs: CategoryCosts,
21
22 pub token_estimates: TokenEstimates,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct NodeCost {
29 pub node_name: String,
31
32 pub node_type: String,
34
35 pub cost_usd: f64,
37
38 pub expected_executions: u32,
40
41 pub components: Vec<CostComponent>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct CostComponent {
48 pub name: String,
50
51 pub cost_usd: f64,
53
54 pub quantity: f64,
56
57 pub unit: String,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct CategoryCosts {
64 pub llm_total: f64,
66
67 pub vector_total: f64,
69
70 pub code_total: f64,
72
73 pub tool_total: f64,
75
76 pub other_total: f64,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct TokenEstimates {
83 pub total_input_tokens: u64,
85
86 pub total_output_tokens: u64,
88
89 pub total_tokens: u64,
91}
92
93#[derive(Debug, Clone)]
95pub struct ModelPricing {
96 pub input_cost_per_million: f64,
98
99 pub output_cost_per_million: f64,
101}
102
103impl ModelPricing {
104 pub fn for_model(provider: &str, model: &str) -> Self {
106 match (
107 provider.to_lowercase().as_str(),
108 model.to_lowercase().as_str(),
109 ) {
110 ("openai", m) if m.contains("gpt-4-turbo") => Self {
112 input_cost_per_million: 10.0,
113 output_cost_per_million: 30.0,
114 },
115 ("openai", m) if m.contains("gpt-4") => Self {
116 input_cost_per_million: 30.0,
117 output_cost_per_million: 60.0,
118 },
119 ("openai", m) if m.contains("gpt-3.5-turbo") => Self {
121 input_cost_per_million: 0.5,
122 output_cost_per_million: 1.5,
123 },
124 ("anthropic", m) if m.contains("claude-3-opus") => Self {
126 input_cost_per_million: 15.0,
127 output_cost_per_million: 75.0,
128 },
129 ("anthropic", m) if m.contains("claude-3-sonnet") => Self {
130 input_cost_per_million: 3.0,
131 output_cost_per_million: 15.0,
132 },
133 ("anthropic", m) if m.contains("claude-3-haiku") => Self {
134 input_cost_per_million: 0.25,
135 output_cost_per_million: 1.25,
136 },
137 ("ollama", _) | ("local", _) => Self {
139 input_cost_per_million: 0.0,
140 output_cost_per_million: 0.0,
141 },
142 _ => Self {
144 input_cost_per_million: 5.0,
145 output_cost_per_million: 15.0,
146 },
147 }
148 }
149
150 pub fn calculate_cost(&self, input_tokens: u64, output_tokens: u64) -> f64 {
152 let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input_cost_per_million;
153 let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output_cost_per_million;
154 input_cost + output_cost
155 }
156}
157
158pub struct CostEstimator;
160
161impl CostEstimator {
162 pub fn estimate(workflow: &Workflow) -> CostEstimate {
164 let mut node_costs = HashMap::new();
165 let mut llm_total = 0.0;
166 let mut vector_total = 0.0;
167 let mut code_total = 0.0;
168 let mut tool_total = 0.0;
169 let mut other_total = 0.0;
170 let mut total_input_tokens = 0u64;
171 let mut total_output_tokens = 0u64;
172
173 for node in &workflow.nodes {
174 let node_cost = Self::estimate_node_cost(node);
175
176 match &node.kind {
178 NodeKind::LLM(_) => llm_total += node_cost.cost_usd,
179 NodeKind::Retriever(_) => vector_total += node_cost.cost_usd,
180 NodeKind::Code(_) => code_total += node_cost.cost_usd,
181 NodeKind::Tool(_) => tool_total += node_cost.cost_usd,
182 _ => other_total += node_cost.cost_usd,
183 }
184
185 for component in &node_cost.components {
187 match component.name.as_str() {
188 "input_tokens" => total_input_tokens += component.quantity as u64,
189 "output_tokens" => total_output_tokens += component.quantity as u64,
190 _ => {}
191 }
192 }
193
194 node_costs.insert(node.id.to_string(), node_cost);
195 }
196
197 let total_usd = llm_total + vector_total + code_total + tool_total + other_total;
198
199 CostEstimate {
200 total_usd,
201 node_costs,
202 category_costs: CategoryCosts {
203 llm_total,
204 vector_total,
205 code_total,
206 tool_total,
207 other_total,
208 },
209 token_estimates: TokenEstimates {
210 total_input_tokens,
211 total_output_tokens,
212 total_tokens: total_input_tokens + total_output_tokens,
213 },
214 }
215 }
216
217 fn estimate_node_cost(node: &Node) -> NodeCost {
219 let mut components = Vec::new();
220 let expected_executions = Self::estimate_executions(node);
221
222 let cost_usd = match &node.kind {
223 NodeKind::LLM(config) => {
224 let (input_tokens, output_tokens) = Self::estimate_llm_tokens(config);
225 let pricing = ModelPricing::for_model(&config.provider, &config.model);
226
227 let input_cost = pricing.calculate_cost(input_tokens, 0);
228 let output_cost = pricing.calculate_cost(0, output_tokens);
229
230 components.push(CostComponent {
231 name: "input_tokens".to_string(),
232 cost_usd: input_cost,
233 quantity: input_tokens as f64,
234 unit: "tokens".to_string(),
235 });
236
237 components.push(CostComponent {
238 name: "output_tokens".to_string(),
239 cost_usd: output_cost,
240 quantity: output_tokens as f64,
241 unit: "tokens".to_string(),
242 });
243
244 (input_cost + output_cost) * expected_executions as f64
245 }
246 NodeKind::Retriever(config) => {
247 let vector_cost = Self::estimate_vector_cost(config);
248 components.push(CostComponent {
249 name: "vector_search".to_string(),
250 cost_usd: vector_cost,
251 quantity: config.top_k as f64,
252 unit: "results".to_string(),
253 });
254 vector_cost * expected_executions as f64
255 }
256 NodeKind::Code(_) => {
257 let compute_cost = 0.0001; components.push(CostComponent {
260 name: "compute".to_string(),
261 cost_usd: compute_cost,
262 quantity: 1.0,
263 unit: "execution".to_string(),
264 });
265 compute_cost * expected_executions as f64
266 }
267 NodeKind::Tool(_) => {
268 let api_cost = 0.001; components.push(CostComponent {
271 name: "api_call".to_string(),
272 cost_usd: api_cost,
273 quantity: 1.0,
274 unit: "call".to_string(),
275 });
276 api_cost * expected_executions as f64
277 }
278 _ => {
279 0.0
281 }
282 };
283
284 NodeCost {
285 node_name: node.name.clone(),
286 node_type: match &node.kind {
287 NodeKind::Start => "Start".to_string(),
288 NodeKind::End => "End".to_string(),
289 NodeKind::LLM(_) => "LLM".to_string(),
290 NodeKind::Retriever(_) => "Retriever".to_string(),
291 NodeKind::Code(_) => "Code".to_string(),
292 NodeKind::IfElse(_) => "IfElse".to_string(),
293 NodeKind::Tool(_) => "Tool".to_string(),
294 NodeKind::Loop(_) => "Loop".to_string(),
295 NodeKind::TryCatch(_) => "TryCatch".to_string(),
296 NodeKind::SubWorkflow(_) => "SubWorkflow".to_string(),
297 NodeKind::Switch(_) => "Switch".to_string(),
298 NodeKind::Parallel(_) => "Parallel".to_string(),
299 NodeKind::Approval(_) => "Approval".to_string(),
300 NodeKind::Form(_) => "Form".to_string(),
301 NodeKind::Vision(_) => "Vision".to_string(),
302 },
303 cost_usd,
304 expected_executions,
305 components,
306 }
307 }
308
309 fn estimate_executions(node: &Node) -> u32 {
311 let mut executions = 1u32;
312
313 if let Some(retry_config) = &node.retry_config {
315 let avg_retries = (retry_config.max_retries as f32 * 0.3).ceil() as u32;
317 executions += avg_retries;
318 }
319
320 executions
321 }
322
323 fn estimate_llm_tokens(config: &LlmConfig) -> (u64, u64) {
325 let system_prompt_tokens = config
327 .system_prompt
328 .as_ref()
329 .map(|s| Self::estimate_token_count(s))
330 .unwrap_or(0);
331
332 let user_prompt_tokens = Self::estimate_token_count(&config.prompt_template);
333 let input_tokens = system_prompt_tokens + user_prompt_tokens + 100; let output_tokens = config.max_tokens.unwrap_or(1000) as u64;
337
338 (input_tokens, output_tokens)
339 }
340
341 fn estimate_token_count(text: &str) -> u64 {
343 (text.len() as f64 / 4.0).ceil() as u64
344 }
345
346 fn estimate_vector_cost(config: &VectorConfig) -> f64 {
348 match config.db_type.to_lowercase().as_str() {
349 "qdrant" => {
350 (config.top_k as f64 / 1000.0) * 0.0001
352 }
353 "pgvector" => {
354 0.00001 }
357 _ => 0.00001, }
359 }
360}
361
362impl CostEstimate {
363 pub fn format_summary(&self) -> String {
365 format!(
366 "Total Cost: ${:.4}\n\
367 LLM: ${:.4} | Vector: ${:.4} | Code: ${:.4} | Tools: ${:.4}\n\
368 Tokens: {} input, {} output ({} total)",
369 self.total_usd,
370 self.category_costs.llm_total,
371 self.category_costs.vector_total,
372 self.category_costs.code_total,
373 self.category_costs.tool_total,
374 self.token_estimates.total_input_tokens,
375 self.token_estimates.total_output_tokens,
376 self.token_estimates.total_tokens
377 )
378 }
379
380 pub fn top_expensive_nodes(&self, limit: usize) -> Vec<&NodeCost> {
382 let mut costs: Vec<&NodeCost> = self.node_costs.values().collect();
383 costs.sort_by(|a, b| b.cost_usd.partial_cmp(&a.cost_usd).unwrap());
384 costs.into_iter().take(limit).collect()
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use crate::WorkflowBuilder;
392
393 #[test]
394 fn test_model_pricing_openai() {
395 let pricing = ModelPricing::for_model("openai", "gpt-4");
396 assert_eq!(pricing.input_cost_per_million, 30.0);
397 assert_eq!(pricing.output_cost_per_million, 60.0);
398 }
399
400 #[test]
401 fn test_model_pricing_anthropic() {
402 let pricing = ModelPricing::for_model("anthropic", "claude-3-opus");
403 assert_eq!(pricing.input_cost_per_million, 15.0);
404 assert_eq!(pricing.output_cost_per_million, 75.0);
405 }
406
407 #[test]
408 fn test_model_pricing_local() {
409 let pricing = ModelPricing::for_model("ollama", "llama2");
410 assert_eq!(pricing.input_cost_per_million, 0.0);
411 assert_eq!(pricing.output_cost_per_million, 0.0);
412 }
413
414 #[test]
415 fn test_calculate_cost() {
416 let pricing = ModelPricing::for_model("openai", "gpt-3.5-turbo");
417 let cost = pricing.calculate_cost(1000, 500);
418
419 assert!((cost - 0.00125).abs() < 0.0001);
421 }
422
423 #[test]
424 fn test_estimate_token_count() {
425 let text = "Hello, world!"; let tokens = CostEstimator::estimate_token_count(text);
427 assert_eq!(tokens, 4); }
429
430 #[test]
431 fn test_estimate_simple_workflow() {
432 let workflow = WorkflowBuilder::new("Test")
433 .start("Start")
434 .llm(
435 "Generate",
436 LlmConfig {
437 provider: "openai".to_string(),
438 model: "gpt-3.5-turbo".to_string(),
439 system_prompt: Some("You are a helpful assistant".to_string()),
440 prompt_template: "Say hello".to_string(),
441 temperature: Some(0.7),
442 max_tokens: Some(100),
443 tools: vec![],
444 images: vec![],
445 extra_params: serde_json::Value::Null,
446 },
447 )
448 .end("End")
449 .build();
450
451 let estimate = CostEstimator::estimate(&workflow);
452
453 assert!(estimate.total_usd > 0.0);
454 assert!(estimate.category_costs.llm_total > 0.0);
455 assert_eq!(estimate.category_costs.vector_total, 0.0);
456 assert!(estimate.token_estimates.total_tokens > 0);
457 }
458
459 #[test]
460 fn test_estimate_with_vector() {
461 let workflow = WorkflowBuilder::new("RAG")
462 .start("Start")
463 .retriever(
464 "Search",
465 VectorConfig {
466 db_type: "qdrant".to_string(),
467 collection: "docs".to_string(),
468 query: "test query".to_string(),
469 top_k: 5,
470 score_threshold: Some(0.7),
471 },
472 )
473 .end("End")
474 .build();
475
476 let estimate = CostEstimator::estimate(&workflow);
477
478 assert!(estimate.category_costs.vector_total > 0.0);
479 assert_eq!(estimate.category_costs.llm_total, 0.0);
480 }
481
482 #[test]
483 fn test_cost_estimate_summary() {
484 let workflow = WorkflowBuilder::new("Test")
485 .start("Start")
486 .llm(
487 "LLM",
488 LlmConfig {
489 provider: "openai".to_string(),
490 model: "gpt-4".to_string(),
491 system_prompt: None,
492 prompt_template: "test".to_string(),
493 temperature: None,
494 max_tokens: Some(500),
495 tools: vec![],
496 images: vec![],
497 extra_params: serde_json::Value::Null,
498 },
499 )
500 .end("End")
501 .build();
502
503 let estimate = CostEstimator::estimate(&workflow);
504 let summary = estimate.format_summary();
505
506 assert!(summary.contains("Total Cost:"));
507 assert!(summary.contains("Tokens:"));
508 }
509
510 #[test]
511 fn test_top_expensive_nodes() {
512 let workflow = WorkflowBuilder::new("Multi-LLM")
513 .start("Start")
514 .llm(
515 "GPT4",
516 LlmConfig {
517 provider: "openai".to_string(),
518 model: "gpt-4".to_string(),
519 system_prompt: None,
520 prompt_template: "expensive call".to_string(),
521 temperature: None,
522 max_tokens: Some(2000),
523 tools: vec![],
524 images: vec![],
525 extra_params: serde_json::Value::Null,
526 },
527 )
528 .llm(
529 "GPT3.5",
530 LlmConfig {
531 provider: "openai".to_string(),
532 model: "gpt-3.5-turbo".to_string(),
533 system_prompt: None,
534 prompt_template: "cheap call".to_string(),
535 temperature: None,
536 max_tokens: Some(100),
537 tools: vec![],
538 images: vec![],
539 extra_params: serde_json::Value::Null,
540 },
541 )
542 .end("End")
543 .build();
544
545 let estimate = CostEstimator::estimate(&workflow);
546 let top = estimate.top_expensive_nodes(1);
547
548 assert_eq!(top.len(), 1);
549 assert_eq!(top[0].node_name, "GPT4");
550 }
551
552 #[test]
553 fn test_estimate_with_retry() {
554 let llm_config = LlmConfig {
555 provider: "openai".to_string(),
556 model: "gpt-4".to_string(),
557 system_prompt: None,
558 prompt_template: "test".to_string(),
559 temperature: None,
560 max_tokens: Some(100),
561 tools: vec![],
562 images: vec![],
563 extra_params: serde_json::Value::Null,
564 };
565
566 let node_with_retry = Node::new("LLM".to_string(), NodeKind::LLM(llm_config)).with_retry(
567 crate::RetryConfig {
568 max_retries: 3,
569 initial_delay_ms: 1000,
570 backoff_multiplier: 2.0,
571 max_delay_ms: 30000,
572 },
573 );
574
575 let cost = CostEstimator::estimate_node_cost(&node_with_retry);
576
577 assert!(cost.expected_executions > 1);
579 }
580}