use crate::language_models::TokenUsage;
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct UsageMetadata {
usage_by_model: HashMap<String, TokenUsage>,
}
impl UsageMetadata {
pub fn new() -> Self {
Self {
usage_by_model: HashMap::new(),
}
}
pub fn add_usage(&mut self, model: &str, usage: &TokenUsage) {
let entry = self
.usage_by_model
.entry(model.to_string())
.or_insert_with(|| TokenUsage::new(0, 0));
entry.add(usage);
}
pub fn get_total_usage(&self, model: &str) -> Option<&TokenUsage> {
self.usage_by_model.get(model)
}
pub fn get_all_usage(&self) -> &HashMap<String, TokenUsage> {
&self.usage_by_model
}
pub fn clear(&mut self) {
self.usage_by_model.clear();
}
pub fn merge(&mut self, other: UsageMetadata) {
for (model, usage) in other.usage_by_model {
self.add_usage(&model, &usage);
}
}
}
pub trait UsageCallback: Send + Sync {
fn on_usage(&mut self, model: &str, usage: &TokenUsage);
}
#[derive(Debug, Default)]
pub struct CollectingUsageCallback {
metadata: UsageMetadata,
}
impl CollectingUsageCallback {
pub fn new() -> Self {
Self {
metadata: UsageMetadata::new(),
}
}
pub fn get_metadata(&self) -> &UsageMetadata {
&self.metadata
}
pub fn into_metadata(self) -> UsageMetadata {
self.metadata
}
}
impl UsageCallback for CollectingUsageCallback {
fn on_usage(&mut self, model: &str, usage: &TokenUsage) {
self.metadata.add_usage(model, usage);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_usage_metadata_add() {
let mut usage = UsageMetadata::new();
let token_usage = TokenUsage::new(10, 20);
usage.add_usage("gpt-4o-mini", &token_usage);
let total = usage.get_total_usage("gpt-4o-mini").unwrap();
assert_eq!(total.prompt_tokens, 10);
assert_eq!(total.completion_tokens, 20);
assert_eq!(total.total_tokens, 30);
}
#[test]
fn test_usage_metadata_accumulate() {
let mut usage = UsageMetadata::new();
usage.add_usage("gpt-4o-mini", &TokenUsage::new(10, 20));
usage.add_usage("gpt-4o-mini", &TokenUsage::new(5, 10));
let total = usage.get_total_usage("gpt-4o-mini").unwrap();
assert_eq!(total.prompt_tokens, 15);
assert_eq!(total.completion_tokens, 30);
assert_eq!(total.total_tokens, 45);
}
#[test]
fn test_usage_metadata_merge() {
let mut usage1 = UsageMetadata::new();
usage1.add_usage("gpt-4o-mini", &TokenUsage::new(10, 20));
let mut usage2 = UsageMetadata::new();
usage2.add_usage("gpt-4o-mini", &TokenUsage::new(5, 10));
usage2.add_usage("claude-3", &TokenUsage::new(15, 25));
usage1.merge(usage2);
let gpt_usage = usage1.get_total_usage("gpt-4o-mini").unwrap();
assert_eq!(gpt_usage.total_tokens, 45);
let claude_usage = usage1.get_total_usage("claude-3").unwrap();
assert_eq!(claude_usage.total_tokens, 40);
}
#[test]
fn test_collecting_callback() {
let mut callback = CollectingUsageCallback::new();
callback.on_usage("gpt-4o-mini", &TokenUsage::new(10, 20));
let metadata = callback.get_metadata();
let usage = metadata.get_total_usage("gpt-4o-mini").unwrap();
assert_eq!(usage.total_tokens, 30);
}
}