ai-lib-contact 1.0.0

AI-Protocol policy layer: cache, batch, routing, plugins, resilience, guardrails, tokens, telemetry
//! Token counter implementations.

use ai_lib_core::types::message::{ContentBlock, MessageContent};
use ai_lib_core::types::Message;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};

pub trait TokenCounter: Send + Sync {
    fn count(&self, text: &str) -> usize;

    fn count_messages(&self, messages: &[Message]) -> usize {
        let mut total = 0;
        for message in messages {
            total += 1;
            match &message.content {
                MessageContent::Text(text) => {
                    total += self.count(text);
                }
                MessageContent::Blocks(blocks) => {
                    for block in blocks {
                        match block {
                            ContentBlock::Text { text } => {
                                total += self.count(text);
                            }
                            ContentBlock::Image { .. } => {
                                total += 85;
                            }
                            ContentBlock::Audio { .. } => {
                                total += 100;
                            }
                            ContentBlock::Document { source } => {
                                total += self.count(&source.data);
                            }
                            ContentBlock::ToolUse { input, .. } => {
                                total +=
                                    self.count(&serde_json::to_string(input).unwrap_or_default());
                            }
                            ContentBlock::ToolResult { content, .. } => {
                                total +=
                                    self.count(&serde_json::to_string(content).unwrap_or_default());
                            }
                        }
                    }
                }
            }
        }
        total + messages.len() * 3
    }

    fn truncate_to_limit(&self, text: &str, max_tokens: usize, suffix: &str) -> String {
        let current = self.count(text);
        if current <= max_tokens {
            return text.to_string();
        }
        let suffix_tokens = if suffix.is_empty() {
            0
        } else {
            self.count(suffix)
        };
        let target = max_tokens.saturating_sub(suffix_tokens);
        if target == 0 {
            return suffix.to_string();
        }
        let chars_per_token = text.len() as f64 / current as f64;
        let mut truncated: String = text
            .chars()
            .take((target as f64 * chars_per_token) as usize)
            .collect();
        while self.count(&truncated) > target && !truncated.is_empty() {
            truncated = truncated
                .chars()
                .take((truncated.len() as f64 * 0.9) as usize)
                .collect();
        }
        format!("{}{}", truncated, suffix)
    }
}

#[derive(Debug, Clone)]
pub struct CharacterEstimator {
    chars_per_token: f64,
}
impl CharacterEstimator {
    pub fn new() -> Self {
        Self::with_ratio(4.0)
    }
    pub fn with_ratio(r: f64) -> Self {
        Self { chars_per_token: r }
    }
}
impl Default for CharacterEstimator {
    fn default() -> Self {
        Self::new()
    }
}
impl TokenCounter for CharacterEstimator {
    fn count(&self, text: &str) -> usize {
        (text.len() as f64 / self.chars_per_token).ceil() as usize
    }
}

#[derive(Debug, Clone)]
pub struct AnthropicEstimator {
    chars_per_token: f64,
}
impl AnthropicEstimator {
    pub fn new() -> Self {
        Self {
            chars_per_token: 3.5,
        }
    }
}
impl Default for AnthropicEstimator {
    fn default() -> Self {
        Self::new()
    }
}
impl TokenCounter for AnthropicEstimator {
    fn count(&self, text: &str) -> usize {
        let base = (text.len() as f64 / self.chars_per_token).ceil() as usize;
        let ws = text.chars().filter(|c| c.is_whitespace()).count();
        base + (ws as f64 * 0.1) as usize
    }
}

pub struct CachingCounter {
    inner: Box<dyn TokenCounter>,
    cache: Arc<RwLock<HashMap<String, usize>>>,
    max_size: usize,
}
impl CachingCounter {
    pub fn new(inner: Box<dyn TokenCounter>, max_size: usize) -> Self {
        Self {
            inner,
            cache: Arc::new(RwLock::new(HashMap::new())),
            max_size,
        }
    }
    pub fn clear_cache(&self) {
        self.cache.write().unwrap().clear();
    }
}
impl TokenCounter for CachingCounter {
    fn count(&self, text: &str) -> usize {
        {
            let c = self.cache.read().unwrap();
            if let Some(&n) = c.get(text) {
                return n;
            }
        }
        let n = self.inner.count(text);
        {
            let mut c = self.cache.write().unwrap();
            if c.len() < self.max_size {
                c.insert(text.to_string(), n);
            }
        }
        n
    }
}

static COUNTERS: once_cell::sync::Lazy<RwLock<HashMap<String, Arc<dyn TokenCounter>>>> =
    once_cell::sync::Lazy::new(|| RwLock::new(HashMap::new()));

pub fn get_token_counter(model: &str) -> Arc<dyn TokenCounter> {
    let ml = model.to_lowercase();
    {
        let c = COUNTERS.read().unwrap();
        if let Some(x) = c.get(&ml) {
            return x.clone();
        }
    }
    let counter: Arc<dyn TokenCounter> = if ml.contains("gpt") || ml.contains("o1") {
        Arc::new(CharacterEstimator::new())
    } else if ml.contains("claude") {
        Arc::new(AnthropicEstimator::new())
    } else {
        Arc::new(CharacterEstimator::new())
    };
    {
        let mut c = COUNTERS.write().unwrap();
        c.insert(ml, counter.clone());
    }
    counter
}