1use chrono::Utc;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::Mutex;
6
7#[derive(Debug, Clone, Serialize, Deserialize, Default)]
8pub struct CostStore {
9 pub agents: HashMap<String, AgentCost>,
10 pub tools: HashMap<String, ToolCost>,
11 pub sessions: Vec<SessionCostSnapshot>,
12 pub updated_at: Option<String>,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize, Default)]
16pub struct AgentCost {
17 pub agent_id: String,
18 pub agent_type: String,
19 #[serde(default)]
20 pub model_key: Option<String>,
21 #[serde(default)]
22 pub pricing_match: Option<String>,
23 pub total_input_tokens: u64,
24 pub total_output_tokens: u64,
25 pub total_cached_tokens: u64,
26 pub total_calls: u64,
27 pub cost_usd: f64,
28 pub tools_used: HashMap<String, u64>,
29 pub first_seen: Option<String>,
30 pub last_seen: Option<String>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, Default)]
34pub struct ToolCost {
35 pub tool_name: String,
36 pub total_input_tokens: u64,
37 pub total_output_tokens: u64,
38 pub total_calls: u64,
39 pub avg_input_tokens: f64,
40 pub avg_output_tokens: f64,
41 pub cost_usd: f64,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct SessionCostSnapshot {
46 pub timestamp: String,
47 pub agent_id: String,
48 #[serde(default)]
49 pub model_key: Option<String>,
50 pub total_input: u64,
51 pub total_output: u64,
52 pub total_saved: u64,
53 pub cost_usd: f64,
54 pub duration_secs: u64,
55}
56
57pub fn estimate_cost(model_key: Option<&str>, input: u64, output: u64, cached: u64) -> f64 {
58 let pricing = crate::core::gain::model_pricing::ModelPricing::load();
59 let quote = pricing.quote(model_key);
60 quote.cost.estimate_usd(input, output, 0, cached)
61}
62
63static COST_BUFFER: Mutex<Option<CostStore>> = Mutex::new(None);
64
65impl CostStore {
66 pub fn load() -> Self {
67 let mut guard = COST_BUFFER.lock().unwrap_or_else(|e| e.into_inner());
68 if let Some(ref store) = *guard {
69 return store.clone();
70 }
71
72 let store = load_from_disk();
73 *guard = Some(store.clone());
74 store
75 }
76
77 pub fn record_tool_call(
78 &mut self,
79 agent_id: &str,
80 agent_type: &str,
81 tool_name: &str,
82 input_tokens: u64,
83 output_tokens: u64,
84 ) {
85 let now = Utc::now().to_rfc3339();
86 let pricing = crate::core::gain::model_pricing::ModelPricing::load();
87 let quote = pricing.quote_from_env_or_agent_type(agent_type);
88 let cost = quote.cost.estimate_usd(input_tokens, output_tokens, 0, 0);
89
90 let agent = self
91 .agents
92 .entry(agent_id.to_string())
93 .or_insert_with(|| AgentCost {
94 agent_id: agent_id.to_string(),
95 agent_type: agent_type.to_string(),
96 first_seen: Some(now.clone()),
97 ..Default::default()
98 });
99 agent.total_input_tokens += input_tokens;
100 agent.total_output_tokens += output_tokens;
101 agent.total_calls += 1;
102 agent.cost_usd += cost;
103 agent.last_seen = Some(now.clone());
104 agent.model_key = Some(quote.model_key.clone());
105 agent.pricing_match = Some(format!("{:?}", quote.match_kind));
106 *agent.tools_used.entry(tool_name.to_string()).or_insert(0) += 1;
107
108 let tool = self
109 .tools
110 .entry(tool_name.to_string())
111 .or_insert_with(|| ToolCost {
112 tool_name: tool_name.to_string(),
113 ..Default::default()
114 });
115 tool.total_input_tokens += input_tokens;
116 tool.total_output_tokens += output_tokens;
117 tool.total_calls += 1;
118 tool.cost_usd += cost;
119 if tool.total_calls > 0 {
120 tool.avg_input_tokens = tool.total_input_tokens as f64 / tool.total_calls as f64;
121 tool.avg_output_tokens = tool.total_output_tokens as f64 / tool.total_calls as f64;
122 }
123
124 self.updated_at = Some(now);
125 }
126
127 pub fn save(&self) -> std::io::Result<()> {
128 save_to_disk(self)?;
129 let mut guard = COST_BUFFER.lock().unwrap_or_else(|e| e.into_inner());
130 *guard = Some(self.clone());
131 Ok(())
132 }
133
134 pub fn top_agents(&self, limit: usize) -> Vec<&AgentCost> {
135 let mut agents: Vec<_> = self.agents.values().collect();
136 agents.sort_by(|a, b| {
137 b.cost_usd
138 .partial_cmp(&a.cost_usd)
139 .unwrap_or(std::cmp::Ordering::Equal)
140 });
141 agents.truncate(limit);
142 agents
143 }
144
145 pub fn top_tools(&self, limit: usize) -> Vec<&ToolCost> {
146 let mut tools: Vec<_> = self.tools.values().collect();
147 tools.sort_by(|a, b| {
148 b.cost_usd
149 .partial_cmp(&a.cost_usd)
150 .unwrap_or(std::cmp::Ordering::Equal)
151 });
152 tools.truncate(limit);
153 tools
154 }
155
156 pub fn total_cost(&self) -> f64 {
157 self.agents.values().map(|a| a.cost_usd).sum()
158 }
159
160 pub fn total_tokens(&self) -> (u64, u64) {
161 let input: u64 = self.agents.values().map(|a| a.total_input_tokens).sum();
162 let output: u64 = self.agents.values().map(|a| a.total_output_tokens).sum();
163 (input, output)
164 }
165
166 pub fn add_session_snapshot(
167 &mut self,
168 agent_id: &str,
169 input: u64,
170 output: u64,
171 saved: u64,
172 duration_secs: u64,
173 ) {
174 let model_key = self
175 .agents
176 .get(agent_id)
177 .and_then(|a| a.model_key.as_deref())
178 .map(|s| s.to_string());
179 let cost = estimate_cost(model_key.as_deref(), input, output, 0);
180 self.sessions.push(SessionCostSnapshot {
181 timestamp: Utc::now().to_rfc3339(),
182 agent_id: agent_id.to_string(),
183 model_key,
184 total_input: input,
185 total_output: output,
186 total_saved: saved,
187 cost_usd: cost,
188 duration_secs,
189 });
190
191 if self.sessions.len() > 500 {
192 self.sessions.drain(0..self.sessions.len() - 500);
193 }
194 }
195}
196
197fn cost_store_path() -> Option<PathBuf> {
198 crate::core::data_dir::lean_ctx_data_dir()
199 .ok()
200 .map(|d| d.join("cost_attribution.json"))
201}
202
203fn load_from_disk() -> CostStore {
204 let path = match cost_store_path() {
205 Some(p) => p,
206 None => return CostStore::default(),
207 };
208 match std::fs::read_to_string(&path) {
209 Ok(content) => serde_json::from_str(&content).unwrap_or_default(),
210 Err(_) => CostStore::default(),
211 }
212}
213
214fn save_to_disk(store: &CostStore) -> std::io::Result<()> {
215 let path = match cost_store_path() {
216 Some(p) => p,
217 None => {
218 return Err(std::io::Error::new(
219 std::io::ErrorKind::NotFound,
220 "no home dir",
221 ))
222 }
223 };
224
225 if let Some(parent) = path.parent() {
226 std::fs::create_dir_all(parent)?;
227 }
228
229 let json = serde_json::to_string(store).map_err(std::io::Error::other)?;
230 let tmp = path.with_extension("tmp");
231 std::fs::write(&tmp, &json)?;
232 std::fs::rename(&tmp, &path)?;
233 Ok(())
234}
235
236pub fn format_cost_report(store: &CostStore, limit: usize) -> String {
237 let mut lines = Vec::new();
238 let (total_in, total_out) = store.total_tokens();
239 let total_cost = store.total_cost();
240
241 lines.push(format!(
242 "Cost Attribution Report ({} agents, {} tools)",
243 store.agents.len(),
244 store.tools.len()
245 ));
246 lines.push(format!(
247 "Total: {total_in} input + {total_out} output tokens = ${total_cost:.4}"
248 ));
249 if let Ok(m) = std::env::var("LEAN_CTX_MODEL").or_else(|_| std::env::var("LCTX_MODEL")) {
250 if !m.trim().is_empty() {
251 let pricing = crate::core::gain::model_pricing::ModelPricing::load();
252 let q = pricing.quote(Some(&m));
253 lines.push(format!(
254 "Pricing: model={} ({:?}) in=${:.2}/M out=${:.2}/M cacheR=${:.3}/M",
255 q.model_key,
256 q.match_kind,
257 q.cost.input_per_m,
258 q.cost.output_per_m,
259 q.cost.cache_read_per_m
260 ));
261 }
262 }
263 lines.push(String::new());
264
265 let top_agents = store.top_agents(limit);
266 if !top_agents.is_empty() {
267 lines.push("Top Agents by Cost:".to_string());
268 for (i, agent) in top_agents.iter().enumerate() {
269 lines.push(format!(
270 " {}. {} ({}) — {} calls, {} in + {} out tok, ${:.4}{}",
271 i + 1,
272 agent.agent_id,
273 agent.agent_type,
274 agent.total_calls,
275 agent.total_input_tokens,
276 agent.total_output_tokens,
277 agent.cost_usd,
278 agent
279 .model_key
280 .as_deref()
281 .map(|m| format!(" [{m}]"))
282 .unwrap_or_default()
283 ));
284 }
285 lines.push(String::new());
286 }
287
288 let top_tools = store.top_tools(limit);
289 if !top_tools.is_empty() {
290 lines.push("Top Tools by Cost:".to_string());
291 for (i, tool) in top_tools.iter().enumerate() {
292 lines.push(format!(
293 " {}. {} — {} calls, avg {:.0} in + {:.0} out tok, ${:.4}",
294 i + 1,
295 tool.tool_name,
296 tool.total_calls,
297 tool.avg_input_tokens,
298 tool.avg_output_tokens,
299 tool.cost_usd
300 ));
301 }
302 }
303
304 lines.join("\n")
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn cost_estimation() {
313 let cost = estimate_cost(Some("fallback-blended"), 1000, 100, 500);
314 assert!(cost > 0.0);
315 }
316
317 #[test]
318 fn record_and_query() {
319 let mut store = CostStore::default();
320 store.record_tool_call("agent-1", "mcp", "ctx_read", 5000, 200);
321 store.record_tool_call("agent-1", "mcp", "ctx_read", 3000, 150);
322 store.record_tool_call("agent-2", "cursor", "ctx_shell", 1000, 100);
323
324 assert_eq!(store.agents.len(), 2);
325 assert_eq!(store.tools.len(), 2);
326
327 let agent1 = &store.agents["agent-1"];
328 assert_eq!(agent1.total_calls, 2);
329 assert_eq!(agent1.total_input_tokens, 8000);
330 assert_eq!(*agent1.tools_used.get("ctx_read").unwrap(), 2);
331
332 let top = store.top_agents(5);
333 assert_eq!(top[0].agent_id, "agent-1");
334 }
335
336 #[test]
337 fn format_report() {
338 let mut store = CostStore::default();
339 store.record_tool_call("agent-a", "mcp", "ctx_read", 10000, 500);
340 store.record_tool_call("agent-b", "cursor", "ctx_shell", 2000, 100);
341
342 let report = format_cost_report(&store, 5);
343 assert!(report.contains("Cost Attribution Report"));
344 assert!(report.contains("agent-a"));
345 assert!(report.contains("ctx_read"));
346 }
347
348 #[test]
349 fn session_snapshots() {
350 let mut store = CostStore::default();
351 store.add_session_snapshot("agent-a", 50000, 5000, 30000, 120);
352 assert_eq!(store.sessions.len(), 1);
353 assert!(store.sessions[0].cost_usd > 0.0);
354 }
355}