Skip to main content

lean_ctx/core/a2a/
cost_attribution.rs

1use chrono::Utc;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::Mutex;
6
7#[derive(Debug, Clone, Serialize, Deserialize, Default)]
8pub struct CostStore {
9    pub agents: HashMap<String, AgentCost>,
10    pub tools: HashMap<String, ToolCost>,
11    pub sessions: Vec<SessionCostSnapshot>,
12    pub updated_at: Option<String>,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize, Default)]
16pub struct AgentCost {
17    pub agent_id: String,
18    pub agent_type: String,
19    #[serde(default)]
20    pub model_key: Option<String>,
21    #[serde(default)]
22    pub pricing_match: Option<String>,
23    pub total_input_tokens: u64,
24    pub total_output_tokens: u64,
25    pub total_cached_tokens: u64,
26    pub total_calls: u64,
27    pub cost_usd: f64,
28    pub tools_used: HashMap<String, u64>,
29    pub first_seen: Option<String>,
30    pub last_seen: Option<String>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, Default)]
34pub struct ToolCost {
35    pub tool_name: String,
36    pub total_input_tokens: u64,
37    pub total_output_tokens: u64,
38    #[serde(default)]
39    pub total_cached_tokens: u64,
40    pub total_calls: u64,
41    pub avg_input_tokens: f64,
42    pub avg_output_tokens: f64,
43    #[serde(default)]
44    pub avg_cached_tokens: f64,
45    pub cost_usd: f64,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct SessionCostSnapshot {
50    pub timestamp: String,
51    pub agent_id: String,
52    #[serde(default)]
53    pub model_key: Option<String>,
54    pub total_input: u64,
55    pub total_output: u64,
56    #[serde(default)]
57    pub total_cached: u64,
58    pub total_saved: u64,
59    pub cost_usd: f64,
60    pub duration_secs: u64,
61}
62
63pub fn estimate_cost(model_key: Option<&str>, input: u64, output: u64, cached: u64) -> f64 {
64    let pricing = crate::core::gain::model_pricing::ModelPricing::load();
65    let quote = pricing.quote(model_key);
66    quote.cost.estimate_usd(input, output, 0, cached)
67}
68
69static COST_BUFFER: Mutex<Option<CostStore>> = Mutex::new(None);
70
71impl CostStore {
72    pub fn load() -> Self {
73        let mut guard = COST_BUFFER
74            .lock()
75            .unwrap_or_else(std::sync::PoisonError::into_inner);
76        if let Some(ref store) = *guard {
77            return store.clone();
78        }
79
80        let store = load_from_disk();
81        *guard = Some(store.clone());
82        store
83    }
84
85    pub fn record_tool_call(
86        &mut self,
87        agent_id: &str,
88        agent_type: &str,
89        tool_name: &str,
90        input_tokens: u64,
91        output_tokens: u64,
92        cached_tokens: u64,
93    ) {
94        let now = Utc::now().to_rfc3339();
95        let pricing = crate::core::gain::model_pricing::ModelPricing::load();
96        let quote = pricing.quote_from_env_or_agent_type(agent_type);
97        let cost = quote
98            .cost
99            .estimate_usd(input_tokens, output_tokens, 0, cached_tokens);
100
101        let agent = self
102            .agents
103            .entry(agent_id.to_string())
104            .or_insert_with(|| AgentCost {
105                agent_id: agent_id.to_string(),
106                agent_type: agent_type.to_string(),
107                first_seen: Some(now.clone()),
108                ..Default::default()
109            });
110        agent.total_input_tokens += input_tokens;
111        agent.total_output_tokens += output_tokens;
112        agent.total_cached_tokens += cached_tokens;
113        agent.total_calls += 1;
114        agent.cost_usd += cost;
115        agent.last_seen = Some(now.clone());
116        agent.model_key = Some(quote.model_key.clone());
117        agent.pricing_match = Some(format!("{:?}", quote.match_kind));
118        *agent.tools_used.entry(tool_name.to_string()).or_insert(0) += 1;
119
120        let tool = self
121            .tools
122            .entry(tool_name.to_string())
123            .or_insert_with(|| ToolCost {
124                tool_name: tool_name.to_string(),
125                ..Default::default()
126            });
127        tool.total_input_tokens += input_tokens;
128        tool.total_output_tokens += output_tokens;
129        tool.total_cached_tokens += cached_tokens;
130        tool.total_calls += 1;
131        tool.cost_usd += cost;
132        if tool.total_calls > 0 {
133            tool.avg_input_tokens = tool.total_input_tokens as f64 / tool.total_calls as f64;
134            tool.avg_output_tokens = tool.total_output_tokens as f64 / tool.total_calls as f64;
135            tool.avg_cached_tokens = tool.total_cached_tokens as f64 / tool.total_calls as f64;
136        }
137
138        self.updated_at = Some(now);
139    }
140
141    pub fn save(&self) -> std::io::Result<()> {
142        save_to_disk(self)?;
143        let mut guard = COST_BUFFER
144            .lock()
145            .unwrap_or_else(std::sync::PoisonError::into_inner);
146        *guard = Some(self.clone());
147        Ok(())
148    }
149
150    pub fn top_agents(&self, limit: usize) -> Vec<&AgentCost> {
151        let mut agents: Vec<_> = self.agents.values().collect();
152        agents.sort_by(|a, b| {
153            b.cost_usd
154                .partial_cmp(&a.cost_usd)
155                .unwrap_or(std::cmp::Ordering::Equal)
156        });
157        agents.truncate(limit);
158        agents
159    }
160
161    pub fn top_tools(&self, limit: usize) -> Vec<&ToolCost> {
162        let mut tools: Vec<_> = self.tools.values().collect();
163        tools.sort_by(|a, b| {
164            b.cost_usd
165                .partial_cmp(&a.cost_usd)
166                .unwrap_or(std::cmp::Ordering::Equal)
167        });
168        tools.truncate(limit);
169        tools
170    }
171
172    pub fn total_cost(&self) -> f64 {
173        self.agents.values().map(|a| a.cost_usd).sum()
174    }
175
176    pub fn total_tokens(&self) -> (u64, u64, u64) {
177        let input: u64 = self.agents.values().map(|a| a.total_input_tokens).sum();
178        let output: u64 = self.agents.values().map(|a| a.total_output_tokens).sum();
179        let cached: u64 = self.agents.values().map(|a| a.total_cached_tokens).sum();
180        (input, output, cached)
181    }
182
183    pub fn add_session_snapshot(
184        &mut self,
185        agent_id: &str,
186        input: u64,
187        output: u64,
188        saved: u64,
189        duration_secs: u64,
190    ) {
191        let model_key = self
192            .agents
193            .get(agent_id)
194            .and_then(|a| a.model_key.as_deref())
195            .map(std::string::ToString::to_string);
196        let cost = estimate_cost(model_key.as_deref(), input, output, 0);
197        self.sessions.push(SessionCostSnapshot {
198            timestamp: Utc::now().to_rfc3339(),
199            agent_id: agent_id.to_string(),
200            model_key,
201            total_input: input,
202            total_output: output,
203            total_cached: 0,
204            total_saved: saved,
205            cost_usd: cost,
206            duration_secs,
207        });
208
209        if self.sessions.len() > 500 {
210            self.sessions.drain(0..self.sessions.len() - 500);
211        }
212    }
213}
214
215fn cost_store_path() -> Option<PathBuf> {
216    crate::core::data_dir::lean_ctx_data_dir()
217        .ok()
218        .map(|d| d.join("cost_attribution.json"))
219}
220
221fn load_from_disk() -> CostStore {
222    let Some(path) = cost_store_path() else {
223        return CostStore::default();
224    };
225    match std::fs::read_to_string(&path) {
226        Ok(content) => serde_json::from_str(&content).unwrap_or_default(),
227        Err(_) => CostStore::default(),
228    }
229}
230
231fn save_to_disk(store: &CostStore) -> std::io::Result<()> {
232    let Some(path) = cost_store_path() else {
233        return Err(std::io::Error::new(
234            std::io::ErrorKind::NotFound,
235            "no home dir",
236        ));
237    };
238
239    if let Some(parent) = path.parent() {
240        std::fs::create_dir_all(parent)?;
241    }
242
243    let json = serde_json::to_string(store).map_err(std::io::Error::other)?;
244    let tmp = path.with_extension("tmp");
245    std::fs::write(&tmp, &json)?;
246    std::fs::rename(&tmp, &path)?;
247    Ok(())
248}
249
250pub fn format_cost_report(store: &CostStore, limit: usize) -> String {
251    let mut lines = Vec::new();
252    let (total_in, total_out, total_cached) = store.total_tokens();
253    let total_cost = store.total_cost();
254
255    lines.push(format!(
256        "Cost Attribution Report ({} agents, {} tools)",
257        store.agents.len(),
258        store.tools.len()
259    ));
260    lines.push(format!(
261        "Total: {total_in} input + {total_out} output + {total_cached} cached tokens = ${total_cost:.4}"
262    ));
263    if let Ok(m) = std::env::var("LEAN_CTX_MODEL").or_else(|_| std::env::var("LCTX_MODEL")) {
264        if !m.trim().is_empty() {
265            let pricing = crate::core::gain::model_pricing::ModelPricing::load();
266            let q = pricing.quote(Some(&m));
267            lines.push(format!(
268                "Pricing: model={} ({:?}) in=${:.2}/M out=${:.2}/M cacheR=${:.3}/M",
269                q.model_key,
270                q.match_kind,
271                q.cost.input_per_m,
272                q.cost.output_per_m,
273                q.cost.cache_read_per_m
274            ));
275        }
276    }
277    lines.push(String::new());
278
279    let top_agents = store.top_agents(limit);
280    if !top_agents.is_empty() {
281        lines.push("Top Agents by Cost:".to_string());
282        for (i, agent) in top_agents.iter().enumerate() {
283            lines.push(format!(
284                "  {}. {} ({}) — {} calls, {} in + {} out + {} cached tok, ${:.4}{}",
285                i + 1,
286                agent.agent_id,
287                agent.agent_type,
288                agent.total_calls,
289                agent.total_input_tokens,
290                agent.total_output_tokens,
291                agent.total_cached_tokens,
292                agent.cost_usd,
293                agent
294                    .model_key
295                    .as_deref()
296                    .map(|m| format!(" [{m}]"))
297                    .unwrap_or_default()
298            ));
299        }
300        lines.push(String::new());
301    }
302
303    let top_tools = store.top_tools(limit);
304    if !top_tools.is_empty() {
305        lines.push("Top Tools by Cost:".to_string());
306        for (i, tool) in top_tools.iter().enumerate() {
307            lines.push(format!(
308                "  {}. {} — {} calls, avg {:.0} in + {:.0} out + {:.0} cached tok, ${:.4}",
309                i + 1,
310                tool.tool_name,
311                tool.total_calls,
312                tool.avg_input_tokens,
313                tool.avg_output_tokens,
314                tool.avg_cached_tokens,
315                tool.cost_usd
316            ));
317        }
318    }
319
320    lines.join("\n")
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn cost_estimation() {
329        let cost = estimate_cost(Some("fallback-blended"), 1000, 100, 500);
330        assert!(cost > 0.0);
331    }
332
333    #[test]
334    fn record_and_query() {
335        let mut store = CostStore::default();
336        store.record_tool_call("agent-1", "mcp", "ctx_read", 5000, 200, 0);
337        store.record_tool_call("agent-1", "mcp", "ctx_read", 3000, 150, 500);
338        store.record_tool_call("agent-2", "cursor", "ctx_shell", 1000, 100, 0);
339
340        assert_eq!(store.agents.len(), 2);
341        assert_eq!(store.tools.len(), 2);
342
343        let agent1 = &store.agents["agent-1"];
344        assert_eq!(agent1.total_calls, 2);
345        assert_eq!(agent1.total_input_tokens, 8000);
346        assert_eq!(agent1.total_cached_tokens, 500);
347        assert_eq!(*agent1.tools_used.get("ctx_read").unwrap(), 2);
348
349        let top = store.top_agents(5);
350        assert_eq!(top[0].agent_id, "agent-1");
351    }
352
353    #[test]
354    fn format_report() {
355        let mut store = CostStore::default();
356        store.record_tool_call("agent-a", "mcp", "ctx_read", 10000, 500, 1000);
357        store.record_tool_call("agent-b", "cursor", "ctx_shell", 2000, 100, 0);
358
359        let report = format_cost_report(&store, 5);
360        assert!(report.contains("Cost Attribution Report"));
361        assert!(report.contains("agent-a"));
362        assert!(report.contains("ctx_read"));
363    }
364
365    #[test]
366    fn session_snapshots() {
367        let mut store = CostStore::default();
368        store.add_session_snapshot("agent-a", 50000, 5000, 30000, 120);
369        assert_eq!(store.sessions.len(), 1);
370        assert!(store.sessions[0].cost_usd > 0.0);
371    }
372}