agent_chain_core/callbacks/
usage.rs

1//! Callback Handler that tracks AIMessage.usage_metadata.
2//!
3//! This module provides a callback handler for tracking token usage
4//! across chat model calls, following the Python LangChain UsageMetadataCallbackHandler pattern.
5
6use std::collections::HashMap;
7use std::fmt;
8use std::ops::{Deref, DerefMut};
9use std::sync::{Arc, Mutex};
10
11use uuid::Uuid;
12
13use crate::messages::{BaseMessage, UsageMetadata};
14use crate::outputs::ChatResult;
15
16use super::base::{
17    BaseCallbackHandler, CallbackManagerMixin, ChainManagerMixin, LLMManagerMixin,
18    RetrieverManagerMixin, RunManagerMixin, ToolManagerMixin,
19};
20
21/// Add two usage metadata objects together.
22///
23/// This function combines the token counts from two usage metadata objects,
24/// returning a new object with the summed values.
25pub fn add_usage(left: &UsageMetadata, right: &UsageMetadata) -> UsageMetadata {
26    left.add(right)
27}
28
29/// Callback Handler that tracks AIMessage.usage_metadata.
30///
31/// This handler collects token usage metadata from chat model responses,
32/// aggregating the usage by model name. It is thread-safe and can be used
33/// across multiple concurrent LLM calls.
34///
35/// # Example
36///
37/// ```ignore
38/// use agent_chain_core::callbacks::UsageMetadataCallbackHandler;
39/// use std::sync::Arc;
40///
41/// let handler = UsageMetadataCallbackHandler::new();
42///
43/// // Use with a callback manager
44/// let mut manager = CallbackManager::new();
45/// manager.add_handler(Arc::new(handler.clone()), true);
46///
47/// // After LLM calls complete
48/// let usage = handler.usage_metadata();
49/// for (model, metadata) in usage.iter() {
50///     println!("{}: {} tokens", model, metadata.total_tokens);
51/// }
52/// ```
53#[derive(Debug, Clone)]
54pub struct UsageMetadataCallbackHandler {
55    /// The usage metadata by model name, protected by a mutex for thread safety.
56    usage_metadata: Arc<Mutex<HashMap<String, UsageMetadata>>>,
57}
58
59impl Default for UsageMetadataCallbackHandler {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl UsageMetadataCallbackHandler {
66    /// Create a new UsageMetadataCallbackHandler.
67    pub fn new() -> Self {
68        Self {
69            usage_metadata: Arc::new(Mutex::new(HashMap::new())),
70        }
71    }
72
73    /// Get the collected usage metadata.
74    ///
75    /// Returns a clone of the current usage metadata map, keyed by model name.
76    pub fn usage_metadata(&self) -> HashMap<String, UsageMetadata> {
77        self.usage_metadata.lock().unwrap().clone()
78    }
79}
80
81impl fmt::Display for UsageMetadataCallbackHandler {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        write!(f, "{:?}", self.usage_metadata.lock().unwrap())
84    }
85}
86
87impl LLMManagerMixin for UsageMetadataCallbackHandler {
88    fn on_llm_end(&mut self, response: &ChatResult, _run_id: Uuid, _parent_run_id: Option<Uuid>) {
89        // Extract usage metadata from the first generation's message
90        let first_generation = response.generations.first();
91
92        let (usage_metadata, model_name) = match first_generation {
93            Some(generation) => {
94                // Try to get usage from the AIMessage
95                let usage = match &generation.message {
96                    BaseMessage::AI(ai_msg) => ai_msg.usage_metadata().cloned(),
97                    _ => None,
98                };
99
100                // Try to get model name from llm_output or response_metadata
101                let model = response
102                    .llm_output
103                    .as_ref()
104                    .and_then(|output| output.get("model"))
105                    .and_then(|v| v.as_str())
106                    .map(|s| s.to_string())
107                    .or_else(|| {
108                        generation
109                            .message
110                            .response_metadata()
111                            .and_then(|meta| meta.get("model"))
112                            .and_then(|v| v.as_str())
113                            .map(|s| s.to_string())
114                    });
115
116                (usage, model)
117            }
118            None => (None, None),
119        };
120
121        // Update shared state behind lock
122        if let (Some(usage), Some(model)) = (usage_metadata, model_name) {
123            let mut guard = self.usage_metadata.lock().unwrap();
124            if let Some(existing) = guard.get(&model) {
125                let combined = add_usage(existing, &usage);
126                guard.insert(model, combined);
127            } else {
128                guard.insert(model, usage);
129            }
130        }
131    }
132}
133
134impl ChainManagerMixin for UsageMetadataCallbackHandler {}
135impl ToolManagerMixin for UsageMetadataCallbackHandler {}
136impl RetrieverManagerMixin for UsageMetadataCallbackHandler {}
137impl CallbackManagerMixin for UsageMetadataCallbackHandler {}
138impl RunManagerMixin for UsageMetadataCallbackHandler {}
139
140impl BaseCallbackHandler for UsageMetadataCallbackHandler {
141    fn name(&self) -> &str {
142        "UsageMetadataCallbackHandler"
143    }
144}
145
146/// Guard type for the `get_usage_metadata_callback` function.
147///
148/// This guard provides access to the underlying `UsageMetadataCallbackHandler`
149/// and can be used with a callback manager to track usage metadata.
150///
151/// When the guard is dropped, any cleanup is performed automatically.
152///
153/// # Example
154///
155/// ```ignore
156/// use agent_chain_core::callbacks::get_usage_metadata_callback;
157///
158/// let callback_guard = get_usage_metadata_callback();
159/// // Use callback_guard.handler() with your callback manager
160/// // The handler can be cloned and added to managers
161///
162/// let usage = callback_guard.usage_metadata();
163/// for (model, metadata) in usage.iter() {
164///     println!("{}: {} tokens", model, metadata.total_tokens);
165/// }
166/// ```
167pub struct UsageMetadataCallbackGuard {
168    handler: UsageMetadataCallbackHandler,
169}
170
171impl UsageMetadataCallbackGuard {
172    /// Create a new usage metadata callback guard.
173    fn new() -> Self {
174        Self {
175            handler: UsageMetadataCallbackHandler::new(),
176        }
177    }
178
179    /// Get a reference to the underlying handler.
180    pub fn handler(&self) -> &UsageMetadataCallbackHandler {
181        &self.handler
182    }
183
184    /// Get a mutable reference to the underlying handler.
185    pub fn handler_mut(&mut self) -> &mut UsageMetadataCallbackHandler {
186        &mut self.handler
187    }
188
189    /// Get the collected usage metadata.
190    pub fn usage_metadata(&self) -> HashMap<String, UsageMetadata> {
191        self.handler.usage_metadata()
192    }
193
194    /// Get an Arc-wrapped handler suitable for use with callback managers.
195    pub fn as_arc_handler(&self) -> Arc<dyn BaseCallbackHandler> {
196        Arc::new(self.handler.clone()) as Arc<dyn BaseCallbackHandler>
197    }
198}
199
200impl Deref for UsageMetadataCallbackGuard {
201    type Target = UsageMetadataCallbackHandler;
202
203    fn deref(&self) -> &Self::Target {
204        &self.handler
205    }
206}
207
208impl DerefMut for UsageMetadataCallbackGuard {
209    fn deref_mut(&mut self) -> &mut Self::Target {
210        &mut self.handler
211    }
212}
213
214/// Get a usage metadata callback handler.
215///
216/// This function creates a `UsageMetadataCallbackGuard` that provides access
217/// to a `UsageMetadataCallbackHandler` for tracking token usage across chat
218/// model calls.
219///
220/// The returned guard implements `Deref` and `DerefMut` to the underlying
221/// handler, making it easy to use.
222///
223/// # Example
224///
225/// ```ignore
226/// use agent_chain_core::callbacks::{get_usage_metadata_callback, CallbackManager};
227/// use std::sync::Arc;
228///
229/// let callback = get_usage_metadata_callback();
230///
231/// // Add to a callback manager
232/// let mut manager = CallbackManager::new();
233/// manager.add_handler(callback.as_arc_handler(), true);
234///
235/// // After LLM calls complete
236/// let usage = callback.usage_metadata();
237/// for (model, metadata) in usage.iter() {
238///     println!("{}: {} tokens", model, metadata.total_tokens);
239/// }
240/// ```
241///
242/// This is the Rust equivalent of Python's `get_usage_metadata_callback()` context manager.
243pub fn get_usage_metadata_callback() -> UsageMetadataCallbackGuard {
244    UsageMetadataCallbackGuard::new()
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use crate::messages::AIMessage;
251    use crate::outputs::ChatGeneration;
252    use serde_json::json;
253
254    /// Helper to create a ChatResult with usage metadata for testing.
255    fn create_chat_result_with_usage(
256        content: &str,
257        model: &str,
258        input_tokens: u32,
259        output_tokens: u32,
260    ) -> ChatResult {
261        let mut ai_msg = AIMessage::new(content);
262        ai_msg = ai_msg.with_usage_metadata(UsageMetadata::new(
263            input_tokens as i64,
264            output_tokens as i64,
265        ));
266
267        let generation = ChatGeneration::new(ai_msg.into());
268
269        let mut llm_output = HashMap::new();
270        llm_output.insert("model".to_string(), json!(model));
271
272        ChatResult {
273            generations: vec![generation],
274            llm_output: Some(llm_output),
275        }
276    }
277
278    #[test]
279    fn test_usage_handler_creation() {
280        let handler = UsageMetadataCallbackHandler::new();
281        assert!(handler.usage_metadata().is_empty());
282        assert_eq!(handler.name(), "UsageMetadataCallbackHandler");
283    }
284
285    #[test]
286    fn test_add_usage() {
287        let usage1 = UsageMetadata::new(10, 20);
288        let usage2 = UsageMetadata::new(5, 15);
289        let combined = add_usage(&usage1, &usage2);
290
291        assert_eq!(combined.input_tokens, 15);
292        assert_eq!(combined.output_tokens, 35);
293        assert_eq!(combined.total_tokens, 50);
294    }
295
296    #[test]
297    fn test_on_llm_end_collects_usage() {
298        let mut handler = UsageMetadataCallbackHandler::new();
299
300        let result = create_chat_result_with_usage("Hello", "gpt-4", 10, 20);
301
302        handler.on_llm_end(&result, Uuid::new_v4(), None);
303
304        let usage = handler.usage_metadata();
305        assert_eq!(usage.len(), 1);
306
307        let gpt4_usage = usage.get("gpt-4").unwrap();
308        assert_eq!(gpt4_usage.input_tokens, 10);
309        assert_eq!(gpt4_usage.output_tokens, 20);
310        assert_eq!(gpt4_usage.total_tokens, 30);
311    }
312
313    #[test]
314    fn test_on_llm_end_accumulates_usage() {
315        let mut handler = UsageMetadataCallbackHandler::new();
316
317        let result1 = create_chat_result_with_usage("Hello", "gpt-4", 10, 20);
318        let result2 = create_chat_result_with_usage("World", "gpt-4", 5, 15);
319
320        handler.on_llm_end(&result1, Uuid::new_v4(), None);
321        handler.on_llm_end(&result2, Uuid::new_v4(), None);
322
323        let usage = handler.usage_metadata();
324        assert_eq!(usage.len(), 1);
325
326        let gpt4_usage = usage.get("gpt-4").unwrap();
327        assert_eq!(gpt4_usage.input_tokens, 15);
328        assert_eq!(gpt4_usage.output_tokens, 35);
329        assert_eq!(gpt4_usage.total_tokens, 50);
330    }
331
332    #[test]
333    fn test_on_llm_end_multiple_models() {
334        let mut handler = UsageMetadataCallbackHandler::new();
335
336        let result1 = create_chat_result_with_usage("Hello", "gpt-4", 10, 20);
337        let result2 = create_chat_result_with_usage("Hello", "claude-3", 8, 25);
338
339        handler.on_llm_end(&result1, Uuid::new_v4(), None);
340        handler.on_llm_end(&result2, Uuid::new_v4(), None);
341
342        let usage = handler.usage_metadata();
343        assert_eq!(usage.len(), 2);
344
345        let gpt4_usage = usage.get("gpt-4").unwrap();
346        assert_eq!(gpt4_usage.total_tokens, 30);
347
348        let claude_usage = usage.get("claude-3").unwrap();
349        assert_eq!(claude_usage.total_tokens, 33);
350    }
351
352    #[test]
353    fn test_clone_shares_state() {
354        let mut handler1 = UsageMetadataCallbackHandler::new();
355        let handler2 = handler1.clone();
356
357        let result = create_chat_result_with_usage("Hello", "gpt-4", 10, 20);
358
359        handler1.on_llm_end(&result, Uuid::new_v4(), None);
360
361        // Both handlers should see the same usage data
362        assert_eq!(handler1.usage_metadata(), handler2.usage_metadata());
363    }
364}