Skip to main content

openclaw_providers/
usage.rs

1//! Usage tracking for providers.
2
3use std::collections::HashMap;
4use std::sync::RwLock;
5use std::sync::atomic::{AtomicU64, Ordering};
6
7use openclaw_core::types::TokenUsage;
8
9/// Usage tracker for monitoring token consumption.
10pub struct UsageTracker {
11    totals: RwLock<HashMap<String, ModelUsage>>,
12}
13
14/// Usage statistics for a model.
15#[derive(Debug, Default)]
16pub struct ModelUsage {
17    input_tokens: AtomicU64,
18    output_tokens: AtomicU64,
19    request_count: AtomicU64,
20}
21
22impl UsageTracker {
23    /// Create a new usage tracker.
24    #[must_use]
25    pub fn new() -> Self {
26        Self {
27            totals: RwLock::new(HashMap::new()),
28        }
29    }
30
31    /// Record token usage for a model.
32    pub fn record(&self, model: &str, usage: &TokenUsage) {
33        let mut totals = self.totals.write().unwrap();
34        let entry = totals.entry(model.to_string()).or_default();
35
36        entry
37            .input_tokens
38            .fetch_add(usage.input_tokens, Ordering::Relaxed);
39        entry
40            .output_tokens
41            .fetch_add(usage.output_tokens, Ordering::Relaxed);
42        entry.request_count.fetch_add(1, Ordering::Relaxed);
43    }
44
45    /// Get total usage for a model.
46    #[must_use]
47    pub fn get_usage(&self, model: &str) -> Option<TokenUsageSummary> {
48        let totals = self.totals.read().unwrap();
49        totals.get(model).map(|u| TokenUsageSummary {
50            input_tokens: u.input_tokens.load(Ordering::Relaxed),
51            output_tokens: u.output_tokens.load(Ordering::Relaxed),
52            request_count: u.request_count.load(Ordering::Relaxed),
53        })
54    }
55
56    /// Get total usage across all models.
57    #[must_use]
58    pub fn total_usage(&self) -> TokenUsageSummary {
59        let totals = self.totals.read().unwrap();
60        let mut summary = TokenUsageSummary::default();
61
62        for usage in totals.values() {
63            summary.input_tokens += usage.input_tokens.load(Ordering::Relaxed);
64            summary.output_tokens += usage.output_tokens.load(Ordering::Relaxed);
65            summary.request_count += usage.request_count.load(Ordering::Relaxed);
66        }
67
68        summary
69    }
70
71    /// Reset all usage statistics.
72    pub fn reset(&self) {
73        let mut totals = self.totals.write().unwrap();
74        totals.clear();
75    }
76}
77
78impl Default for UsageTracker {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84/// Summary of token usage.
85#[derive(Debug, Clone, Default)]
86pub struct TokenUsageSummary {
87    /// Total input tokens.
88    pub input_tokens: u64,
89    /// Total output tokens.
90    pub output_tokens: u64,
91    /// Total request count.
92    pub request_count: u64,
93}
94
95impl TokenUsageSummary {
96    /// Get total tokens (input + output).
97    #[must_use]
98    pub const fn total_tokens(&self) -> u64 {
99        self.input_tokens + self.output_tokens
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn test_usage_tracking() {
109        let tracker = UsageTracker::new();
110
111        tracker.record(
112            "claude-3-5-sonnet",
113            &TokenUsage {
114                input_tokens: 100,
115                output_tokens: 50,
116                cache_read_tokens: None,
117                cache_write_tokens: None,
118            },
119        );
120
121        tracker.record(
122            "claude-3-5-sonnet",
123            &TokenUsage {
124                input_tokens: 200,
125                output_tokens: 100,
126                cache_read_tokens: None,
127                cache_write_tokens: None,
128            },
129        );
130
131        let usage = tracker.get_usage("claude-3-5-sonnet").unwrap();
132        assert_eq!(usage.input_tokens, 300);
133        assert_eq!(usage.output_tokens, 150);
134        assert_eq!(usage.request_count, 2);
135    }
136
137    #[test]
138    fn test_total_usage() {
139        let tracker = UsageTracker::new();
140
141        tracker.record(
142            "model1",
143            &TokenUsage {
144                input_tokens: 100,
145                output_tokens: 50,
146                cache_read_tokens: None,
147                cache_write_tokens: None,
148            },
149        );
150
151        tracker.record(
152            "model2",
153            &TokenUsage {
154                input_tokens: 200,
155                output_tokens: 100,
156                cache_read_tokens: None,
157                cache_write_tokens: None,
158            },
159        );
160
161        let total = tracker.total_usage();
162        assert_eq!(total.input_tokens, 300);
163        assert_eq!(total.output_tokens, 150);
164        assert_eq!(total.total_tokens(), 450);
165    }
166}