use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
use async_trait::async_trait;
use cognis_core::error::Result;
use cognis_core::language_models::chat_model::{
BaseChatModel, ChatStream, ModelProfile, ToolChoice,
};
use cognis_core::messages::Message;
use cognis_core::outputs::ChatResult;
use cognis_core::tools::ToolSchema;
use cognis_core::utils::tokens::estimate_token_count;
#[derive(Debug, Clone)]
pub struct ModelPricing {
pub input_cost_per_1k: f64,
pub output_cost_per_1k: f64,
pub model_name: String,
}
impl ModelPricing {
pub fn new(
model_name: impl Into<String>,
input_cost_per_1k: f64,
output_cost_per_1k: f64,
) -> Self {
Self {
model_name: model_name.into(),
input_cost_per_1k,
output_cost_per_1k,
}
}
pub fn calculate_cost(&self, input_tokens: usize, output_tokens: usize) -> f64 {
(input_tokens as f64 / 1000.0) * self.input_cost_per_1k
+ (output_tokens as f64 / 1000.0) * self.output_cost_per_1k
}
}
#[derive(Debug, Clone)]
pub struct PricingRegistry {
prices: HashMap<String, ModelPricing>,
}
impl PricingRegistry {
pub fn new() -> Self {
Self {
prices: HashMap::new(),
}
}
pub fn get_pricing(&self, model: &str) -> Option<&ModelPricing> {
self.prices.get(model)
}
pub fn register(&mut self, pricing: ModelPricing) {
self.prices.insert(pricing.model_name.clone(), pricing);
}
}
impl Default for PricingRegistry {
fn default() -> Self {
let mut registry = Self::new();
registry.register(ModelPricing::new("gpt-4o", 2.50, 10.00));
registry.register(ModelPricing::new("gpt-4o-mini", 0.15, 0.60));
registry.register(ModelPricing::new("gpt-4-turbo", 10.00, 30.00));
registry.register(ModelPricing::new("claude-3.5-sonnet", 3.00, 15.00));
registry.register(ModelPricing::new("claude-3-opus", 15.00, 75.00));
registry.register(ModelPricing::new("claude-3-haiku", 0.25, 1.25));
registry.register(ModelPricing::new("gemini-1.5-pro", 3.50, 10.50));
registry.register(ModelPricing::new("gemini-1.5-flash", 0.075, 0.30));
registry
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TokenUsage {
pub input_tokens: usize,
pub output_tokens: usize,
pub total_tokens: usize,
pub estimated_cost: Option<f64>,
}
impl TokenUsage {
pub fn new(input_tokens: usize, output_tokens: usize, estimated_cost: Option<f64>) -> Self {
Self {
input_tokens,
output_tokens,
total_tokens: input_tokens + output_tokens,
estimated_cost,
}
}
pub fn zero() -> Self {
Self {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
estimated_cost: None,
}
}
}
pub struct TokenCountingModel {
inner: Box<dyn BaseChatModel>,
cumulative_input: AtomicUsize,
cumulative_output: AtomicUsize,
last_input: AtomicUsize,
last_output: AtomicUsize,
pricing: Option<ModelPricing>,
_lock: Mutex<()>,
}
impl TokenCountingModel {
pub fn new(inner: Box<dyn BaseChatModel>, pricing: Option<ModelPricing>) -> Self {
Self {
inner,
cumulative_input: AtomicUsize::new(0),
cumulative_output: AtomicUsize::new(0),
last_input: AtomicUsize::new(0),
last_output: AtomicUsize::new(0),
pricing,
_lock: Mutex::new(()),
}
}
pub fn builder(inner: Box<dyn BaseChatModel>) -> TokenCountingModelBuilder {
TokenCountingModelBuilder {
inner,
pricing: None,
}
}
pub fn get_usage(&self) -> TokenUsage {
let input = self.cumulative_input.load(Ordering::SeqCst);
let output = self.cumulative_output.load(Ordering::SeqCst);
let cost = self
.pricing
.as_ref()
.map(|p| p.calculate_cost(input, output));
TokenUsage::new(input, output, cost)
}
pub fn get_last_usage(&self) -> TokenUsage {
let input = self.last_input.load(Ordering::SeqCst);
let output = self.last_output.load(Ordering::SeqCst);
let cost = self
.pricing
.as_ref()
.map(|p| p.calculate_cost(input, output));
TokenUsage::new(input, output, cost)
}
pub fn reset_usage(&self) {
self.cumulative_input.store(0, Ordering::SeqCst);
self.cumulative_output.store(0, Ordering::SeqCst);
self.last_input.store(0, Ordering::SeqCst);
self.last_output.store(0, Ordering::SeqCst);
}
fn record_usage(&self, input_tokens: usize, output_tokens: usize) {
let _guard = self._lock.lock().unwrap_or_else(|e| e.into_inner());
self.last_input.store(input_tokens, Ordering::SeqCst);
self.last_output.store(output_tokens, Ordering::SeqCst);
self.cumulative_input
.fetch_add(input_tokens, Ordering::SeqCst);
self.cumulative_output
.fetch_add(output_tokens, Ordering::SeqCst);
}
fn estimate_input_tokens(&self, messages: &[Message]) -> usize {
messages
.iter()
.map(|m| estimate_token_count(&m.content().text()))
.sum()
}
fn estimate_output_tokens(&self, result: &ChatResult) -> usize {
result
.generations
.iter()
.map(|g| estimate_token_count(&g.text))
.sum()
}
}
pub struct TokenCountingModelBuilder {
inner: Box<dyn BaseChatModel>,
pricing: Option<ModelPricing>,
}
impl TokenCountingModelBuilder {
pub fn with_pricing(mut self, pricing: ModelPricing) -> Self {
self.pricing = Some(pricing);
self
}
pub fn build(self) -> TokenCountingModel {
TokenCountingModel::new(self.inner, self.pricing)
}
}
#[async_trait]
impl BaseChatModel for TokenCountingModel {
async fn _generate(&self, messages: &[Message], stop: Option<&[String]>) -> Result<ChatResult> {
let input_tokens = self.estimate_input_tokens(messages);
let result = self.inner._generate(messages, stop).await?;
let output_tokens = self.estimate_output_tokens(&result);
self.record_usage(input_tokens, output_tokens);
Ok(result)
}
fn llm_type(&self) -> &str {
self.inner.llm_type()
}
async fn _stream(&self, messages: &[Message], stop: Option<&[String]>) -> Result<ChatStream> {
self.inner._stream(messages, stop).await
}
fn bind_tools(
&self,
tools: &[ToolSchema],
tool_choice: Option<ToolChoice>,
) -> Result<Box<dyn BaseChatModel>> {
self.inner.bind_tools(tools, tool_choice)
}
fn profile(&self) -> ModelProfile {
self.inner.profile()
}
fn get_num_tokens_from_messages(&self, messages: &[Message]) -> usize {
self.inner.get_num_tokens_from_messages(messages)
}
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::language_models::fake::FakeListChatModel;
use cognis_core::messages::{HumanMessage, Message};
fn human(text: &str) -> Message {
Message::Human(HumanMessage::new(text))
}
fn make_fake(responses: Vec<&str>) -> FakeListChatModel {
FakeListChatModel::new(responses.into_iter().map(String::from).collect())
}
#[tokio::test]
async fn test_token_counting_single_call() {
let model = TokenCountingModel::new(Box::new(make_fake(vec!["Hello there"])), None);
let msgs = vec![human("Hi")];
let result = model._generate(&msgs, None).await.unwrap();
assert_eq!(result.generations[0].text, "Hello there");
let usage = model.get_usage();
assert!(usage.input_tokens > 0);
assert!(usage.output_tokens > 0);
assert_eq!(usage.total_tokens, usage.input_tokens + usage.output_tokens);
}
#[tokio::test]
async fn test_cumulative_tracking() {
let model = TokenCountingModel::new(
Box::new(make_fake(vec!["Response one", "Response two"])),
None,
);
let msgs = vec![human("First question")];
model._generate(&msgs, None).await.unwrap();
let usage1 = model.get_usage();
let msgs2 = vec![human("Second question")];
model._generate(&msgs2, None).await.unwrap();
let usage2 = model.get_usage();
assert!(usage2.input_tokens > usage1.input_tokens);
assert!(usage2.output_tokens > usage1.output_tokens);
assert!(usage2.total_tokens > usage1.total_tokens);
}
#[tokio::test]
async fn test_cost_estimation_with_pricing() {
let pricing = ModelPricing::new("test-model", 2.0, 4.0);
let model = TokenCountingModel::new(
Box::new(make_fake(vec!["Hello world response"])),
Some(pricing),
);
let msgs = vec![human("Test input")];
model._generate(&msgs, None).await.unwrap();
let usage = model.get_usage();
assert!(usage.estimated_cost.is_some());
assert!(usage.estimated_cost.unwrap() > 0.0);
}
#[tokio::test]
async fn test_reset_usage() {
let model = TokenCountingModel::new(Box::new(make_fake(vec!["Response"])), None);
let msgs = vec![human("Hello")];
model._generate(&msgs, None).await.unwrap();
assert!(model.get_usage().total_tokens > 0);
model.reset_usage();
let usage = model.get_usage();
assert_eq!(usage.input_tokens, 0);
assert_eq!(usage.output_tokens, 0);
assert_eq!(usage.total_tokens, 0);
}
#[tokio::test]
async fn test_last_usage_tracking() {
let model = TokenCountingModel::new(
Box::new(make_fake(vec![
"Short",
"A much longer response than the first one",
])),
None,
);
let msgs = vec![human("Q1")];
model._generate(&msgs, None).await.unwrap();
let last1 = model.get_last_usage();
let msgs2 = vec![human("Q2")];
model._generate(&msgs2, None).await.unwrap();
let last2 = model.get_last_usage();
assert!(last2.output_tokens > last1.output_tokens);
assert_ne!(last1.total_tokens, last2.total_tokens);
}
#[tokio::test]
async fn test_no_pricing_returns_none_cost() {
let model = TokenCountingModel::new(Box::new(make_fake(vec!["Response"])), None);
let msgs = vec![human("Hello")];
model._generate(&msgs, None).await.unwrap();
let usage = model.get_usage();
assert!(usage.estimated_cost.is_none());
}
#[test]
fn test_pricing_registry_lookup() {
let registry = PricingRegistry::default();
let pricing = registry.get_pricing("gpt-4o");
assert!(pricing.is_some());
let p = pricing.unwrap();
assert_eq!(p.model_name, "gpt-4o");
assert!(p.input_cost_per_1k > 0.0);
assert!(p.output_cost_per_1k > 0.0);
}
#[test]
fn test_pricing_registry_custom_registration() {
let mut registry = PricingRegistry::default();
let custom = ModelPricing::new("my-custom-model", 1.0, 2.0);
registry.register(custom);
let pricing = registry.get_pricing("my-custom-model");
assert!(pricing.is_some());
assert_eq!(pricing.unwrap().input_cost_per_1k, 1.0);
assert_eq!(pricing.unwrap().output_cost_per_1k, 2.0);
}
#[test]
fn test_default_pricing_known_models() {
let registry = PricingRegistry::default();
let known = [
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"claude-3.5-sonnet",
"claude-3-opus",
"claude-3-haiku",
"gemini-1.5-pro",
"gemini-1.5-flash",
];
for model in &known {
assert!(
registry.get_pricing(model).is_some(),
"Missing pricing for {}",
model
);
}
}
#[tokio::test]
async fn test_builder_pattern() {
let pricing = ModelPricing::new("test-model", 1.0, 2.0);
let model = TokenCountingModel::builder(Box::new(make_fake(vec!["Built response"])))
.with_pricing(pricing)
.build();
let msgs = vec![human("Test")];
model._generate(&msgs, None).await.unwrap();
let usage = model.get_usage();
assert!(usage.total_tokens > 0);
assert!(usage.estimated_cost.is_some());
}
#[tokio::test]
async fn test_delegates_llm_type() {
let model = TokenCountingModel::new(Box::new(make_fake(vec!["Response"])), None);
assert_eq!(model.llm_type(), "fake_list_chat_model");
}
#[test]
fn test_pricing_registry_unknown_model_returns_none() {
let registry = PricingRegistry::default();
assert!(registry.get_pricing("nonexistent-model").is_none());
}
#[test]
fn test_model_pricing_calculate_cost() {
let pricing = ModelPricing::new("test", 2.0, 4.0);
let cost = pricing.calculate_cost(1000, 500);
assert!((cost - 4.0).abs() < f64::EPSILON);
}
}