Skip to main content

rustic_ai/
usage.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5
6#[derive(Debug, Clone, Default, Serialize, Deserialize)]
7pub struct RequestUsage {
8    pub input_tokens: u64,
9    pub output_tokens: u64,
10    pub cache_write_tokens: u64,
11    pub cache_read_tokens: u64,
12    pub input_audio_tokens: u64,
13    pub output_audio_tokens: u64,
14    pub details: HashMap<String, u64>,
15}
16
17impl RequestUsage {
18    pub fn total_tokens(&self) -> u64 {
19        self.input_tokens + self.output_tokens
20    }
21}
22
23#[derive(Debug, Clone, Default, Serialize, Deserialize)]
24pub struct RunUsage {
25    pub requests: u64,
26    pub tool_calls: u64,
27    pub input_tokens: u64,
28    pub output_tokens: u64,
29    pub cache_write_tokens: u64,
30    pub cache_read_tokens: u64,
31    pub input_audio_tokens: u64,
32    pub output_audio_tokens: u64,
33    pub details: HashMap<String, u64>,
34}
35
36impl RunUsage {
37    pub fn total_tokens(&self) -> u64 {
38        self.input_tokens + self.output_tokens
39    }
40
41    pub fn incr_request(&mut self, request: &RequestUsage) {
42        self.requests += 1;
43        self.input_tokens += request.input_tokens;
44        self.output_tokens += request.output_tokens;
45        self.cache_write_tokens += request.cache_write_tokens;
46        self.cache_read_tokens += request.cache_read_tokens;
47        self.input_audio_tokens += request.input_audio_tokens;
48        self.output_audio_tokens += request.output_audio_tokens;
49        for (k, v) in &request.details {
50            *self.details.entry(k.clone()).or_insert(0) += v;
51        }
52    }
53
54    pub fn incr_tool_call(&mut self) {
55        self.tool_calls += 1;
56    }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct UsageLimits {
61    pub request_limit: Option<u64>,
62    pub tool_calls_limit: Option<u64>,
63    pub input_tokens_limit: Option<u64>,
64    pub output_tokens_limit: Option<u64>,
65    pub total_tokens_limit: Option<u64>,
66}
67
68impl Default for UsageLimits {
69    fn default() -> Self {
70        Self {
71            request_limit: Some(50),
72            tool_calls_limit: None,
73            input_tokens_limit: None,
74            output_tokens_limit: None,
75            total_tokens_limit: None,
76        }
77    }
78}
79
80impl UsageLimits {
81    pub fn check_request(&self, current_requests: u64) -> Result<(), UsageError> {
82        if let Some(limit) = self.request_limit
83            && current_requests >= limit
84        {
85            return Err(UsageError::RequestLimitExceeded { limit });
86        }
87        Ok(())
88    }
89
90    pub fn check_tool_call(&self, current_calls: u64) -> Result<(), UsageError> {
91        if let Some(limit) = self.tool_calls_limit
92            && current_calls >= limit
93        {
94            return Err(UsageError::ToolCallsLimitExceeded { limit });
95        }
96        Ok(())
97    }
98
99    pub fn check_after_response(&self, usage: &RunUsage) -> Result<(), UsageError> {
100        if let Some(limit) = self.input_tokens_limit
101            && usage.input_tokens > limit
102        {
103            return Err(UsageError::InputTokensLimitExceeded { limit });
104        }
105        if let Some(limit) = self.output_tokens_limit
106            && usage.output_tokens > limit
107        {
108            return Err(UsageError::OutputTokensLimitExceeded { limit });
109        }
110        if let Some(limit) = self.total_tokens_limit
111            && usage.total_tokens() > limit
112        {
113            return Err(UsageError::TotalTokensLimitExceeded { limit });
114        }
115        Ok(())
116    }
117}
118
119#[derive(Debug, Error)]
120pub enum UsageError {
121    #[error("request limit exceeded (limit {limit})")]
122    RequestLimitExceeded { limit: u64 },
123    #[error("tool call limit exceeded (limit {limit})")]
124    ToolCallsLimitExceeded { limit: u64 },
125    #[error("input token limit exceeded (limit {limit})")]
126    InputTokensLimitExceeded { limit: u64 },
127    #[error("output token limit exceeded (limit {limit})")]
128    OutputTokensLimitExceeded { limit: u64 },
129    #[error("total token limit exceeded (limit {limit})")]
130    TotalTokensLimitExceeded { limit: u64 },
131}