1use chrono::Utc;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::Mutex;
6
7#[derive(Debug, Clone, Serialize, Deserialize, Default)]
8pub struct CostStore {
9 pub agents: HashMap<String, AgentCost>,
10 pub tools: HashMap<String, ToolCost>,
11 pub sessions: Vec<SessionCostSnapshot>,
12 pub updated_at: Option<String>,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize, Default)]
16pub struct AgentCost {
17 pub agent_id: String,
18 pub agent_type: String,
19 #[serde(default)]
20 pub model_key: Option<String>,
21 #[serde(default)]
22 pub pricing_match: Option<String>,
23 pub total_input_tokens: u64,
24 pub total_output_tokens: u64,
25 pub total_cached_tokens: u64,
26 pub total_calls: u64,
27 pub cost_usd: f64,
28 pub tools_used: HashMap<String, u64>,
29 pub first_seen: Option<String>,
30 pub last_seen: Option<String>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, Default)]
34pub struct ToolCost {
35 pub tool_name: String,
36 pub total_input_tokens: u64,
37 pub total_output_tokens: u64,
38 #[serde(default)]
39 pub total_cached_tokens: u64,
40 pub total_calls: u64,
41 pub avg_input_tokens: f64,
42 pub avg_output_tokens: f64,
43 #[serde(default)]
44 pub avg_cached_tokens: f64,
45 pub cost_usd: f64,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct SessionCostSnapshot {
50 pub timestamp: String,
51 pub agent_id: String,
52 #[serde(default)]
53 pub model_key: Option<String>,
54 pub total_input: u64,
55 pub total_output: u64,
56 #[serde(default)]
57 pub total_cached: u64,
58 pub total_saved: u64,
59 pub cost_usd: f64,
60 pub duration_secs: u64,
61}
62
63pub fn estimate_cost(model_key: Option<&str>, input: u64, output: u64, cached: u64) -> f64 {
64 let pricing = crate::core::gain::model_pricing::ModelPricing::load();
65 let quote = pricing.quote(model_key);
66 quote.cost.estimate_usd(input, output, 0, cached)
67}
68
69static COST_BUFFER: Mutex<Option<CostStore>> = Mutex::new(None);
70
71impl CostStore {
72 pub fn load() -> Self {
73 let mut guard = COST_BUFFER
74 .lock()
75 .unwrap_or_else(std::sync::PoisonError::into_inner);
76 if let Some(ref store) = *guard {
77 return store.clone();
78 }
79
80 let store = load_from_disk();
81 *guard = Some(store.clone());
82 store
83 }
84
85 pub fn record_tool_call(
86 &mut self,
87 agent_id: &str,
88 agent_type: &str,
89 tool_name: &str,
90 input_tokens: u64,
91 output_tokens: u64,
92 cached_tokens: u64,
93 ) {
94 let now = Utc::now().to_rfc3339();
95 let pricing = crate::core::gain::model_pricing::ModelPricing::load();
96 let quote = pricing.quote_from_env_or_agent_type(agent_type);
97 let cost = quote
98 .cost
99 .estimate_usd(input_tokens, output_tokens, 0, cached_tokens);
100
101 let agent = self
102 .agents
103 .entry(agent_id.to_string())
104 .or_insert_with(|| AgentCost {
105 agent_id: agent_id.to_string(),
106 agent_type: agent_type.to_string(),
107 first_seen: Some(now.clone()),
108 ..Default::default()
109 });
110 agent.total_input_tokens += input_tokens;
111 agent.total_output_tokens += output_tokens;
112 agent.total_cached_tokens += cached_tokens;
113 agent.total_calls += 1;
114 agent.cost_usd += cost;
115 agent.last_seen = Some(now.clone());
116 agent.model_key = Some(quote.model_key.clone());
117 agent.pricing_match = Some(format!("{:?}", quote.match_kind));
118 *agent.tools_used.entry(tool_name.to_string()).or_insert(0) += 1;
119
120 let tool = self
121 .tools
122 .entry(tool_name.to_string())
123 .or_insert_with(|| ToolCost {
124 tool_name: tool_name.to_string(),
125 ..Default::default()
126 });
127 tool.total_input_tokens += input_tokens;
128 tool.total_output_tokens += output_tokens;
129 tool.total_cached_tokens += cached_tokens;
130 tool.total_calls += 1;
131 tool.cost_usd += cost;
132 if tool.total_calls > 0 {
133 tool.avg_input_tokens = tool.total_input_tokens as f64 / tool.total_calls as f64;
134 tool.avg_output_tokens = tool.total_output_tokens as f64 / tool.total_calls as f64;
135 tool.avg_cached_tokens = tool.total_cached_tokens as f64 / tool.total_calls as f64;
136 }
137
138 self.updated_at = Some(now);
139 }
140
141 pub fn save(&self) -> std::io::Result<()> {
142 save_to_disk(self)?;
143 let mut guard = COST_BUFFER
144 .lock()
145 .unwrap_or_else(std::sync::PoisonError::into_inner);
146 *guard = Some(self.clone());
147 Ok(())
148 }
149
150 pub fn top_agents(&self, limit: usize) -> Vec<&AgentCost> {
151 let mut agents: Vec<_> = self.agents.values().collect();
152 agents.sort_by(|a, b| {
153 b.cost_usd
154 .partial_cmp(&a.cost_usd)
155 .unwrap_or(std::cmp::Ordering::Equal)
156 });
157 agents.truncate(limit);
158 agents
159 }
160
161 pub fn top_tools(&self, limit: usize) -> Vec<&ToolCost> {
162 let mut tools: Vec<_> = self.tools.values().collect();
163 tools.sort_by(|a, b| {
164 b.cost_usd
165 .partial_cmp(&a.cost_usd)
166 .unwrap_or(std::cmp::Ordering::Equal)
167 });
168 tools.truncate(limit);
169 tools
170 }
171
172 pub fn total_cost(&self) -> f64 {
173 self.agents.values().map(|a| a.cost_usd).sum()
174 }
175
176 pub fn total_tokens(&self) -> (u64, u64, u64) {
177 let input: u64 = self.agents.values().map(|a| a.total_input_tokens).sum();
178 let output: u64 = self.agents.values().map(|a| a.total_output_tokens).sum();
179 let cached: u64 = self.agents.values().map(|a| a.total_cached_tokens).sum();
180 (input, output, cached)
181 }
182
183 pub fn add_session_snapshot(
184 &mut self,
185 agent_id: &str,
186 input: u64,
187 output: u64,
188 saved: u64,
189 duration_secs: u64,
190 ) {
191 let model_key = self
192 .agents
193 .get(agent_id)
194 .and_then(|a| a.model_key.as_deref())
195 .map(std::string::ToString::to_string);
196 let cost = estimate_cost(model_key.as_deref(), input, output, 0);
197 self.sessions.push(SessionCostSnapshot {
198 timestamp: Utc::now().to_rfc3339(),
199 agent_id: agent_id.to_string(),
200 model_key,
201 total_input: input,
202 total_output: output,
203 total_cached: 0,
204 total_saved: saved,
205 cost_usd: cost,
206 duration_secs,
207 });
208
209 if self.sessions.len() > 500 {
210 self.sessions.drain(0..self.sessions.len() - 500);
211 }
212 }
213}
214
215fn cost_store_path() -> Option<PathBuf> {
216 crate::core::data_dir::lean_ctx_data_dir()
217 .ok()
218 .map(|d| d.join("cost_attribution.json"))
219}
220
221fn load_from_disk() -> CostStore {
222 let Some(path) = cost_store_path() else {
223 return CostStore::default();
224 };
225 match std::fs::read_to_string(&path) {
226 Ok(content) => serde_json::from_str(&content).unwrap_or_default(),
227 Err(_) => CostStore::default(),
228 }
229}
230
231fn save_to_disk(store: &CostStore) -> std::io::Result<()> {
232 let Some(path) = cost_store_path() else {
233 return Err(std::io::Error::new(
234 std::io::ErrorKind::NotFound,
235 "no home dir",
236 ));
237 };
238
239 if let Some(parent) = path.parent() {
240 std::fs::create_dir_all(parent)?;
241 }
242
243 let json = serde_json::to_string(store).map_err(std::io::Error::other)?;
244 let tmp = path.with_extension("tmp");
245 std::fs::write(&tmp, &json)?;
246 std::fs::rename(&tmp, &path)?;
247 Ok(())
248}
249
250pub fn format_cost_report(store: &CostStore, limit: usize) -> String {
251 let mut lines = Vec::new();
252 let (total_in, total_out, total_cached) = store.total_tokens();
253 let total_cost = store.total_cost();
254
255 lines.push(format!(
256 "Cost Attribution Report ({} agents, {} tools)",
257 store.agents.len(),
258 store.tools.len()
259 ));
260 lines.push(format!(
261 "Total: {total_in} input + {total_out} output + {total_cached} cached tokens = ${total_cost:.4}"
262 ));
263 if let Ok(m) = std::env::var("LEAN_CTX_MODEL").or_else(|_| std::env::var("LCTX_MODEL")) {
264 if !m.trim().is_empty() {
265 let pricing = crate::core::gain::model_pricing::ModelPricing::load();
266 let q = pricing.quote(Some(&m));
267 lines.push(format!(
268 "Pricing: model={} ({:?}) in=${:.2}/M out=${:.2}/M cacheR=${:.3}/M",
269 q.model_key,
270 q.match_kind,
271 q.cost.input_per_m,
272 q.cost.output_per_m,
273 q.cost.cache_read_per_m
274 ));
275 }
276 }
277 lines.push(String::new());
278
279 let top_agents = store.top_agents(limit);
280 if !top_agents.is_empty() {
281 lines.push("Top Agents by Cost:".to_string());
282 for (i, agent) in top_agents.iter().enumerate() {
283 lines.push(format!(
284 " {}. {} ({}) — {} calls, {} in + {} out + {} cached tok, ${:.4}{}",
285 i + 1,
286 agent.agent_id,
287 agent.agent_type,
288 agent.total_calls,
289 agent.total_input_tokens,
290 agent.total_output_tokens,
291 agent.total_cached_tokens,
292 agent.cost_usd,
293 agent
294 .model_key
295 .as_deref()
296 .map(|m| format!(" [{m}]"))
297 .unwrap_or_default()
298 ));
299 }
300 lines.push(String::new());
301 }
302
303 let top_tools = store.top_tools(limit);
304 if !top_tools.is_empty() {
305 lines.push("Top Tools by Cost:".to_string());
306 for (i, tool) in top_tools.iter().enumerate() {
307 lines.push(format!(
308 " {}. {} — {} calls, avg {:.0} in + {:.0} out + {:.0} cached tok, ${:.4}",
309 i + 1,
310 tool.tool_name,
311 tool.total_calls,
312 tool.avg_input_tokens,
313 tool.avg_output_tokens,
314 tool.avg_cached_tokens,
315 tool.cost_usd
316 ));
317 }
318 }
319
320 lines.join("\n")
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn cost_estimation() {
329 let cost = estimate_cost(Some("fallback-blended"), 1000, 100, 500);
330 assert!(cost > 0.0);
331 }
332
333 #[test]
334 fn record_and_query() {
335 let mut store = CostStore::default();
336 store.record_tool_call("agent-1", "mcp", "ctx_read", 5000, 200, 0);
337 store.record_tool_call("agent-1", "mcp", "ctx_read", 3000, 150, 500);
338 store.record_tool_call("agent-2", "cursor", "ctx_shell", 1000, 100, 0);
339
340 assert_eq!(store.agents.len(), 2);
341 assert_eq!(store.tools.len(), 2);
342
343 let agent1 = &store.agents["agent-1"];
344 assert_eq!(agent1.total_calls, 2);
345 assert_eq!(agent1.total_input_tokens, 8000);
346 assert_eq!(agent1.total_cached_tokens, 500);
347 assert_eq!(*agent1.tools_used.get("ctx_read").unwrap(), 2);
348
349 let top = store.top_agents(5);
350 assert_eq!(top[0].agent_id, "agent-1");
351 }
352
353 #[test]
354 fn format_report() {
355 let mut store = CostStore::default();
356 store.record_tool_call("agent-a", "mcp", "ctx_read", 10000, 500, 1000);
357 store.record_tool_call("agent-b", "cursor", "ctx_shell", 2000, 100, 0);
358
359 let report = format_cost_report(&store, 5);
360 assert!(report.contains("Cost Attribution Report"));
361 assert!(report.contains("agent-a"));
362 assert!(report.contains("ctx_read"));
363 }
364
365 #[test]
366 fn session_snapshots() {
367 let mut store = CostStore::default();
368 store.add_session_snapshot("agent-a", 50000, 5000, 30000, 120);
369 assert_eq!(store.sessions.len(), 1);
370 assert!(store.sessions[0].cost_usd > 0.0);
371 }
372}