Skip to main content

agentforge_core/
cost.rs

1use serde::{Deserialize, Serialize};
2
3/// Token usage for a single trace or model call.
4#[derive(Debug, Clone, Default, Serialize, Deserialize)]
5pub struct TokenUsage {
6    pub input_tokens: u32,
7    pub output_tokens: u32,
8    pub total_tokens: u32,
9}
10
11impl TokenUsage {
12    pub fn new(input: u32, output: u32) -> Self {
13        Self {
14            input_tokens: input,
15            output_tokens: output,
16            total_tokens: input + output,
17        }
18    }
19}
20
21/// USD cost breakdown for a trace.
22#[derive(Debug, Clone, Default, Serialize, Deserialize)]
23pub struct CostBreakdown {
24    /// Estimated total cost in USD.
25    pub total_usd: f64,
26    /// Cost attributed to input tokens.
27    pub input_usd: f64,
28    /// Cost attributed to output tokens.
29    pub output_usd: f64,
30    /// Model that generated this cost.
31    pub model: String,
32    /// Provider (openai, anthropic, etc.)
33    pub provider: String,
34}
35
36/// A cost optimization recommendation: downgrade a model for a scenario subset.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct CostRecommendation {
39    /// Current model in use.
40    pub current_model: String,
41    /// Recommended cheaper model.
42    pub recommended_model: String,
43    /// Estimated savings per 1000 scenarios (USD).
44    pub estimated_savings_usd: f64,
45    /// Fraction of scenarios where the cheaper model scored equivalently.
46    pub equivalent_score_fraction: f64,
47    /// Aggregate score of recommended model on the test scenarios.
48    pub candidate_aggregate_score: f64,
49    /// Aggregate score of current model on the same scenarios.
50    pub current_aggregate_score: f64,
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56
57    #[test]
58    fn token_usage_total() {
59        let u = TokenUsage::new(100, 50);
60        assert_eq!(u.total_tokens, 150);
61    }
62
63    // ── 12 new tests ─────────────────────────────────────────────────────────
64
65    #[test]
66    fn token_usage_zero_inputs() {
67        let u = TokenUsage::new(0, 0);
68        assert_eq!(u.input_tokens, 0);
69        assert_eq!(u.output_tokens, 0);
70        assert_eq!(u.total_tokens, 0);
71    }
72
73    #[test]
74    fn token_usage_default_is_zero() {
75        let u = TokenUsage::default();
76        assert_eq!(u.input_tokens, 0);
77        assert_eq!(u.output_tokens, 0);
78        assert_eq!(u.total_tokens, 0);
79    }
80
81    #[test]
82    fn token_usage_large_values() {
83        let u = TokenUsage::new(1_000_000, 500_000);
84        assert_eq!(u.total_tokens, 1_500_000);
85    }
86
87    #[test]
88    fn token_usage_input_only() {
89        let u = TokenUsage::new(200, 0);
90        assert_eq!(u.total_tokens, 200);
91        assert_eq!(u.output_tokens, 0);
92    }
93
94    #[test]
95    fn token_usage_output_only() {
96        let u = TokenUsage::new(0, 300);
97        assert_eq!(u.total_tokens, 300);
98        assert_eq!(u.input_tokens, 0);
99    }
100
101    #[test]
102    fn token_usage_serde_roundtrip() {
103        let u = TokenUsage::new(128, 64);
104        let json = serde_json::to_string(&u).unwrap();
105        let back: TokenUsage = serde_json::from_str(&json).unwrap();
106        assert_eq!(back.input_tokens, 128);
107        assert_eq!(back.output_tokens, 64);
108        assert_eq!(back.total_tokens, 192);
109    }
110
111    #[test]
112    fn cost_breakdown_default_is_zero() {
113        let cb = CostBreakdown::default();
114        assert_eq!(cb.total_usd, 0.0);
115        assert_eq!(cb.input_usd, 0.0);
116        assert_eq!(cb.output_usd, 0.0);
117        assert!(cb.model.is_empty());
118        assert!(cb.provider.is_empty());
119    }
120
121    #[test]
122    fn cost_breakdown_stores_fields() {
123        let cb = CostBreakdown {
124            total_usd: 0.05,
125            input_usd: 0.02,
126            output_usd: 0.03,
127            model: "gpt-4o".to_string(),
128            provider: "openai".to_string(),
129        };
130        assert!((cb.total_usd - 0.05).abs() < 1e-9);
131        assert_eq!(cb.model, "gpt-4o");
132        assert_eq!(cb.provider, "openai");
133    }
134
135    #[test]
136    fn cost_breakdown_serde_roundtrip() {
137        let cb = CostBreakdown {
138            total_usd: 0.123,
139            input_usd: 0.100,
140            output_usd: 0.023,
141            model: "claude-3".to_string(),
142            provider: "anthropic".to_string(),
143        };
144        let json = serde_json::to_string(&cb).unwrap();
145        let back: CostBreakdown = serde_json::from_str(&json).unwrap();
146        assert!((back.total_usd - 0.123).abs() < 1e-9);
147        assert_eq!(back.provider, "anthropic");
148    }
149
150    #[test]
151    fn cost_recommendation_stores_all_fields() {
152        let rec = CostRecommendation {
153            current_model: "gpt-4o".to_string(),
154            recommended_model: "gpt-4o-mini".to_string(),
155            estimated_savings_usd: 1.5,
156            equivalent_score_fraction: 0.95,
157            candidate_aggregate_score: 0.82,
158            current_aggregate_score: 0.84,
159        };
160        assert_eq!(rec.current_model, "gpt-4o");
161        assert_eq!(rec.recommended_model, "gpt-4o-mini");
162        assert!((rec.equivalent_score_fraction - 0.95).abs() < 1e-9);
163    }
164
165    #[test]
166    fn cost_recommendation_serde_roundtrip() {
167        let rec = CostRecommendation {
168            current_model: "gpt-4o".to_string(),
169            recommended_model: "gpt-4o-mini".to_string(),
170            estimated_savings_usd: 2.0,
171            equivalent_score_fraction: 0.90,
172            candidate_aggregate_score: 0.80,
173            current_aggregate_score: 0.83,
174        };
175        let json = serde_json::to_string(&rec).unwrap();
176        let back: CostRecommendation = serde_json::from_str(&json).unwrap();
177        assert!((back.estimated_savings_usd - 2.0).abs() < 1e-9);
178        assert_eq!(back.recommended_model, "gpt-4o-mini");
179    }
180
181    #[test]
182    fn token_usage_new_sets_fields_correctly() {
183        let u = TokenUsage::new(512, 256);
184        assert_eq!(u.input_tokens, 512);
185        assert_eq!(u.output_tokens, 256);
186        assert_eq!(u.total_tokens, 768);
187    }
188}