aico/llm/
tokens.rs

1use crate::model_registry::get_model_info;
2
3use crate::consts::*;
4use crate::models::TokenUsage;
5
6pub async fn calculate_cost(model_id: &str, usage: &TokenUsage) -> Option<f64> {
7    if let Some(cost) = usage.cost {
8        return Some(cost);
9    }
10    let info = get_model_info(model_id).await?;
11    calculate_cost_prefetched(&info, usage)
12}
13
14pub fn calculate_cost_prefetched(
15    info: &crate::model_registry::ModelInfo,
16    usage: &TokenUsage,
17) -> Option<f64> {
18    let mut total_cost = 0.0;
19    let mut has_cost = false;
20
21    if let Some(input_cost) = info.input_cost_per_token {
22        total_cost += usage.prompt_tokens as f64 * input_cost;
23        has_cost = true;
24    }
25
26    if let Some(output_cost) = info.output_cost_per_token {
27        total_cost += usage.completion_tokens as f64 * output_cost;
28        has_cost = true;
29    }
30
31    if has_cost { Some(total_cost) } else { None }
32}
33
34pub struct HeuristicCounter {
35    total_bytes: usize,
36}
37
38impl Default for HeuristicCounter {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl HeuristicCounter {
45    pub fn new() -> Self {
46        Self { total_bytes: 0 }
47    }
48
49    pub fn count(&self) -> u32 {
50        (self.total_bytes as u32).div_ceil(4)
51    }
52
53    pub fn add_str(&mut self, s: &str) {
54        self.total_bytes += s.len();
55    }
56}
57
58impl std::fmt::Write for HeuristicCounter {
59    fn write_str(&mut self, s: &str) -> std::fmt::Result {
60        self.total_bytes += s.len();
61        Ok(())
62    }
63}
64
65pub fn count_heuristic(text: &str) -> u32 {
66    (text.len() as u32).div_ceil(4)
67}
68
69pub fn count_tokens_for_messages(messages: &[&str]) -> u32 {
70    let mut counter = HeuristicCounter::new();
71    for message in messages {
72        counter.add_str(message);
73    }
74    counter.count()
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80
81    #[test]
82    fn test_count_heuristic() {
83        assert_eq!(count_heuristic("abcd"), 1);
84        assert_eq!(count_heuristic("abcdefgh"), 2);
85    }
86
87    #[tokio::test]
88    async fn test_calculate_cost_priority() {
89        let usage = TokenUsage {
90            prompt_tokens: 10,
91            completion_tokens: 10,
92            total_tokens: 20,
93            cached_tokens: None,
94            reasoning_tokens: None,
95            cost: Some(1.234),
96        };
97
98        let cost = calculate_cost("non-existent-model", &usage).await;
99        assert_eq!(cost, Some(1.234));
100    }
101}
102
103pub const SYSTEM_TOKEN_COUNT: u32 =
104    ((DEFAULT_SYSTEM_PROMPT.len() + DIFF_MODE_INSTRUCTIONS.len()) as u32).div_ceil(4);
105
106pub const MAX_ALIGNMENT_TOKENS: u32 = {
107    let conv = ((ALIGNMENT_CONVERSATION_USER.len() + ALIGNMENT_CONVERSATION_ASSISTANT.len())
108        as u32)
109        .div_ceil(4);
110    let diff = ((ALIGNMENT_DIFF_USER.len() + ALIGNMENT_DIFF_ASSISTANT.len()) as u32).div_ceil(4);
111    let anchor_tokens = ((STATIC_CONTEXT_INTRO.len()
112        + STATIC_CONTEXT_ANCHOR.len()
113        + FLOATING_CONTEXT_INTRO.len()
114        + FLOATING_CONTEXT_ANCHOR.len()) as u32)
115        .div_ceil(4);
116
117    let base_max = if conv > diff { conv } else { diff };
118    base_max + anchor_tokens
119};