#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: Option<u64>,
pub completion_tokens: Option<u64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TokenSource {
Estimated,
Provided,
Definite,
}
impl TokenSource {
pub fn weaker(self, other: Self) -> Self {
match (self, other) {
(Self::Estimated, _) | (_, Self::Estimated) => Self::Estimated,
(Self::Provided, _) | (_, Self::Provided) => Self::Provided,
_ => Self::Definite,
}
}
}
#[derive(Debug, Clone)]
pub struct TokenCount {
pub tokens: u64,
pub source: TokenSource,
}
impl TokenCount {
pub(crate) fn new(source: TokenSource) -> Self {
Self { tokens: 0, source }
}
pub(crate) fn accumulate(&mut self, tokens: u64, source: TokenSource) {
self.tokens += tokens;
self.source = self.source.weaker(source);
}
pub(crate) fn to_json(&self) -> serde_json::Value {
serde_json::json!({
"tokens": self.tokens,
"source": self.source,
})
}
}
pub(crate) fn estimate_tokens(text: &str) -> u64 {
let mut ascii_chars: u64 = 0;
let mut non_ascii_chars: u64 = 0;
for ch in text.chars() {
if ch.is_ascii() {
ascii_chars += 1;
} else {
non_ascii_chars += 1;
}
}
let ascii_tokens = ascii_chars.div_ceil(4);
let non_ascii_tokens = (non_ascii_chars * 2).div_ceil(3);
ascii_tokens + non_ascii_tokens
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ExecutionMetrics, ExecutionObserver, LlmQuery, QueryId};
#[test]
fn estimate_tokens_empty() {
assert_eq!(estimate_tokens(""), 0);
}
#[test]
fn estimate_tokens_ascii() {
assert_eq!(estimate_tokens("hello world"), 3);
}
#[test]
fn token_source_weaker_estimated_wins() {
assert_eq!(
TokenSource::Estimated.weaker(TokenSource::Definite),
TokenSource::Estimated
);
assert_eq!(
TokenSource::Definite.weaker(TokenSource::Estimated),
TokenSource::Estimated
);
}
#[test]
fn token_source_weaker_provided_over_definite() {
assert_eq!(
TokenSource::Provided.weaker(TokenSource::Definite),
TokenSource::Provided
);
}
#[test]
fn token_source_weaker_same_returns_same() {
assert_eq!(
TokenSource::Definite.weaker(TokenSource::Definite),
TokenSource::Definite
);
assert_eq!(
TokenSource::Estimated.weaker(TokenSource::Estimated),
TokenSource::Estimated
);
}
#[test]
fn token_count_accumulate_degrades_source() {
let mut tc = TokenCount::new(TokenSource::Definite);
tc.accumulate(10, TokenSource::Definite);
assert_eq!(tc.source, TokenSource::Definite);
tc.accumulate(5, TokenSource::Provided);
assert_eq!(tc.tokens, 15);
assert_eq!(tc.source, TokenSource::Provided);
tc.accumulate(3, TokenSource::Estimated);
assert_eq!(tc.tokens, 18);
assert_eq!(tc.source, TokenSource::Estimated);
}
#[test]
fn token_count_to_json_format() {
let tc = TokenCount {
tokens: 42,
source: TokenSource::Provided,
};
let json = tc.to_json();
assert_eq!(json["tokens"], 42);
assert_eq!(json["source"], "provided");
}
#[test]
fn token_source_serde_roundtrip() {
let source = TokenSource::Estimated;
let json = serde_json::to_string(&source).unwrap();
assert_eq!(json, r#""estimated""#);
let restored: TokenSource = serde_json::from_str(&json).unwrap();
assert_eq!(restored, source);
}
#[test]
fn estimate_tokens_cjk() {
assert_eq!(estimate_tokens("あいう"), 2);
}
#[test]
fn estimate_tokens_mixed() {
assert_eq!(estimate_tokens("hello あ"), 3);
}
#[test]
fn token_estimation_in_stats() {
let metrics = ExecutionMetrics::new();
let observer = metrics.create_observer();
let queries = vec![LlmQuery {
id: QueryId::single(),
prompt: "What is 2+2?".into(), system: Some("Expert".into()), max_tokens: 50,
grounded: false,
underspecified: false,
}];
observer.on_paused(&queries);
observer.on_response_fed(&QueryId::single(), "4", None); observer.on_resumed();
observer.on_completed(&serde_json::json!(null));
let json = metrics.to_json();
let auto = &json["auto"];
assert_eq!(auto["prompt_tokens"]["tokens"], 5); assert_eq!(auto["prompt_tokens"]["source"], "estimated");
assert_eq!(auto["response_tokens"]["tokens"], 1);
assert_eq!(auto["response_tokens"]["source"], "estimated");
assert_eq!(auto["total_tokens"]["tokens"], 6);
assert_eq!(auto["total_tokens"]["source"], "estimated");
}
#[test]
fn token_estimation_accumulates_across_rounds() {
let metrics = ExecutionMetrics::new();
let observer = metrics.create_observer();
let q = vec![LlmQuery {
id: QueryId::single(),
prompt: "test".into(), system: None,
max_tokens: 10,
grounded: false,
underspecified: false,
}];
for _ in 0..3 {
observer.on_paused(&q);
observer.on_response_fed(&QueryId::single(), "reply here", None); observer.on_resumed();
}
observer.on_completed(&serde_json::json!(null));
let json = metrics.to_json();
let auto = &json["auto"];
assert_eq!(auto["prompt_tokens"]["tokens"], 3); assert_eq!(auto["prompt_tokens"]["source"], "estimated");
assert_eq!(auto["response_tokens"]["tokens"], 9); assert_eq!(auto["response_tokens"]["source"], "estimated");
}
}