1use std::fmt;
9
10use tiktoken_rs::{CoreBPE, get_bpe_from_model};
11
12use crate::config::CommitConfig;
13
14pub fn create_token_counter(config: &CommitConfig) -> TokenCounter {
16 TokenCounter::new(&config.api_base_url, config.api_key.as_deref(), &config.analysis_model)
17}
18
19pub struct TokenCounter {
21 client: reqwest::Client,
22 api_base_url: String,
23 api_key: Option<String>,
24 model: String,
25 tiktoken: Option<CoreBPE>,
26}
27
28impl fmt::Debug for TokenCounter {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 f.debug_struct("TokenCounter")
31 .field("model", &self.model)
32 .field("has_tiktoken", &self.tiktoken.is_some())
33 .finish_non_exhaustive()
34 }
35}
36
37impl TokenCounter {
38 pub fn new(api_base_url: &str, api_key: Option<&str>, model: &str) -> Self {
40 Self {
41 client: reqwest::Client::new(),
42 api_base_url: api_base_url.to_string(),
43 api_key: api_key.map(String::from),
44 model: model.to_string(),
45 tiktoken: get_bpe_from_model(model).ok(),
46 }
47 }
48
49 pub async fn count(&self, text: &str) -> usize {
53 if let Some(count) = self.try_api_count(text).await {
55 return count;
56 }
57 self.count_sync(text)
59 }
60
61 pub fn count_sync(&self, text: &str) -> usize {
63 if let Some(encoder) = &self.tiktoken {
64 encoder.encode_with_special_tokens(text).len()
65 } else {
66 text.len() / 4
67 }
68 }
69
70 async fn try_api_count(&self, text: &str) -> Option<usize> {
74 let api_key = self.api_key.as_ref()?;
75
76 if self.api_base_url.contains("openai.com") {
78 return None;
79 }
80
81 let resp = self
84 .client
85 .post(format!("{}/messages/count_tokens", self.api_base_url))
86 .header("x-api-key", api_key)
87 .header("anthropic-version", "2023-06-01")
88 .header("content-type", "application/json")
89 .json(&serde_json::json!({
90 "model": self.model,
91 "messages": [{"role": "user", "content": text}]
92 }))
93 .send()
94 .await
95 .ok()?;
96
97 let body: serde_json::Value = resp.json().await.ok()?;
98 body["input_tokens"].as_u64().map(|n| n as usize)
99 }
100}