use std::fmt;
use tiktoken_rs::{CoreBPE, bpe_for_model};
use crate::config::CommitConfig;
pub fn create_token_counter(config: &CommitConfig) -> TokenCounter {
TokenCounter::new(&config.api_base_url, config.api_key.as_deref(), &config.analysis_model)
}
pub struct TokenCounter {
client: reqwest::Client,
api_base_url: String,
api_key: Option<String>,
model: String,
tiktoken: Option<&'static CoreBPE>,
}
impl fmt::Debug for TokenCounter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TokenCounter")
.field("model", &self.model)
.field("has_tiktoken", &self.tiktoken.is_some())
.finish_non_exhaustive()
}
}
impl TokenCounter {
pub fn new(api_base_url: &str, api_key: Option<&str>, model: &str) -> Self {
Self {
client: reqwest::Client::new(),
api_base_url: api_base_url.to_string(),
api_key: api_key.map(String::from),
model: model.to_string(),
tiktoken: bpe_for_model(model).ok(),
}
}
pub async fn count(&self, text: &str) -> usize {
if let Some(count) = self.try_api_count(text).await {
return count;
}
self.count_sync(text)
}
pub fn count_sync(&self, text: &str) -> usize {
if let Some(encoder) = &self.tiktoken {
encoder.encode_with_special_tokens(text).len()
} else {
text.len() / 4
}
}
async fn try_api_count(&self, text: &str) -> Option<usize> {
let api_key = self.api_key.as_ref()?;
if self.api_base_url.contains("openai.com") {
return None;
}
let resp = self
.client
.post(format!("{}/messages/count_tokens", self.api_base_url))
.header("x-api-key", api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&serde_json::json!({
"model": self.model,
"messages": [{"role": "user", "content": text}]
}))
.send()
.await
.ok()?;
let body: serde_json::Value = resp.json().await.ok()?;
body["input_tokens"].as_u64().map(|n| n as usize)
}
}