Skip to main content

lean_ctx/core/a2a/
cost_attribution.rs

1use chrono::Utc;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::Mutex;
6
7#[derive(Debug, Clone, Serialize, Deserialize, Default)]
8pub struct CostStore {
9    pub agents: HashMap<String, AgentCost>,
10    pub tools: HashMap<String, ToolCost>,
11    pub sessions: Vec<SessionCostSnapshot>,
12    pub updated_at: Option<String>,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize, Default)]
16pub struct AgentCost {
17    pub agent_id: String,
18    pub agent_type: String,
19    #[serde(default)]
20    pub model_key: Option<String>,
21    #[serde(default)]
22    pub pricing_match: Option<String>,
23    pub total_input_tokens: u64,
24    pub total_output_tokens: u64,
25    pub total_cached_tokens: u64,
26    pub total_calls: u64,
27    pub cost_usd: f64,
28    pub tools_used: HashMap<String, u64>,
29    pub first_seen: Option<String>,
30    pub last_seen: Option<String>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, Default)]
34pub struct ToolCost {
35    pub tool_name: String,
36    pub total_input_tokens: u64,
37    pub total_output_tokens: u64,
38    pub total_calls: u64,
39    pub avg_input_tokens: f64,
40    pub avg_output_tokens: f64,
41    pub cost_usd: f64,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct SessionCostSnapshot {
46    pub timestamp: String,
47    pub agent_id: String,
48    #[serde(default)]
49    pub model_key: Option<String>,
50    pub total_input: u64,
51    pub total_output: u64,
52    pub total_saved: u64,
53    pub cost_usd: f64,
54    pub duration_secs: u64,
55}
56
57pub fn estimate_cost(model_key: Option<&str>, input: u64, output: u64, cached: u64) -> f64 {
58    let pricing = crate::core::gain::model_pricing::ModelPricing::load();
59    let quote = pricing.quote(model_key);
60    quote.cost.estimate_usd(input, output, 0, cached)
61}
62
63static COST_BUFFER: Mutex<Option<CostStore>> = Mutex::new(None);
64
65impl CostStore {
66    pub fn load() -> Self {
67        let mut guard = COST_BUFFER.lock().unwrap_or_else(|e| e.into_inner());
68        if let Some(ref store) = *guard {
69            return store.clone();
70        }
71
72        let store = load_from_disk();
73        *guard = Some(store.clone());
74        store
75    }
76
77    pub fn record_tool_call(
78        &mut self,
79        agent_id: &str,
80        agent_type: &str,
81        tool_name: &str,
82        input_tokens: u64,
83        output_tokens: u64,
84    ) {
85        let now = Utc::now().to_rfc3339();
86        let pricing = crate::core::gain::model_pricing::ModelPricing::load();
87        let quote = pricing.quote_from_env_or_agent_type(agent_type);
88        let cost = quote.cost.estimate_usd(input_tokens, output_tokens, 0, 0);
89
90        let agent = self
91            .agents
92            .entry(agent_id.to_string())
93            .or_insert_with(|| AgentCost {
94                agent_id: agent_id.to_string(),
95                agent_type: agent_type.to_string(),
96                first_seen: Some(now.clone()),
97                ..Default::default()
98            });
99        agent.total_input_tokens += input_tokens;
100        agent.total_output_tokens += output_tokens;
101        agent.total_calls += 1;
102        agent.cost_usd += cost;
103        agent.last_seen = Some(now.clone());
104        agent.model_key = Some(quote.model_key.clone());
105        agent.pricing_match = Some(format!("{:?}", quote.match_kind));
106        *agent.tools_used.entry(tool_name.to_string()).or_insert(0) += 1;
107
108        let tool = self
109            .tools
110            .entry(tool_name.to_string())
111            .or_insert_with(|| ToolCost {
112                tool_name: tool_name.to_string(),
113                ..Default::default()
114            });
115        tool.total_input_tokens += input_tokens;
116        tool.total_output_tokens += output_tokens;
117        tool.total_calls += 1;
118        tool.cost_usd += cost;
119        if tool.total_calls > 0 {
120            tool.avg_input_tokens = tool.total_input_tokens as f64 / tool.total_calls as f64;
121            tool.avg_output_tokens = tool.total_output_tokens as f64 / tool.total_calls as f64;
122        }
123
124        self.updated_at = Some(now);
125    }
126
127    pub fn save(&self) -> std::io::Result<()> {
128        save_to_disk(self)?;
129        let mut guard = COST_BUFFER.lock().unwrap_or_else(|e| e.into_inner());
130        *guard = Some(self.clone());
131        Ok(())
132    }
133
134    pub fn top_agents(&self, limit: usize) -> Vec<&AgentCost> {
135        let mut agents: Vec<_> = self.agents.values().collect();
136        agents.sort_by(|a, b| {
137            b.cost_usd
138                .partial_cmp(&a.cost_usd)
139                .unwrap_or(std::cmp::Ordering::Equal)
140        });
141        agents.truncate(limit);
142        agents
143    }
144
145    pub fn top_tools(&self, limit: usize) -> Vec<&ToolCost> {
146        let mut tools: Vec<_> = self.tools.values().collect();
147        tools.sort_by(|a, b| {
148            b.cost_usd
149                .partial_cmp(&a.cost_usd)
150                .unwrap_or(std::cmp::Ordering::Equal)
151        });
152        tools.truncate(limit);
153        tools
154    }
155
156    pub fn total_cost(&self) -> f64 {
157        self.agents.values().map(|a| a.cost_usd).sum()
158    }
159
160    pub fn total_tokens(&self) -> (u64, u64) {
161        let input: u64 = self.agents.values().map(|a| a.total_input_tokens).sum();
162        let output: u64 = self.agents.values().map(|a| a.total_output_tokens).sum();
163        (input, output)
164    }
165
166    pub fn add_session_snapshot(
167        &mut self,
168        agent_id: &str,
169        input: u64,
170        output: u64,
171        saved: u64,
172        duration_secs: u64,
173    ) {
174        let model_key = self
175            .agents
176            .get(agent_id)
177            .and_then(|a| a.model_key.as_deref())
178            .map(|s| s.to_string());
179        let cost = estimate_cost(model_key.as_deref(), input, output, 0);
180        self.sessions.push(SessionCostSnapshot {
181            timestamp: Utc::now().to_rfc3339(),
182            agent_id: agent_id.to_string(),
183            model_key,
184            total_input: input,
185            total_output: output,
186            total_saved: saved,
187            cost_usd: cost,
188            duration_secs,
189        });
190
191        if self.sessions.len() > 500 {
192            self.sessions.drain(0..self.sessions.len() - 500);
193        }
194    }
195}
196
197fn cost_store_path() -> Option<PathBuf> {
198    crate::core::data_dir::lean_ctx_data_dir()
199        .ok()
200        .map(|d| d.join("cost_attribution.json"))
201}
202
203fn load_from_disk() -> CostStore {
204    let path = match cost_store_path() {
205        Some(p) => p,
206        None => return CostStore::default(),
207    };
208    match std::fs::read_to_string(&path) {
209        Ok(content) => serde_json::from_str(&content).unwrap_or_default(),
210        Err(_) => CostStore::default(),
211    }
212}
213
214fn save_to_disk(store: &CostStore) -> std::io::Result<()> {
215    let path = match cost_store_path() {
216        Some(p) => p,
217        None => {
218            return Err(std::io::Error::new(
219                std::io::ErrorKind::NotFound,
220                "no home dir",
221            ))
222        }
223    };
224
225    if let Some(parent) = path.parent() {
226        std::fs::create_dir_all(parent)?;
227    }
228
229    let json = serde_json::to_string(store).map_err(std::io::Error::other)?;
230    let tmp = path.with_extension("tmp");
231    std::fs::write(&tmp, &json)?;
232    std::fs::rename(&tmp, &path)?;
233    Ok(())
234}
235
236pub fn format_cost_report(store: &CostStore, limit: usize) -> String {
237    let mut lines = Vec::new();
238    let (total_in, total_out) = store.total_tokens();
239    let total_cost = store.total_cost();
240
241    lines.push(format!(
242        "Cost Attribution Report ({} agents, {} tools)",
243        store.agents.len(),
244        store.tools.len()
245    ));
246    lines.push(format!(
247        "Total: {total_in} input + {total_out} output tokens = ${total_cost:.4}"
248    ));
249    if let Ok(m) = std::env::var("LEAN_CTX_MODEL").or_else(|_| std::env::var("LCTX_MODEL")) {
250        if !m.trim().is_empty() {
251            let pricing = crate::core::gain::model_pricing::ModelPricing::load();
252            let q = pricing.quote(Some(&m));
253            lines.push(format!(
254                "Pricing: model={} ({:?}) in=${:.2}/M out=${:.2}/M cacheR=${:.3}/M",
255                q.model_key,
256                q.match_kind,
257                q.cost.input_per_m,
258                q.cost.output_per_m,
259                q.cost.cache_read_per_m
260            ));
261        }
262    }
263    lines.push(String::new());
264
265    let top_agents = store.top_agents(limit);
266    if !top_agents.is_empty() {
267        lines.push("Top Agents by Cost:".to_string());
268        for (i, agent) in top_agents.iter().enumerate() {
269            lines.push(format!(
270                "  {}. {} ({}) — {} calls, {} in + {} out tok, ${:.4}{}",
271                i + 1,
272                agent.agent_id,
273                agent.agent_type,
274                agent.total_calls,
275                agent.total_input_tokens,
276                agent.total_output_tokens,
277                agent.cost_usd,
278                agent
279                    .model_key
280                    .as_deref()
281                    .map(|m| format!(" [{m}]"))
282                    .unwrap_or_default()
283            ));
284        }
285        lines.push(String::new());
286    }
287
288    let top_tools = store.top_tools(limit);
289    if !top_tools.is_empty() {
290        lines.push("Top Tools by Cost:".to_string());
291        for (i, tool) in top_tools.iter().enumerate() {
292            lines.push(format!(
293                "  {}. {} — {} calls, avg {:.0} in + {:.0} out tok, ${:.4}",
294                i + 1,
295                tool.tool_name,
296                tool.total_calls,
297                tool.avg_input_tokens,
298                tool.avg_output_tokens,
299                tool.cost_usd
300            ));
301        }
302    }
303
304    lines.join("\n")
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn cost_estimation() {
313        let cost = estimate_cost(Some("fallback-blended"), 1000, 100, 500);
314        assert!(cost > 0.0);
315    }
316
317    #[test]
318    fn record_and_query() {
319        let mut store = CostStore::default();
320        store.record_tool_call("agent-1", "mcp", "ctx_read", 5000, 200);
321        store.record_tool_call("agent-1", "mcp", "ctx_read", 3000, 150);
322        store.record_tool_call("agent-2", "cursor", "ctx_shell", 1000, 100);
323
324        assert_eq!(store.agents.len(), 2);
325        assert_eq!(store.tools.len(), 2);
326
327        let agent1 = &store.agents["agent-1"];
328        assert_eq!(agent1.total_calls, 2);
329        assert_eq!(agent1.total_input_tokens, 8000);
330        assert_eq!(*agent1.tools_used.get("ctx_read").unwrap(), 2);
331
332        let top = store.top_agents(5);
333        assert_eq!(top[0].agent_id, "agent-1");
334    }
335
336    #[test]
337    fn format_report() {
338        let mut store = CostStore::default();
339        store.record_tool_call("agent-a", "mcp", "ctx_read", 10000, 500);
340        store.record_tool_call("agent-b", "cursor", "ctx_shell", 2000, 100);
341
342        let report = format_cost_report(&store, 5);
343        assert!(report.contains("Cost Attribution Report"));
344        assert!(report.contains("agent-a"));
345        assert!(report.contains("ctx_read"));
346    }
347
348    #[test]
349    fn session_snapshots() {
350        let mut store = CostStore::default();
351        store.add_session_snapshot("agent-a", 50000, 5000, 30000, 120);
352        assert_eq!(store.sessions.len(), 1);
353        assert!(store.sessions[0].cost_usd > 0.0);
354    }
355}