1use chio_core::capability::MonetaryAmount;
7use serde::{Deserialize, Serialize};
8
9use crate::cost::CostMetadata;
10
11pub const MAX_COST_QUERY_LIMIT: usize = 500;
13
14#[derive(Debug, Clone, Default, Serialize, Deserialize)]
19pub struct CostQuery {
20 #[serde(default, skip_serializing_if = "Option::is_none")]
22 pub session_id: Option<String>,
23
24 #[serde(default, skip_serializing_if = "Option::is_none")]
26 pub agent_id: Option<String>,
27
28 #[serde(default, skip_serializing_if = "Option::is_none")]
30 pub tool_server: Option<String>,
31
32 #[serde(default, skip_serializing_if = "Option::is_none")]
34 pub tool_name: Option<String>,
35
36 #[serde(default, skip_serializing_if = "Option::is_none")]
38 pub since: Option<u64>,
39
40 #[serde(default, skip_serializing_if = "Option::is_none")]
42 pub until: Option<u64>,
43
44 #[serde(default, skip_serializing_if = "Option::is_none")]
46 pub currency: Option<String>,
47
48 #[serde(default, skip_serializing_if = "Option::is_none")]
50 pub limit: Option<usize>,
51
52 #[serde(default)]
54 pub group_by: GroupBy,
55}
56
57#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
59#[serde(rename_all = "snake_case")]
60pub enum GroupBy {
61 #[default]
63 None,
64 Session,
66 Agent,
68 Tool,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct CostQueryResult {
75 pub summary: CostSummary,
77 #[serde(default, skip_serializing_if = "Vec::is_empty")]
79 pub groups: Vec<CostGroup>,
80 pub truncated: bool,
82}
83
84#[derive(Debug, Clone, Default, Serialize, Deserialize)]
86pub struct CostSummary {
87 pub receipt_count: u64,
89 pub total_compute_time_ms: u64,
91 pub total_data_bytes: u64,
93 #[serde(default, skip_serializing_if = "Option::is_none")]
95 pub total_monetary_cost: Option<MonetaryAmount>,
96 pub distinct_agents: u64,
98 pub distinct_tools: u64,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct CostGroup {
105 pub key: String,
107 pub receipt_count: u64,
109 pub total_compute_time_ms: u64,
111 pub total_data_bytes: u64,
113 #[serde(default, skip_serializing_if = "Option::is_none")]
115 pub total_monetary_cost: Option<MonetaryAmount>,
116}
117
118pub fn execute_cost_query(records: &[CostMetadata], query: &CostQuery) -> CostQueryResult {
123 let limit = query
124 .limit
125 .unwrap_or(MAX_COST_QUERY_LIMIT)
126 .min(MAX_COST_QUERY_LIMIT);
127
128 let filtered: Vec<&CostMetadata> = records
129 .iter()
130 .filter(|record| record_matches_query(record, query))
131 .collect();
132
133 let truncated = filtered.len() > limit;
134 let capped: Vec<&CostMetadata> = filtered.iter().take(limit).copied().collect();
135
136 let mut agents = std::collections::HashSet::new();
138 let mut tools = std::collections::HashSet::new();
139 let mut total_compute = 0u64;
140 let mut total_data = 0u64;
141 let mut total_money_units = 0u64;
142 let mut money_currency: Option<String> = None;
143 let mut mixed_currency = false;
144
145 for r in &capped {
146 agents.insert(&r.agent_id);
147 tools.insert(format!("{}:{}", r.tool_server, r.tool_name));
148 total_compute = total_compute.saturating_add(r.total_compute_time_ms());
149 total_data = total_data.saturating_add(r.total_data_bytes());
150 if let Some(ref cost) = r.total_monetary_cost {
151 match &money_currency {
152 None => {
153 money_currency = Some(cost.currency.clone());
154 total_money_units = cost.units;
155 }
156 Some(c) if c == &cost.currency => {
157 total_money_units = total_money_units.saturating_add(cost.units);
158 }
159 _ => {
160 mixed_currency = true;
161 }
162 }
163 }
164 }
165
166 let total_monetary_cost = if mixed_currency {
167 None
168 } else {
169 money_currency.map(|c| MonetaryAmount {
170 units: total_money_units,
171 currency: c,
172 })
173 };
174
175 let summary = CostSummary {
176 receipt_count: capped.len() as u64,
177 total_compute_time_ms: total_compute,
178 total_data_bytes: total_data,
179 total_monetary_cost,
180 distinct_agents: agents.len() as u64,
181 distinct_tools: tools.len() as u64,
182 };
183
184 let groups = match query.group_by {
186 GroupBy::None => vec![],
187 GroupBy::Session | GroupBy::Agent | GroupBy::Tool => build_groups(&capped, &query.group_by),
188 };
189
190 CostQueryResult {
191 summary,
192 groups,
193 truncated,
194 }
195}
196
197fn record_matches_query(record: &CostMetadata, query: &CostQuery) -> bool {
198 if let Some(ref session_id) = query.session_id {
199 if record.session_id.as_ref() != Some(session_id) {
200 return false;
201 }
202 }
203 if let Some(ref agent_id) = query.agent_id {
204 if &record.agent_id != agent_id {
205 return false;
206 }
207 }
208 if let Some(ref tool_server) = query.tool_server {
209 if &record.tool_server != tool_server {
210 return false;
211 }
212 }
213 if let Some(ref tool_name) = query.tool_name {
214 if &record.tool_name != tool_name {
215 return false;
216 }
217 }
218 if let Some(since) = query.since {
219 if record.timestamp < since {
220 return false;
221 }
222 }
223 if let Some(until) = query.until {
224 if record.timestamp >= until {
225 return false;
226 }
227 }
228 if let Some(ref currency) = query.currency {
229 match record.total_monetary_cost.as_ref() {
230 Some(cost) if &cost.currency == currency => {}
231 _ => return false,
232 }
233 }
234 true
235}
236
237fn build_groups(records: &[&CostMetadata], group_by: &GroupBy) -> Vec<CostGroup> {
238 use std::collections::BTreeMap;
239
240 let mut map: BTreeMap<String, (u64, u64, u64, Option<String>, u64)> = BTreeMap::new();
241
242 for r in records {
243 let key = match group_by {
244 GroupBy::Session => r
245 .session_id
246 .clone()
247 .unwrap_or_else(|| "<no-session>".to_string()),
248 GroupBy::Agent => r.agent_id.clone(),
249 GroupBy::Tool => format!("{}:{}", r.tool_server, r.tool_name),
250 GroupBy::None => continue,
251 };
252
253 let entry = map.entry(key).or_insert_with(|| (0, 0, 0, None, 0));
254
255 entry.0 = entry.0.saturating_add(1);
256 entry.1 = entry.1.saturating_add(r.total_compute_time_ms());
257 entry.2 = entry.2.saturating_add(r.total_data_bytes());
258
259 if let Some(ref cost) = r.total_monetary_cost {
260 if entry.3.is_none() {
261 entry.3 = Some(cost.currency.clone());
262 }
263 if entry.3.as_ref() == Some(&cost.currency) {
264 entry.4 = entry.4.saturating_add(cost.units);
265 }
266 }
267 }
268
269 map.into_iter()
270 .map(|(key, (count, compute, data, currency, money))| CostGroup {
271 key,
272 receipt_count: count,
273 total_compute_time_ms: compute,
274 total_data_bytes: data,
275 total_monetary_cost: currency.map(|c| MonetaryAmount {
276 units: money,
277 currency: c,
278 }),
279 })
280 .collect()
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use crate::cost::{CostDimension, CostMetadata};
287
288 fn make_record(
289 id: &str,
290 ts: u64,
291 agent: &str,
292 server: &str,
293 tool: &str,
294 cost_units: u64,
295 ) -> CostMetadata {
296 let mut m = CostMetadata::new(
297 id.to_string(),
298 ts,
299 agent.to_string(),
300 server.to_string(),
301 tool.to_string(),
302 );
303 m.add_dimension(CostDimension::ComputeTime { duration_ms: 100 });
304 m.add_dimension(CostDimension::DataVolume {
305 bytes_read: 500,
306 bytes_written: 200,
307 });
308 m.add_dimension(CostDimension::ApiCost {
309 amount: MonetaryAmount {
310 units: cost_units,
311 currency: "USD".to_string(),
312 },
313 provider: "test".to_string(),
314 });
315 m.session_id = Some("sess-1".to_string());
316 m.compute_total_monetary_cost();
317 m
318 }
319
320 #[test]
321 fn query_no_filter() {
322 let records = vec![
323 make_record("r1", 1000, "a1", "s1", "t1", 50),
324 make_record("r2", 2000, "a2", "s1", "t2", 100),
325 ];
326 let result = execute_cost_query(&records, &CostQuery::default());
327 assert_eq!(result.summary.receipt_count, 2);
328 assert_eq!(result.summary.total_compute_time_ms, 200);
329 assert_eq!(result.summary.total_data_bytes, 1400);
330 assert_eq!(
331 result.summary.total_monetary_cost.as_ref().unwrap().units,
332 150
333 );
334 assert!(!result.truncated);
335 }
336
337 #[test]
338 fn query_filter_by_agent() {
339 let records = vec![
340 make_record("r1", 1000, "a1", "s1", "t1", 50),
341 make_record("r2", 2000, "a2", "s1", "t2", 100),
342 ];
343 let query = CostQuery {
344 agent_id: Some("a1".to_string()),
345 ..Default::default()
346 };
347 let result = execute_cost_query(&records, &query);
348 assert_eq!(result.summary.receipt_count, 1);
349 }
350
351 #[test]
352 fn query_filter_by_time_range() {
353 let records = vec![
354 make_record("r1", 1000, "a1", "s1", "t1", 50),
355 make_record("r2", 2000, "a1", "s1", "t1", 100),
356 make_record("r3", 3000, "a1", "s1", "t1", 200),
357 ];
358 let query = CostQuery {
359 since: Some(1500),
360 until: Some(2500),
361 ..Default::default()
362 };
363 let result = execute_cost_query(&records, &query);
364 assert_eq!(result.summary.receipt_count, 1);
365 }
366
367 #[test]
368 fn query_group_by_agent() {
369 let records = vec![
370 make_record("r1", 1000, "a1", "s1", "t1", 50),
371 make_record("r2", 2000, "a2", "s1", "t2", 100),
372 make_record("r3", 3000, "a1", "s1", "t1", 75),
373 ];
374 let query = CostQuery {
375 group_by: GroupBy::Agent,
376 ..Default::default()
377 };
378 let result = execute_cost_query(&records, &query);
379 assert_eq!(result.groups.len(), 2);
380
381 let a1_group = result.groups.iter().find(|g| g.key == "a1").unwrap();
382 assert_eq!(a1_group.receipt_count, 2);
383 assert_eq!(a1_group.total_monetary_cost.as_ref().unwrap().units, 125);
384 }
385
386 #[test]
387 fn query_group_by_tool() {
388 let records = vec![
389 make_record("r1", 1000, "a1", "s1", "t1", 50),
390 make_record("r2", 2000, "a1", "s1", "t2", 100),
391 ];
392 let query = CostQuery {
393 group_by: GroupBy::Tool,
394 ..Default::default()
395 };
396 let result = execute_cost_query(&records, &query);
397 assert_eq!(result.groups.len(), 2);
398 assert!(result.groups.iter().any(|g| g.key == "s1:t1"));
399 assert!(result.groups.iter().any(|g| g.key == "s1:t2"));
400 }
401
402 #[test]
403 fn query_truncation() {
404 let records: Vec<CostMetadata> = (0..600)
405 .map(|i| make_record(&format!("r{i}"), i as u64, "a1", "s1", "t1", 1))
406 .collect();
407 let result = execute_cost_query(&records, &CostQuery::default());
408 assert!(result.truncated);
409 assert_eq!(result.summary.receipt_count, MAX_COST_QUERY_LIMIT as u64);
410 }
411
412 #[test]
413 fn query_empty_records() {
414 let records: Vec<CostMetadata> = vec![];
415 let result = execute_cost_query(&records, &CostQuery::default());
416 assert_eq!(result.summary.receipt_count, 0);
417 assert_eq!(result.summary.total_compute_time_ms, 0);
418 assert_eq!(result.summary.total_data_bytes, 0);
419 assert!(result.summary.total_monetary_cost.is_none());
420 assert!(!result.truncated);
421 }
422
423 #[test]
424 fn query_group_by_session() {
425 let records = vec![
426 make_record("r1", 1000, "a1", "s1", "t1", 50),
427 make_record("r2", 2000, "a1", "s1", "t1", 100),
428 ];
429 let query = CostQuery {
430 group_by: GroupBy::Session,
431 ..Default::default()
432 };
433 let result = execute_cost_query(&records, &query);
434 assert_eq!(result.groups.len(), 1);
435 assert_eq!(result.groups[0].key, "sess-1");
436 assert_eq!(result.groups[0].receipt_count, 2);
437 }
438
439 #[test]
440 fn query_filter_by_tool_server() {
441 let records = vec![
442 make_record("r1", 1000, "a1", "s1", "t1", 50),
443 make_record("r2", 2000, "a1", "s2", "t1", 100),
444 ];
445 let query = CostQuery {
446 tool_server: Some("s1".to_string()),
447 ..Default::default()
448 };
449 let result = execute_cost_query(&records, &query);
450 assert_eq!(result.summary.receipt_count, 1);
451 }
452
453 #[test]
454 fn query_filter_by_tool_name() {
455 let records = vec![
456 make_record("r1", 1000, "a1", "s1", "t1", 50),
457 make_record("r2", 2000, "a1", "s1", "t2", 100),
458 ];
459 let query = CostQuery {
460 tool_name: Some("t2".to_string()),
461 ..Default::default()
462 };
463 let result = execute_cost_query(&records, &query);
464 assert_eq!(result.summary.receipt_count, 1);
465 assert_eq!(
466 result.summary.total_monetary_cost.as_ref().unwrap().units,
467 100
468 );
469 }
470
471 #[test]
472 fn query_currency_filter() {
473 let mut r1 = make_record("r1", 1000, "a1", "s1", "t1", 50);
474 r1.total_monetary_cost = Some(MonetaryAmount {
475 units: 50,
476 currency: "USD".to_string(),
477 });
478 let mut r2 = make_record("r2", 2000, "a1", "s1", "t1", 100);
479 r2.total_monetary_cost = Some(MonetaryAmount {
480 units: 100,
481 currency: "EUR".to_string(),
482 });
483 let records = vec![r1, r2];
484 let query = CostQuery {
485 currency: Some("EUR".to_string()),
486 ..Default::default()
487 };
488 let result = execute_cost_query(&records, &query);
489 assert_eq!(result.summary.receipt_count, 1);
490 assert_eq!(
491 result
492 .summary
493 .total_monetary_cost
494 .as_ref()
495 .unwrap()
496 .currency,
497 "EUR"
498 );
499 }
500}