llmkit-core 0.1.0

Core traits, types, and errors for llmkit-rs — no I/O, runtime-agnostic
Documentation
//! Token usage and per-model cost estimation.

use serde::{Deserialize, Serialize};

/// Normalised token counts for one request.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct TokenUsage {
    /// Input (prompt) tokens.
    pub prompt: u32,
    /// Output (completion) tokens.
    pub completion: u32,
}

impl TokenUsage {
    /// Construct from prompt and completion counts.
    pub const fn new(prompt: u32, completion: u32) -> Self {
        Self { prompt, completion }
    }

    /// Total tokens.
    pub const fn total(&self) -> u32 {
        self.prompt + self.completion
    }
}

impl std::ops::Add for TokenUsage {
    type Output = TokenUsage;
    fn add(self, rhs: TokenUsage) -> TokenUsage {
        TokenUsage { prompt: self.prompt + rhs.prompt, completion: self.completion + rhs.completion }
    }
}

impl std::ops::AddAssign for TokenUsage {
    fn add_assign(&mut self, rhs: TokenUsage) {
        self.prompt += rhs.prompt;
        self.completion += rhs.completion;
    }
}

/// USD pricing per 1M tokens.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ModelPricing {
    /// USD per 1M input tokens.
    pub input_per_mtok: f64,
    /// USD per 1M output tokens.
    pub output_per_mtok: f64,
}

impl ModelPricing {
    /// Cost for the given usage under this pricing.
    pub fn cost_for(&self, usage: TokenUsage) -> CostEstimate {
        CostEstimate {
            input_usd: (usage.prompt as f64 / 1_000_000.0) * self.input_per_mtok,
            output_usd: (usage.completion as f64 / 1_000_000.0) * self.output_per_mtok,
        }
    }
}

/// Computed cost breakdown in USD.
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
pub struct CostEstimate {
    /// Cost of input tokens.
    pub input_usd: f64,
    /// Cost of output tokens.
    pub output_usd: f64,
}

impl CostEstimate {
    /// Total cost (input + output).
    pub fn total_usd(&self) -> f64 {
        self.input_usd + self.output_usd
    }
}

impl std::ops::Add for CostEstimate {
    type Output = CostEstimate;
    fn add(self, rhs: CostEstimate) -> CostEstimate {
        CostEstimate {
            input_usd: self.input_usd + rhs.input_usd,
            output_usd: self.output_usd + rhs.output_usd,
        }
    }
}

/// Per-model pricing lookup. Unknown models (e.g. local Ollama) return `None`.
pub mod pricing {
    use super::ModelPricing;

    const fn p(input: f64, output: f64) -> ModelPricing {
        ModelPricing { input_per_mtok: input, output_per_mtok: output }
    }

    /// Look up pricing for a model slug by prefix. `None` if unknown.
    pub fn pricing_for(model: &str) -> Option<ModelPricing> {
        let m = model.trim().to_ascii_lowercase();
        let m = m.strip_prefix("anthropic.").unwrap_or(&m);

        let pricing = match () {
            _ if m.starts_with("gpt-4o-mini") => p(0.15, 0.60),
            _ if m.starts_with("gpt-4o") => p(2.50, 10.0),
            _ if m.starts_with("o1-mini") => p(3.0, 12.0),
            _ if m.starts_with("o1") => p(15.0, 60.0),
            _ if m.starts_with("text-embedding-3-small") => p(0.02, 0.0),
            _ if m.starts_with("text-embedding-3-large") => p(0.13, 0.0),
            _ if m.starts_with("claude-opus-4") => p(5.0, 25.0),
            _ if m.starts_with("claude-sonnet-4") || m.starts_with("claude-3-5-sonnet") => p(3.0, 15.0),
            _ if m.starts_with("claude-haiku-4") || m.starts_with("claude-3-5-haiku") => p(1.0, 5.0),
            _ if m.starts_with("claude-3-opus") => p(15.0, 75.0),
            _ => return None,
        };
        Some(pricing)
    }

    /// Look up pricing, falling back to `default` when unknown.
    pub fn pricing_for_or(model: &str, default: ModelPricing) -> ModelPricing {
        pricing_for(model).unwrap_or(default)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn usage_arithmetic() {
        let mut c = TokenUsage::new(10, 5);
        c += TokenUsage::new(3, 7);
        assert_eq!(c, TokenUsage::new(13, 12));
        assert_eq!(c.total(), 25);
    }

    #[test]
    fn cost_computation() {
        let cost = ModelPricing { input_per_mtok: 5.0, output_per_mtok: 25.0 }
            .cost_for(TokenUsage::new(1_000_000, 1_000_000));
        assert!((cost.total_usd() - 30.0).abs() < 1e-9);
    }

    #[test]
    fn pricing_lookup() {
        assert!(pricing::pricing_for("gpt-4o-mini").is_some());
        assert!(pricing::pricing_for("claude-opus-4-8").is_some());
        assert!(pricing::pricing_for("llama3.1").is_none());
    }
}