agent_chain_core/utils/
usage.rs

1//! Usage utilities.
2//!
3//! Adapted from langchain_core/utils/usage.py
4
5use serde_json::{Map, Value};
6use std::collections::HashMap;
7use std::collections::HashSet;
8
9/// Perform an integer operation on nested dictionaries.
10///
11/// This function recursively applies an operation to integer values in
12/// nested dictionaries.
13///
14/// # Arguments
15///
16/// * `left` - The first dictionary.
17/// * `right` - The second dictionary.
18/// * `op` - The operation to apply (e.g., addition, subtraction).
19/// * `default` - The default value for missing keys.
20/// * `max_depth` - Maximum recursion depth (default: 100).
21///
22/// # Returns
23///
24/// A new dictionary with the operation applied, or an error if max depth exceeded.
25///
26/// # Example
27///
28/// ```
29/// use std::collections::HashMap;
30/// use agent_chain_core::utils::usage::{dict_int_add, UsageValue};
31///
32/// let mut left = HashMap::new();
33/// left.insert("a".to_string(), UsageValue::Int(1));
34/// left.insert("b".to_string(), UsageValue::Int(2));
35///
36/// let mut right = HashMap::new();
37/// right.insert("a".to_string(), UsageValue::Int(3));
38/// right.insert("c".to_string(), UsageValue::Int(4));
39///
40/// let result = dict_int_add(&left, &right).unwrap();
41/// // result["a"] == 4, result["b"] == 2, result["c"] == 4
42/// ```
43pub fn dict_int_op<F>(
44    left: &HashMap<String, UsageValue>,
45    right: &HashMap<String, UsageValue>,
46    op: F,
47    default: i64,
48    max_depth: usize,
49) -> Result<HashMap<String, UsageValue>, UsageError>
50where
51    F: Fn(i64, i64) -> i64 + Copy,
52{
53    dict_int_op_impl(left, right, op, default, 0, max_depth)
54}
55
56fn dict_int_op_impl<F>(
57    left: &HashMap<String, UsageValue>,
58    right: &HashMap<String, UsageValue>,
59    op: F,
60    default: i64,
61    depth: usize,
62    max_depth: usize,
63) -> Result<HashMap<String, UsageValue>, UsageError>
64where
65    F: Fn(i64, i64) -> i64 + Copy,
66{
67    if depth >= max_depth {
68        return Err(UsageError::MaxDepthExceeded(max_depth));
69    }
70
71    let mut combined = HashMap::new();
72    let all_keys: std::collections::HashSet<_> = left.keys().chain(right.keys()).cloned().collect();
73
74    for k in all_keys {
75        let left_val = left.get(&k);
76        let right_val = right.get(&k);
77
78        match (left_val, right_val) {
79            (Some(UsageValue::Int(l)), Some(UsageValue::Int(r))) => {
80                combined.insert(k, UsageValue::Int(op(*l, *r)));
81            }
82            (Some(UsageValue::Int(l)), None) => {
83                combined.insert(k, UsageValue::Int(op(*l, default)));
84            }
85            (None, Some(UsageValue::Int(r))) => {
86                combined.insert(k, UsageValue::Int(op(default, *r)));
87            }
88            (Some(UsageValue::Dict(l)), Some(UsageValue::Dict(r))) => {
89                let nested = dict_int_op_impl(l, r, op, default, depth + 1, max_depth)?;
90                combined.insert(k, UsageValue::Dict(nested));
91            }
92            (Some(UsageValue::Dict(l)), None) => {
93                let empty = HashMap::new();
94                let nested = dict_int_op_impl(l, &empty, op, default, depth + 1, max_depth)?;
95                combined.insert(k, UsageValue::Dict(nested));
96            }
97            (None, Some(UsageValue::Dict(r))) => {
98                let empty = HashMap::new();
99                let nested = dict_int_op_impl(&empty, r, op, default, depth + 1, max_depth)?;
100                combined.insert(k, UsageValue::Dict(nested));
101            }
102            (Some(l), Some(r)) => {
103                return Err(UsageError::TypeMismatch {
104                    key: k,
105                    left_type: l.type_name().to_string(),
106                    right_type: r.type_name().to_string(),
107                });
108            }
109            (None, None) => unreachable!(),
110        }
111    }
112
113    Ok(combined)
114}
115
116/// Add two usage dictionaries together.
117///
118/// # Arguments
119///
120/// * `left` - The first dictionary.
121/// * `right` - The second dictionary.
122///
123/// # Returns
124///
125/// A new dictionary with values added together.
126pub fn dict_int_add(
127    left: &HashMap<String, UsageValue>,
128    right: &HashMap<String, UsageValue>,
129) -> Result<HashMap<String, UsageValue>, UsageError> {
130    dict_int_op(left, right, |a, b| a + b, 0, 100)
131}
132
133/// Subtract one usage dictionary from another.
134///
135/// # Arguments
136///
137/// * `left` - The first dictionary.
138/// * `right` - The dictionary to subtract.
139///
140/// # Returns
141///
142/// A new dictionary with values subtracted.
143pub fn dict_int_sub(
144    left: &HashMap<String, UsageValue>,
145    right: &HashMap<String, UsageValue>,
146) -> Result<HashMap<String, UsageValue>, UsageError> {
147    dict_int_op(left, right, |a, b| a - b, 0, 100)
148}
149
150/// A value in a usage dictionary.
151#[derive(Debug, Clone, PartialEq)]
152pub enum UsageValue {
153    /// An integer value.
154    Int(i64),
155    /// A nested dictionary.
156    Dict(HashMap<String, UsageValue>),
157}
158
159impl UsageValue {
160    /// Get the type name of this value.
161    pub fn type_name(&self) -> &'static str {
162        match self {
163            UsageValue::Int(_) => "int",
164            UsageValue::Dict(_) => "dict",
165        }
166    }
167
168    /// Try to get the value as an integer.
169    pub fn as_int(&self) -> Option<i64> {
170        match self {
171            UsageValue::Int(v) => Some(*v),
172            _ => None,
173        }
174    }
175
176    /// Try to get the value as a dictionary.
177    pub fn as_dict(&self) -> Option<&HashMap<String, UsageValue>> {
178        match self {
179            UsageValue::Dict(v) => Some(v),
180            _ => None,
181        }
182    }
183}
184
185impl From<i64> for UsageValue {
186    fn from(v: i64) -> Self {
187        UsageValue::Int(v)
188    }
189}
190
191impl From<i32> for UsageValue {
192    fn from(v: i32) -> Self {
193        UsageValue::Int(v as i64)
194    }
195}
196
197impl From<HashMap<String, UsageValue>> for UsageValue {
198    fn from(v: HashMap<String, UsageValue>) -> Self {
199        UsageValue::Dict(v)
200    }
201}
202
203/// Error types for usage operations.
204#[derive(Debug, Clone, PartialEq)]
205pub enum UsageError {
206    /// Maximum recursion depth exceeded.
207    MaxDepthExceeded(usize),
208    /// Type mismatch between left and right values.
209    TypeMismatch {
210        key: String,
211        left_type: String,
212        right_type: String,
213    },
214}
215
216impl std::fmt::Display for UsageError {
217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218        match self {
219            UsageError::MaxDepthExceeded(depth) => {
220                write!(f, "max_depth={} exceeded, unable to combine dicts", depth)
221            }
222            UsageError::TypeMismatch {
223                key,
224                left_type,
225                right_type,
226            } => {
227                write!(
228                    f,
229                    "Unknown value types for key '{}': {} and {}. Only dict and int values are supported.",
230                    key, left_type, right_type
231                )
232            }
233        }
234    }
235}
236
237impl std::error::Error for UsageError {}
238
239/// Perform an integer operation on nested JSON dictionaries.
240///
241/// This function recursively applies an operation to integer values in
242/// nested JSON objects. This matches the Python `_dict_int_op` function
243/// from `langchain_core.utils.usage`.
244///
245/// # Arguments
246///
247/// * `left` - The first JSON object.
248/// * `right` - The second JSON object.
249/// * `op` - The operation to apply (e.g., addition, subtraction).
250/// * `default` - The default value for missing keys.
251/// * `max_depth` - Maximum recursion depth (default: 100).
252///
253/// # Returns
254///
255/// A new JSON object with the operation applied, or an error if max depth exceeded.
256pub fn dict_int_op_json<F>(
257    left: &Value,
258    right: &Value,
259    op: F,
260    default: i64,
261    max_depth: usize,
262) -> Result<Value, UsageError>
263where
264    F: Fn(i64, i64) -> i64 + Copy,
265{
266    dict_int_op_json_impl(left, right, op, default, 0, max_depth)
267}
268
269fn dict_int_op_json_impl<F>(
270    left: &Value,
271    right: &Value,
272    op: F,
273    default: i64,
274    depth: usize,
275    max_depth: usize,
276) -> Result<Value, UsageError>
277where
278    F: Fn(i64, i64) -> i64 + Copy,
279{
280    if depth >= max_depth {
281        return Err(UsageError::MaxDepthExceeded(max_depth));
282    }
283
284    let empty_map = Map::new();
285    let left_obj = left.as_object().unwrap_or(&empty_map);
286    let right_obj = right.as_object().unwrap_or(&empty_map);
287
288    let all_keys: HashSet<_> = left_obj.keys().chain(right_obj.keys()).cloned().collect();
289
290    let mut combined = Map::new();
291
292    for k in all_keys {
293        let left_val = left_obj.get(&k);
294        let right_val = right_obj.get(&k);
295
296        match (left_val, right_val) {
297            // Both are integers
298            (Some(Value::Number(l)), Some(Value::Number(r))) if l.is_i64() && r.is_i64() => {
299                let l_int = l.as_i64().unwrap_or(default);
300                let r_int = r.as_i64().unwrap_or(default);
301                combined.insert(k, Value::Number(op(l_int, r_int).into()));
302            }
303            // Left is int, right is missing
304            (Some(Value::Number(l)), None) if l.is_i64() => {
305                let l_int = l.as_i64().unwrap_or(default);
306                combined.insert(k, Value::Number(op(l_int, default).into()));
307            }
308            // Right is int, left is missing
309            (None, Some(Value::Number(r))) if r.is_i64() => {
310                let r_int = r.as_i64().unwrap_or(default);
311                combined.insert(k, Value::Number(op(default, r_int).into()));
312            }
313            // Both are objects
314            (Some(Value::Object(_)), Some(Value::Object(_))) => {
315                let nested = dict_int_op_json_impl(
316                    left_val.unwrap(),
317                    right_val.unwrap(),
318                    op,
319                    default,
320                    depth + 1,
321                    max_depth,
322                )?;
323                combined.insert(k, nested);
324            }
325            // Left is object, right is missing
326            (Some(Value::Object(_)), None) => {
327                let nested = dict_int_op_json_impl(
328                    left_val.unwrap(),
329                    &Value::Object(Map::new()),
330                    op,
331                    default,
332                    depth + 1,
333                    max_depth,
334                )?;
335                combined.insert(k, nested);
336            }
337            // Right is object, left is missing
338            (None, Some(Value::Object(_))) => {
339                let nested = dict_int_op_json_impl(
340                    &Value::Object(Map::new()),
341                    right_val.unwrap(),
342                    op,
343                    default,
344                    depth + 1,
345                    max_depth,
346                )?;
347                combined.insert(k, nested);
348            }
349            // Neither present (shouldn't happen due to all_keys)
350            (None, None) => {}
351            // Type mismatch or unsupported types
352            (Some(l), Some(r)) => {
353                return Err(UsageError::TypeMismatch {
354                    key: k,
355                    left_type: json_type_name(l).to_string(),
356                    right_type: json_type_name(r).to_string(),
357                });
358            }
359            // One side has unsupported type
360            (Some(v), None) | (None, Some(v)) => {
361                // Just copy over non-int/non-object values
362                combined.insert(k, v.clone());
363            }
364        }
365    }
366
367    Ok(Value::Object(combined))
368}
369
370fn json_type_name(value: &Value) -> &'static str {
371    match value {
372        Value::Null => "null",
373        Value::Bool(_) => "bool",
374        Value::Number(_) => "number",
375        Value::String(_) => "string",
376        Value::Array(_) => "array",
377        Value::Object(_) => "object",
378    }
379}
380
381/// Add two JSON usage dictionaries together.
382///
383/// # Arguments
384///
385/// * `left` - The first JSON object.
386/// * `right` - The second JSON object.
387///
388/// # Returns
389///
390/// A new JSON object with values added together.
391pub fn dict_int_add_json(left: &Value, right: &Value) -> Result<Value, UsageError> {
392    dict_int_op_json(left, right, |a, b| a + b, 0, 100)
393}
394
395/// Subtract one JSON usage dictionary from another, with floor at 0.
396///
397/// Token counts cannot be negative so the actual operation is `max(left - right, 0)`.
398///
399/// # Arguments
400///
401/// * `left` - The first JSON object.
402/// * `right` - The JSON object to subtract.
403///
404/// # Returns
405///
406/// A new JSON object with values subtracted (floored at 0).
407pub fn dict_int_sub_floor_json(left: &Value, right: &Value) -> Result<Value, UsageError> {
408    dict_int_op_json(left, right, |a, b| (a - b).max(0), 0, 100)
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414    use serde_json::json;
415
416    #[test]
417    fn test_dict_int_add() {
418        let mut left = HashMap::new();
419        left.insert("a".to_string(), UsageValue::Int(1));
420        left.insert("b".to_string(), UsageValue::Int(2));
421
422        let mut right = HashMap::new();
423        right.insert("a".to_string(), UsageValue::Int(3));
424        right.insert("c".to_string(), UsageValue::Int(4));
425
426        let result = dict_int_add(&left, &right).unwrap();
427
428        assert_eq!(result.get("a").unwrap().as_int(), Some(4));
429        assert_eq!(result.get("b").unwrap().as_int(), Some(2));
430        assert_eq!(result.get("c").unwrap().as_int(), Some(4));
431    }
432
433    #[test]
434    fn test_dict_int_add_nested() {
435        let mut inner_left = HashMap::new();
436        inner_left.insert("x".to_string(), UsageValue::Int(1));
437
438        let mut left = HashMap::new();
439        left.insert("nested".to_string(), UsageValue::Dict(inner_left));
440
441        let mut inner_right = HashMap::new();
442        inner_right.insert("x".to_string(), UsageValue::Int(2));
443        inner_right.insert("y".to_string(), UsageValue::Int(3));
444
445        let mut right = HashMap::new();
446        right.insert("nested".to_string(), UsageValue::Dict(inner_right));
447
448        let result = dict_int_add(&left, &right).unwrap();
449
450        let nested = result.get("nested").unwrap().as_dict().unwrap();
451        assert_eq!(nested.get("x").unwrap().as_int(), Some(3));
452        assert_eq!(nested.get("y").unwrap().as_int(), Some(3));
453    }
454
455    #[test]
456    fn test_dict_int_sub() {
457        let mut left = HashMap::new();
458        left.insert("a".to_string(), UsageValue::Int(5));
459        left.insert("b".to_string(), UsageValue::Int(3));
460
461        let mut right = HashMap::new();
462        right.insert("a".to_string(), UsageValue::Int(2));
463
464        let result = dict_int_sub(&left, &right).unwrap();
465
466        assert_eq!(result.get("a").unwrap().as_int(), Some(3));
467        assert_eq!(result.get("b").unwrap().as_int(), Some(3));
468    }
469
470    #[test]
471    fn test_max_depth_exceeded() {
472        fn create_nested(depth: usize) -> HashMap<String, UsageValue> {
473            if depth == 0 {
474                let mut m = HashMap::new();
475                m.insert("value".to_string(), UsageValue::Int(1));
476                m
477            } else {
478                let mut m = HashMap::new();
479                m.insert(
480                    "nested".to_string(),
481                    UsageValue::Dict(create_nested(depth - 1)),
482                );
483                m
484            }
485        }
486
487        let left = create_nested(150);
488        let right = create_nested(150);
489
490        let result = dict_int_op(&left, &right, |a, b| a + b, 0, 100);
491        assert!(matches!(result, Err(UsageError::MaxDepthExceeded(_))));
492    }
493
494    #[test]
495    fn test_dict_int_add_json() {
496        let left = json!({
497            "a": 1,
498            "b": 2
499        });
500
501        let right = json!({
502            "a": 3,
503            "c": 4
504        });
505
506        let result = dict_int_add_json(&left, &right).unwrap();
507
508        assert_eq!(result["a"], 4);
509        assert_eq!(result["b"], 2);
510        assert_eq!(result["c"], 4);
511    }
512
513    #[test]
514    fn test_dict_int_add_json_nested() {
515        let left = json!({
516            "input_tokens": 5,
517            "output_tokens": 0,
518            "total_tokens": 5,
519            "input_token_details": {
520                "cache_read": 3
521            }
522        });
523
524        let right = json!({
525            "input_tokens": 0,
526            "output_tokens": 10,
527            "total_tokens": 10,
528            "output_token_details": {
529                "reasoning": 4
530            }
531        });
532
533        let result = dict_int_add_json(&left, &right).unwrap();
534
535        assert_eq!(result["input_tokens"], 5);
536        assert_eq!(result["output_tokens"], 10);
537        assert_eq!(result["total_tokens"], 15);
538        assert_eq!(result["input_token_details"]["cache_read"], 3);
539        assert_eq!(result["output_token_details"]["reasoning"], 4);
540    }
541
542    #[test]
543    fn test_dict_int_sub_floor_json() {
544        let left = json!({
545            "input_tokens": 5,
546            "output_tokens": 10,
547            "total_tokens": 15,
548            "input_token_details": {
549                "cache_read": 4
550            }
551        });
552
553        let right = json!({
554            "input_tokens": 3,
555            "output_tokens": 8,
556            "total_tokens": 11,
557            "output_token_details": {
558                "reasoning": 4
559            }
560        });
561
562        let result = dict_int_sub_floor_json(&left, &right).unwrap();
563
564        assert_eq!(result["input_tokens"], 2);
565        assert_eq!(result["output_tokens"], 2);
566        assert_eq!(result["total_tokens"], 4);
567        assert_eq!(result["input_token_details"]["cache_read"], 4);
568        // reasoning should be 0 because 0 - 4 = -4, floored to 0
569        assert_eq!(result["output_token_details"]["reasoning"], 0);
570    }
571}