agent_chain_core/callbacks/
usage.rs1use std::collections::HashMap;
7use std::fmt;
8use std::ops::{Deref, DerefMut};
9use std::sync::{Arc, Mutex};
10
11use uuid::Uuid;
12
13use crate::messages::{BaseMessage, UsageMetadata};
14use crate::outputs::ChatResult;
15
16use super::base::{
17 BaseCallbackHandler, CallbackManagerMixin, ChainManagerMixin, LLMManagerMixin,
18 RetrieverManagerMixin, RunManagerMixin, ToolManagerMixin,
19};
20
21pub fn add_usage(left: &UsageMetadata, right: &UsageMetadata) -> UsageMetadata {
26 left.add(right)
27}
28
29#[derive(Debug, Clone)]
54pub struct UsageMetadataCallbackHandler {
55 usage_metadata: Arc<Mutex<HashMap<String, UsageMetadata>>>,
57}
58
59impl Default for UsageMetadataCallbackHandler {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl UsageMetadataCallbackHandler {
66 pub fn new() -> Self {
68 Self {
69 usage_metadata: Arc::new(Mutex::new(HashMap::new())),
70 }
71 }
72
73 pub fn usage_metadata(&self) -> HashMap<String, UsageMetadata> {
77 self.usage_metadata.lock().unwrap().clone()
78 }
79}
80
81impl fmt::Display for UsageMetadataCallbackHandler {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 write!(f, "{:?}", self.usage_metadata.lock().unwrap())
84 }
85}
86
87impl LLMManagerMixin for UsageMetadataCallbackHandler {
88 fn on_llm_end(&mut self, response: &ChatResult, _run_id: Uuid, _parent_run_id: Option<Uuid>) {
89 let first_generation = response.generations.first();
91
92 let (usage_metadata, model_name) = match first_generation {
93 Some(generation) => {
94 let usage = match &generation.message {
96 BaseMessage::AI(ai_msg) => ai_msg.usage_metadata().cloned(),
97 _ => None,
98 };
99
100 let model = response
102 .llm_output
103 .as_ref()
104 .and_then(|output| output.get("model"))
105 .and_then(|v| v.as_str())
106 .map(|s| s.to_string())
107 .or_else(|| {
108 generation
109 .message
110 .response_metadata()
111 .and_then(|meta| meta.get("model"))
112 .and_then(|v| v.as_str())
113 .map(|s| s.to_string())
114 });
115
116 (usage, model)
117 }
118 None => (None, None),
119 };
120
121 if let (Some(usage), Some(model)) = (usage_metadata, model_name) {
123 let mut guard = self.usage_metadata.lock().unwrap();
124 if let Some(existing) = guard.get(&model) {
125 let combined = add_usage(existing, &usage);
126 guard.insert(model, combined);
127 } else {
128 guard.insert(model, usage);
129 }
130 }
131 }
132}
133
134impl ChainManagerMixin for UsageMetadataCallbackHandler {}
135impl ToolManagerMixin for UsageMetadataCallbackHandler {}
136impl RetrieverManagerMixin for UsageMetadataCallbackHandler {}
137impl CallbackManagerMixin for UsageMetadataCallbackHandler {}
138impl RunManagerMixin for UsageMetadataCallbackHandler {}
139
140impl BaseCallbackHandler for UsageMetadataCallbackHandler {
141 fn name(&self) -> &str {
142 "UsageMetadataCallbackHandler"
143 }
144}
145
146pub struct UsageMetadataCallbackGuard {
168 handler: UsageMetadataCallbackHandler,
169}
170
171impl UsageMetadataCallbackGuard {
172 fn new() -> Self {
174 Self {
175 handler: UsageMetadataCallbackHandler::new(),
176 }
177 }
178
179 pub fn handler(&self) -> &UsageMetadataCallbackHandler {
181 &self.handler
182 }
183
184 pub fn handler_mut(&mut self) -> &mut UsageMetadataCallbackHandler {
186 &mut self.handler
187 }
188
189 pub fn usage_metadata(&self) -> HashMap<String, UsageMetadata> {
191 self.handler.usage_metadata()
192 }
193
194 pub fn as_arc_handler(&self) -> Arc<dyn BaseCallbackHandler> {
196 Arc::new(self.handler.clone()) as Arc<dyn BaseCallbackHandler>
197 }
198}
199
200impl Deref for UsageMetadataCallbackGuard {
201 type Target = UsageMetadataCallbackHandler;
202
203 fn deref(&self) -> &Self::Target {
204 &self.handler
205 }
206}
207
208impl DerefMut for UsageMetadataCallbackGuard {
209 fn deref_mut(&mut self) -> &mut Self::Target {
210 &mut self.handler
211 }
212}
213
214pub fn get_usage_metadata_callback() -> UsageMetadataCallbackGuard {
244 UsageMetadataCallbackGuard::new()
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use crate::messages::AIMessage;
251 use crate::outputs::ChatGeneration;
252 use serde_json::json;
253
254 fn create_chat_result_with_usage(
256 content: &str,
257 model: &str,
258 input_tokens: u32,
259 output_tokens: u32,
260 ) -> ChatResult {
261 let mut ai_msg = AIMessage::new(content);
262 ai_msg = ai_msg.with_usage_metadata(UsageMetadata::new(
263 input_tokens as i64,
264 output_tokens as i64,
265 ));
266
267 let generation = ChatGeneration::new(ai_msg.into());
268
269 let mut llm_output = HashMap::new();
270 llm_output.insert("model".to_string(), json!(model));
271
272 ChatResult {
273 generations: vec![generation],
274 llm_output: Some(llm_output),
275 }
276 }
277
278 #[test]
279 fn test_usage_handler_creation() {
280 let handler = UsageMetadataCallbackHandler::new();
281 assert!(handler.usage_metadata().is_empty());
282 assert_eq!(handler.name(), "UsageMetadataCallbackHandler");
283 }
284
285 #[test]
286 fn test_add_usage() {
287 let usage1 = UsageMetadata::new(10, 20);
288 let usage2 = UsageMetadata::new(5, 15);
289 let combined = add_usage(&usage1, &usage2);
290
291 assert_eq!(combined.input_tokens, 15);
292 assert_eq!(combined.output_tokens, 35);
293 assert_eq!(combined.total_tokens, 50);
294 }
295
296 #[test]
297 fn test_on_llm_end_collects_usage() {
298 let mut handler = UsageMetadataCallbackHandler::new();
299
300 let result = create_chat_result_with_usage("Hello", "gpt-4", 10, 20);
301
302 handler.on_llm_end(&result, Uuid::new_v4(), None);
303
304 let usage = handler.usage_metadata();
305 assert_eq!(usage.len(), 1);
306
307 let gpt4_usage = usage.get("gpt-4").unwrap();
308 assert_eq!(gpt4_usage.input_tokens, 10);
309 assert_eq!(gpt4_usage.output_tokens, 20);
310 assert_eq!(gpt4_usage.total_tokens, 30);
311 }
312
313 #[test]
314 fn test_on_llm_end_accumulates_usage() {
315 let mut handler = UsageMetadataCallbackHandler::new();
316
317 let result1 = create_chat_result_with_usage("Hello", "gpt-4", 10, 20);
318 let result2 = create_chat_result_with_usage("World", "gpt-4", 5, 15);
319
320 handler.on_llm_end(&result1, Uuid::new_v4(), None);
321 handler.on_llm_end(&result2, Uuid::new_v4(), None);
322
323 let usage = handler.usage_metadata();
324 assert_eq!(usage.len(), 1);
325
326 let gpt4_usage = usage.get("gpt-4").unwrap();
327 assert_eq!(gpt4_usage.input_tokens, 15);
328 assert_eq!(gpt4_usage.output_tokens, 35);
329 assert_eq!(gpt4_usage.total_tokens, 50);
330 }
331
332 #[test]
333 fn test_on_llm_end_multiple_models() {
334 let mut handler = UsageMetadataCallbackHandler::new();
335
336 let result1 = create_chat_result_with_usage("Hello", "gpt-4", 10, 20);
337 let result2 = create_chat_result_with_usage("Hello", "claude-3", 8, 25);
338
339 handler.on_llm_end(&result1, Uuid::new_v4(), None);
340 handler.on_llm_end(&result2, Uuid::new_v4(), None);
341
342 let usage = handler.usage_metadata();
343 assert_eq!(usage.len(), 2);
344
345 let gpt4_usage = usage.get("gpt-4").unwrap();
346 assert_eq!(gpt4_usage.total_tokens, 30);
347
348 let claude_usage = usage.get("claude-3").unwrap();
349 assert_eq!(claude_usage.total_tokens, 33);
350 }
351
352 #[test]
353 fn test_clone_shares_state() {
354 let mut handler1 = UsageMetadataCallbackHandler::new();
355 let handler2 = handler1.clone();
356
357 let result = create_chat_result_with_usage("Hello", "gpt-4", 10, 20);
358
359 handler1.on_llm_end(&result, Uuid::new_v4(), None);
360
361 assert_eq!(handler1.usage_metadata(), handler2.usage_metadata());
363 }
364}