1use crate::model_registry::get_model_info;
2
3use crate::consts::*;
4use crate::models::TokenUsage;
5
6pub async fn calculate_cost(model_id: &str, usage: &TokenUsage) -> Option<f64> {
7 if let Some(cost) = usage.cost {
8 return Some(cost);
9 }
10 let info = get_model_info(model_id).await?;
11 calculate_cost_prefetched(&info, usage)
12}
13
14pub fn calculate_cost_prefetched(
15 info: &crate::model_registry::ModelInfo,
16 usage: &TokenUsage,
17) -> Option<f64> {
18 let mut total_cost = 0.0;
19 let mut has_cost = false;
20
21 if let Some(input_cost) = info.input_cost_per_token {
22 total_cost += usage.prompt_tokens as f64 * input_cost;
23 has_cost = true;
24 }
25
26 if let Some(output_cost) = info.output_cost_per_token {
27 total_cost += usage.completion_tokens as f64 * output_cost;
28 has_cost = true;
29 }
30
31 if has_cost { Some(total_cost) } else { None }
32}
33
34pub struct HeuristicCounter {
35 total_bytes: usize,
36}
37
38impl Default for HeuristicCounter {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl HeuristicCounter {
45 pub fn new() -> Self {
46 Self { total_bytes: 0 }
47 }
48
49 pub fn count(&self) -> u32 {
50 (self.total_bytes as u32).div_ceil(4)
51 }
52
53 pub fn add_str(&mut self, s: &str) {
54 self.total_bytes += s.len();
55 }
56}
57
58impl std::fmt::Write for HeuristicCounter {
59 fn write_str(&mut self, s: &str) -> std::fmt::Result {
60 self.total_bytes += s.len();
61 Ok(())
62 }
63}
64
65pub fn count_heuristic(text: &str) -> u32 {
66 (text.len() as u32).div_ceil(4)
67}
68
69pub fn count_tokens_for_messages(messages: &[&str]) -> u32 {
70 let mut counter = HeuristicCounter::new();
71 for message in messages {
72 counter.add_str(message);
73 }
74 counter.count()
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80
81 #[test]
82 fn test_count_heuristic() {
83 assert_eq!(count_heuristic("abcd"), 1);
84 assert_eq!(count_heuristic("abcdefgh"), 2);
85 }
86
87 #[tokio::test]
88 async fn test_calculate_cost_priority() {
89 let usage = TokenUsage {
90 prompt_tokens: 10,
91 completion_tokens: 10,
92 total_tokens: 20,
93 cached_tokens: None,
94 reasoning_tokens: None,
95 cost: Some(1.234),
96 };
97
98 let cost = calculate_cost("non-existent-model", &usage).await;
99 assert_eq!(cost, Some(1.234));
100 }
101}
102
103pub const SYSTEM_TOKEN_COUNT: u32 =
104 ((DEFAULT_SYSTEM_PROMPT.len() + DIFF_MODE_INSTRUCTIONS.len()) as u32).div_ceil(4);
105
106pub const MAX_ALIGNMENT_TOKENS: u32 = {
107 let conv = ((ALIGNMENT_CONVERSATION_USER.len() + ALIGNMENT_CONVERSATION_ASSISTANT.len())
108 as u32)
109 .div_ceil(4);
110 let diff = ((ALIGNMENT_DIFF_USER.len() + ALIGNMENT_DIFF_ASSISTANT.len()) as u32).div_ceil(4);
111 let anchor_tokens = ((STATIC_CONTEXT_INTRO.len()
112 + STATIC_CONTEXT_ANCHOR.len()
113 + FLOATING_CONTEXT_INTRO.len()
114 + FLOATING_CONTEXT_ANCHOR.len()) as u32)
115 .div_ceil(4);
116
117 let base_max = if conv > diff { conv } else { diff };
118 base_max + anchor_tokens
119};