use crate::types::{EstimationConfidence, Message};
const DEFAULT_OUTPUT_TOKENS: u32 = 512;
pub struct EstimateResult {
pub input_tokens: u32,
pub output_tokens: u32,
pub confidence: EstimationConfidence,
}
pub fn estimate(
provider: &str,
messages: &[Message],
max_tokens_hint: Option<u32>,
model_max_output: Option<u32>,
) -> EstimateResult {
let text = concat_message_text(messages);
let est = tt_tokenize::estimate_input_tokens(provider, &text);
let output = output_tokens(max_tokens_hint, model_max_output);
EstimateResult {
input_tokens: est.tokens,
output_tokens: output,
confidence: map_confidence(est.confidence),
}
}
fn output_tokens(max_tokens_hint: Option<u32>, model_max_output: Option<u32>) -> u32 {
let assumed = max_tokens_hint.unwrap_or(DEFAULT_OUTPUT_TOKENS);
match model_max_output {
Some(max) => assumed.min(max),
None => assumed,
}
}
fn map_confidence(c: tt_tokenize::Confidence) -> EstimationConfidence {
match c {
tt_tokenize::Confidence::High => EstimationConfidence::High,
tt_tokenize::Confidence::Medium => EstimationConfidence::Medium,
tt_tokenize::Confidence::Low => EstimationConfidence::Low,
}
}
fn concat_message_text(messages: &[Message]) -> String {
let mut out = String::new();
for m in messages {
if let Some(s) = m.content.as_str() {
out.push_str(s);
out.push('\n');
} else if let Some(parts) = m.content.as_array() {
for p in parts {
if let Some(s) = p.get("text").and_then(|v| v.as_str()) {
out.push_str(s);
out.push('\n');
}
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn user(text: &str) -> Message {
Message {
role: "user".into(),
content: json!(text),
}
}
#[test]
fn openai_uses_tiktoken_high_confidence() {
let est = estimate("openai", &[user("Hello, world.")], Some(100), None);
assert!(est.input_tokens >= 1);
assert!(matches!(est.confidence, EstimationConfidence::High));
assert_eq!(est.output_tokens, 100);
}
#[test]
fn anthropic_uses_tiktoken() {
let est = estimate("anthropic", &[user("Hello, world.")], None, None);
assert!(est.input_tokens >= 1);
assert!(matches!(est.confidence, EstimationConfidence::High));
assert_eq!(est.output_tokens, 512); }
#[test]
fn unknown_provider_uses_heuristic_medium() {
let est = estimate("gemini", &[user("abcdefgh")], None, None);
assert_eq!(est.input_tokens, 3);
assert!(matches!(est.confidence, EstimationConfidence::Medium));
}
#[test]
fn explicit_hint_is_honored_when_model_max_unknown() {
let est = estimate("openai", &[user("hi")], Some(99999), None);
assert_eq!(est.output_tokens, 99999);
}
#[test]
fn output_is_clamped_to_model_max_when_known() {
let est = estimate("openai", &[user("hi")], Some(99999), Some(8192));
assert_eq!(est.output_tokens, 8192);
}
#[test]
fn default_output_is_clamped_to_small_model_max() {
let est = estimate("openai", &[user("hi")], None, Some(256));
assert_eq!(est.output_tokens, 256);
}
#[test]
fn structured_content_extracts_text_parts() {
let m = Message {
role: "user".into(),
content: json!([{"type": "text", "text": "Hello"}, {"type": "text", "text": " world"}]),
};
let est = estimate("gemini", &[m], None, None);
assert!(est.input_tokens >= 2);
}
}