1use serde_json::{Map, Value};
6use std::collections::HashMap;
7use std::collections::HashSet;
8
9pub fn dict_int_op<F>(
44 left: &HashMap<String, UsageValue>,
45 right: &HashMap<String, UsageValue>,
46 op: F,
47 default: i64,
48 max_depth: usize,
49) -> Result<HashMap<String, UsageValue>, UsageError>
50where
51 F: Fn(i64, i64) -> i64 + Copy,
52{
53 dict_int_op_impl(left, right, op, default, 0, max_depth)
54}
55
56fn dict_int_op_impl<F>(
57 left: &HashMap<String, UsageValue>,
58 right: &HashMap<String, UsageValue>,
59 op: F,
60 default: i64,
61 depth: usize,
62 max_depth: usize,
63) -> Result<HashMap<String, UsageValue>, UsageError>
64where
65 F: Fn(i64, i64) -> i64 + Copy,
66{
67 if depth >= max_depth {
68 return Err(UsageError::MaxDepthExceeded(max_depth));
69 }
70
71 let mut combined = HashMap::new();
72 let all_keys: std::collections::HashSet<_> = left.keys().chain(right.keys()).cloned().collect();
73
74 for k in all_keys {
75 let left_val = left.get(&k);
76 let right_val = right.get(&k);
77
78 match (left_val, right_val) {
79 (Some(UsageValue::Int(l)), Some(UsageValue::Int(r))) => {
80 combined.insert(k, UsageValue::Int(op(*l, *r)));
81 }
82 (Some(UsageValue::Int(l)), None) => {
83 combined.insert(k, UsageValue::Int(op(*l, default)));
84 }
85 (None, Some(UsageValue::Int(r))) => {
86 combined.insert(k, UsageValue::Int(op(default, *r)));
87 }
88 (Some(UsageValue::Dict(l)), Some(UsageValue::Dict(r))) => {
89 let nested = dict_int_op_impl(l, r, op, default, depth + 1, max_depth)?;
90 combined.insert(k, UsageValue::Dict(nested));
91 }
92 (Some(UsageValue::Dict(l)), None) => {
93 let empty = HashMap::new();
94 let nested = dict_int_op_impl(l, &empty, op, default, depth + 1, max_depth)?;
95 combined.insert(k, UsageValue::Dict(nested));
96 }
97 (None, Some(UsageValue::Dict(r))) => {
98 let empty = HashMap::new();
99 let nested = dict_int_op_impl(&empty, r, op, default, depth + 1, max_depth)?;
100 combined.insert(k, UsageValue::Dict(nested));
101 }
102 (Some(l), Some(r)) => {
103 return Err(UsageError::TypeMismatch {
104 key: k,
105 left_type: l.type_name().to_string(),
106 right_type: r.type_name().to_string(),
107 });
108 }
109 (None, None) => unreachable!(),
110 }
111 }
112
113 Ok(combined)
114}
115
116pub fn dict_int_add(
127 left: &HashMap<String, UsageValue>,
128 right: &HashMap<String, UsageValue>,
129) -> Result<HashMap<String, UsageValue>, UsageError> {
130 dict_int_op(left, right, |a, b| a + b, 0, 100)
131}
132
133pub fn dict_int_sub(
144 left: &HashMap<String, UsageValue>,
145 right: &HashMap<String, UsageValue>,
146) -> Result<HashMap<String, UsageValue>, UsageError> {
147 dict_int_op(left, right, |a, b| a - b, 0, 100)
148}
149
150#[derive(Debug, Clone, PartialEq)]
152pub enum UsageValue {
153 Int(i64),
155 Dict(HashMap<String, UsageValue>),
157}
158
159impl UsageValue {
160 pub fn type_name(&self) -> &'static str {
162 match self {
163 UsageValue::Int(_) => "int",
164 UsageValue::Dict(_) => "dict",
165 }
166 }
167
168 pub fn as_int(&self) -> Option<i64> {
170 match self {
171 UsageValue::Int(v) => Some(*v),
172 _ => None,
173 }
174 }
175
176 pub fn as_dict(&self) -> Option<&HashMap<String, UsageValue>> {
178 match self {
179 UsageValue::Dict(v) => Some(v),
180 _ => None,
181 }
182 }
183}
184
185impl From<i64> for UsageValue {
186 fn from(v: i64) -> Self {
187 UsageValue::Int(v)
188 }
189}
190
191impl From<i32> for UsageValue {
192 fn from(v: i32) -> Self {
193 UsageValue::Int(v as i64)
194 }
195}
196
197impl From<HashMap<String, UsageValue>> for UsageValue {
198 fn from(v: HashMap<String, UsageValue>) -> Self {
199 UsageValue::Dict(v)
200 }
201}
202
203#[derive(Debug, Clone, PartialEq)]
205pub enum UsageError {
206 MaxDepthExceeded(usize),
208 TypeMismatch {
210 key: String,
211 left_type: String,
212 right_type: String,
213 },
214}
215
216impl std::fmt::Display for UsageError {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 match self {
219 UsageError::MaxDepthExceeded(depth) => {
220 write!(f, "max_depth={} exceeded, unable to combine dicts", depth)
221 }
222 UsageError::TypeMismatch {
223 key,
224 left_type,
225 right_type,
226 } => {
227 write!(
228 f,
229 "Unknown value types for key '{}': {} and {}. Only dict and int values are supported.",
230 key, left_type, right_type
231 )
232 }
233 }
234 }
235}
236
237impl std::error::Error for UsageError {}
238
239pub fn dict_int_op_json<F>(
257 left: &Value,
258 right: &Value,
259 op: F,
260 default: i64,
261 max_depth: usize,
262) -> Result<Value, UsageError>
263where
264 F: Fn(i64, i64) -> i64 + Copy,
265{
266 dict_int_op_json_impl(left, right, op, default, 0, max_depth)
267}
268
269fn dict_int_op_json_impl<F>(
270 left: &Value,
271 right: &Value,
272 op: F,
273 default: i64,
274 depth: usize,
275 max_depth: usize,
276) -> Result<Value, UsageError>
277where
278 F: Fn(i64, i64) -> i64 + Copy,
279{
280 if depth >= max_depth {
281 return Err(UsageError::MaxDepthExceeded(max_depth));
282 }
283
284 let empty_map = Map::new();
285 let left_obj = left.as_object().unwrap_or(&empty_map);
286 let right_obj = right.as_object().unwrap_or(&empty_map);
287
288 let all_keys: HashSet<_> = left_obj.keys().chain(right_obj.keys()).cloned().collect();
289
290 let mut combined = Map::new();
291
292 for k in all_keys {
293 let left_val = left_obj.get(&k);
294 let right_val = right_obj.get(&k);
295
296 match (left_val, right_val) {
297 (Some(Value::Number(l)), Some(Value::Number(r))) if l.is_i64() && r.is_i64() => {
299 let l_int = l.as_i64().unwrap_or(default);
300 let r_int = r.as_i64().unwrap_or(default);
301 combined.insert(k, Value::Number(op(l_int, r_int).into()));
302 }
303 (Some(Value::Number(l)), None) if l.is_i64() => {
305 let l_int = l.as_i64().unwrap_or(default);
306 combined.insert(k, Value::Number(op(l_int, default).into()));
307 }
308 (None, Some(Value::Number(r))) if r.is_i64() => {
310 let r_int = r.as_i64().unwrap_or(default);
311 combined.insert(k, Value::Number(op(default, r_int).into()));
312 }
313 (Some(Value::Object(_)), Some(Value::Object(_))) => {
315 let nested = dict_int_op_json_impl(
316 left_val.unwrap(),
317 right_val.unwrap(),
318 op,
319 default,
320 depth + 1,
321 max_depth,
322 )?;
323 combined.insert(k, nested);
324 }
325 (Some(Value::Object(_)), None) => {
327 let nested = dict_int_op_json_impl(
328 left_val.unwrap(),
329 &Value::Object(Map::new()),
330 op,
331 default,
332 depth + 1,
333 max_depth,
334 )?;
335 combined.insert(k, nested);
336 }
337 (None, Some(Value::Object(_))) => {
339 let nested = dict_int_op_json_impl(
340 &Value::Object(Map::new()),
341 right_val.unwrap(),
342 op,
343 default,
344 depth + 1,
345 max_depth,
346 )?;
347 combined.insert(k, nested);
348 }
349 (None, None) => {}
351 (Some(l), Some(r)) => {
353 return Err(UsageError::TypeMismatch {
354 key: k,
355 left_type: json_type_name(l).to_string(),
356 right_type: json_type_name(r).to_string(),
357 });
358 }
359 (Some(v), None) | (None, Some(v)) => {
361 combined.insert(k, v.clone());
363 }
364 }
365 }
366
367 Ok(Value::Object(combined))
368}
369
370fn json_type_name(value: &Value) -> &'static str {
371 match value {
372 Value::Null => "null",
373 Value::Bool(_) => "bool",
374 Value::Number(_) => "number",
375 Value::String(_) => "string",
376 Value::Array(_) => "array",
377 Value::Object(_) => "object",
378 }
379}
380
381pub fn dict_int_add_json(left: &Value, right: &Value) -> Result<Value, UsageError> {
392 dict_int_op_json(left, right, |a, b| a + b, 0, 100)
393}
394
395pub fn dict_int_sub_floor_json(left: &Value, right: &Value) -> Result<Value, UsageError> {
408 dict_int_op_json(left, right, |a, b| (a - b).max(0), 0, 100)
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use serde_json::json;
415
416 #[test]
417 fn test_dict_int_add() {
418 let mut left = HashMap::new();
419 left.insert("a".to_string(), UsageValue::Int(1));
420 left.insert("b".to_string(), UsageValue::Int(2));
421
422 let mut right = HashMap::new();
423 right.insert("a".to_string(), UsageValue::Int(3));
424 right.insert("c".to_string(), UsageValue::Int(4));
425
426 let result = dict_int_add(&left, &right).unwrap();
427
428 assert_eq!(result.get("a").unwrap().as_int(), Some(4));
429 assert_eq!(result.get("b").unwrap().as_int(), Some(2));
430 assert_eq!(result.get("c").unwrap().as_int(), Some(4));
431 }
432
433 #[test]
434 fn test_dict_int_add_nested() {
435 let mut inner_left = HashMap::new();
436 inner_left.insert("x".to_string(), UsageValue::Int(1));
437
438 let mut left = HashMap::new();
439 left.insert("nested".to_string(), UsageValue::Dict(inner_left));
440
441 let mut inner_right = HashMap::new();
442 inner_right.insert("x".to_string(), UsageValue::Int(2));
443 inner_right.insert("y".to_string(), UsageValue::Int(3));
444
445 let mut right = HashMap::new();
446 right.insert("nested".to_string(), UsageValue::Dict(inner_right));
447
448 let result = dict_int_add(&left, &right).unwrap();
449
450 let nested = result.get("nested").unwrap().as_dict().unwrap();
451 assert_eq!(nested.get("x").unwrap().as_int(), Some(3));
452 assert_eq!(nested.get("y").unwrap().as_int(), Some(3));
453 }
454
455 #[test]
456 fn test_dict_int_sub() {
457 let mut left = HashMap::new();
458 left.insert("a".to_string(), UsageValue::Int(5));
459 left.insert("b".to_string(), UsageValue::Int(3));
460
461 let mut right = HashMap::new();
462 right.insert("a".to_string(), UsageValue::Int(2));
463
464 let result = dict_int_sub(&left, &right).unwrap();
465
466 assert_eq!(result.get("a").unwrap().as_int(), Some(3));
467 assert_eq!(result.get("b").unwrap().as_int(), Some(3));
468 }
469
470 #[test]
471 fn test_max_depth_exceeded() {
472 fn create_nested(depth: usize) -> HashMap<String, UsageValue> {
473 if depth == 0 {
474 let mut m = HashMap::new();
475 m.insert("value".to_string(), UsageValue::Int(1));
476 m
477 } else {
478 let mut m = HashMap::new();
479 m.insert(
480 "nested".to_string(),
481 UsageValue::Dict(create_nested(depth - 1)),
482 );
483 m
484 }
485 }
486
487 let left = create_nested(150);
488 let right = create_nested(150);
489
490 let result = dict_int_op(&left, &right, |a, b| a + b, 0, 100);
491 assert!(matches!(result, Err(UsageError::MaxDepthExceeded(_))));
492 }
493
494 #[test]
495 fn test_dict_int_add_json() {
496 let left = json!({
497 "a": 1,
498 "b": 2
499 });
500
501 let right = json!({
502 "a": 3,
503 "c": 4
504 });
505
506 let result = dict_int_add_json(&left, &right).unwrap();
507
508 assert_eq!(result["a"], 4);
509 assert_eq!(result["b"], 2);
510 assert_eq!(result["c"], 4);
511 }
512
513 #[test]
514 fn test_dict_int_add_json_nested() {
515 let left = json!({
516 "input_tokens": 5,
517 "output_tokens": 0,
518 "total_tokens": 5,
519 "input_token_details": {
520 "cache_read": 3
521 }
522 });
523
524 let right = json!({
525 "input_tokens": 0,
526 "output_tokens": 10,
527 "total_tokens": 10,
528 "output_token_details": {
529 "reasoning": 4
530 }
531 });
532
533 let result = dict_int_add_json(&left, &right).unwrap();
534
535 assert_eq!(result["input_tokens"], 5);
536 assert_eq!(result["output_tokens"], 10);
537 assert_eq!(result["total_tokens"], 15);
538 assert_eq!(result["input_token_details"]["cache_read"], 3);
539 assert_eq!(result["output_token_details"]["reasoning"], 4);
540 }
541
542 #[test]
543 fn test_dict_int_sub_floor_json() {
544 let left = json!({
545 "input_tokens": 5,
546 "output_tokens": 10,
547 "total_tokens": 15,
548 "input_token_details": {
549 "cache_read": 4
550 }
551 });
552
553 let right = json!({
554 "input_tokens": 3,
555 "output_tokens": 8,
556 "total_tokens": 11,
557 "output_token_details": {
558 "reasoning": 4
559 }
560 });
561
562 let result = dict_int_sub_floor_json(&left, &right).unwrap();
563
564 assert_eq!(result["input_tokens"], 2);
565 assert_eq!(result["output_tokens"], 2);
566 assert_eq!(result["total_tokens"], 4);
567 assert_eq!(result["input_token_details"]["cache_read"], 4);
568 assert_eq!(result["output_token_details"]["reasoning"], 0);
570 }
571}