use serde::{Deserialize, Serialize};
use std::fmt;
use std::ops::{Add, AddAssign};
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default, Serialize, Deserialize,
)]
#[serde(transparent)]
pub struct TokenCount(u64);
impl TokenCount {
pub const fn new(count: u64) -> Self {
Self(count)
}
pub const fn zero() -> Self {
Self(0)
}
pub const fn as_u64(&self) -> u64 {
self.0
}
pub const fn is_zero(&self) -> bool {
self.0 == 0
}
pub fn format(&self) -> String {
if self.0 < 1_000 {
format!("{}", self.0)
} else if self.0 < 10_000 {
format!("{:.1}K", self.0 as f64 / 1_000.0)
} else if self.0 < 1_000_000 {
format!("{}K", self.0 / 1_000)
} else {
format!("{:.1}M", self.0 as f64 / 1_000_000.0)
}
}
pub fn saturating_add(self, other: Self) -> Self {
Self(self.0.saturating_add(other.0))
}
}
impl Add for TokenCount {
type Output = Self;
fn add(self, other: Self) -> Self {
Self(self.0.saturating_add(other.0))
}
}
impl AddAssign for TokenCount {
fn add_assign(&mut self, other: Self) {
self.0 = self.0.saturating_add(other.0);
}
}
impl From<u64> for TokenCount {
fn from(n: u64) -> Self {
Self(n)
}
}
impl From<u32> for TokenCount {
fn from(n: u32) -> Self {
Self(n as u64)
}
}
impl fmt::Display for TokenCount {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.format())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Default, Serialize, Deserialize)]
pub struct ContextUsage {
pub total_input_tokens: TokenCount,
pub total_output_tokens: TokenCount,
pub context_window_size: u32,
pub current_input_tokens: TokenCount,
pub current_output_tokens: TokenCount,
pub cache_creation_tokens: TokenCount,
pub cache_read_tokens: TokenCount,
}
impl ContextUsage {
pub fn new(context_window_size: u32) -> Self {
Self {
context_window_size,
..Default::default()
}
}
pub fn context_tokens(&self) -> TokenCount {
self.cache_read_tokens
.saturating_add(self.current_input_tokens)
.saturating_add(self.cache_creation_tokens)
}
pub fn total_tokens(&self) -> TokenCount {
self.total_input_tokens
.saturating_add(self.total_output_tokens)
}
pub fn usage_percentage(&self) -> f64 {
if self.context_window_size == 0 {
return 0.0;
}
let usage = self.context_tokens().as_u64() as f64 / self.context_window_size as f64;
(usage * 100.0).min(100.0)
}
pub fn is_warning(&self) -> bool {
self.usage_percentage() >= 80.0
}
pub fn is_critical(&self) -> bool {
self.usage_percentage() >= 90.0
}
pub fn exceeds_200k(&self) -> bool {
self.context_tokens().as_u64() > 200_000
}
pub fn remaining_tokens(&self) -> TokenCount {
let used = self.context_tokens().as_u64();
let limit = self.context_window_size as u64;
TokenCount::new(limit.saturating_sub(used))
}
pub fn format(&self) -> String {
format!(
"{:.1}% ({}/{})",
self.usage_percentage(),
self.context_tokens().format(),
TokenCount::new(self.context_window_size as u64).format()
)
}
pub fn format_compact(&self) -> String {
format!("{:.0}%", self.usage_percentage())
}
}
impl fmt::Display for ContextUsage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.format())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum ContextWarningLevel {
Normal,
Elevated,
Warning,
Critical,
}
pub struct ContextAnalyzer;
impl ContextAnalyzer {
pub fn analyze(context: &ContextUsage) -> ContextWarningLevel {
let percentage = context.usage_percentage();
if percentage >= 90.0 {
ContextWarningLevel::Critical
} else if percentage >= 80.0 {
ContextWarningLevel::Warning
} else if percentage >= 60.0 {
ContextWarningLevel::Elevated
} else {
ContextWarningLevel::Normal
}
}
pub fn warning_message(context: &ContextUsage) -> Option<String> {
match Self::analyze(context) {
ContextWarningLevel::Critical => Some(format!(
"CRITICAL: Context at {:.0}%. Consider /compact or starting new conversation.",
context.usage_percentage()
)),
ContextWarningLevel::Warning => Some(format!(
"Warning: Context at {:.0}%. Approaching limit.",
context.usage_percentage()
)),
ContextWarningLevel::Elevated => Some(format!(
"Note: Context at {:.0}%.",
context.usage_percentage()
)),
ContextWarningLevel::Normal => None,
}
}
pub fn estimate_remaining_turns(
context: &ContextUsage,
avg_tokens_per_turn: u64,
) -> Option<u64> {
if avg_tokens_per_turn == 0 {
return None;
}
let remaining = context.remaining_tokens().as_u64();
Some(remaining / avg_tokens_per_turn)
}
pub fn cache_efficiency(context: &ContextUsage) -> f64 {
let total_input = context.total_input_tokens.as_u64();
if total_input == 0 {
return 0.0;
}
let cache_reads = context.cache_read_tokens.as_u64();
(cache_reads as f64 / total_input as f64) * 100.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_count_formatting() {
assert_eq!(TokenCount::new(500).format(), "500");
assert_eq!(TokenCount::new(5_000).format(), "5.0K");
assert_eq!(TokenCount::new(50_000).format(), "50K");
assert_eq!(TokenCount::new(1_500_000).format(), "1.5M");
}
#[test]
fn test_usage_percentage_from_current_usage() {
let usage = ContextUsage {
cache_read_tokens: TokenCount::new(26_000),
current_input_tokens: TokenCount::new(9),
cache_creation_tokens: TokenCount::new(31),
context_window_size: 200_000,
..Default::default()
};
assert!((usage.usage_percentage() - 13.02).abs() < 0.01);
assert_eq!(usage.context_tokens().as_u64(), 26_040);
}
#[test]
fn test_usage_percentage_zero_when_current_usage_null() {
let usage = ContextUsage {
total_input_tokens: TokenCount::new(10_000), total_output_tokens: TokenCount::new(1_000),
context_window_size: 200_000,
..Default::default()
};
assert!((usage.usage_percentage() - 0.0).abs() < 0.01);
}
#[test]
fn test_warning_thresholds() {
let normal = ContextUsage {
cache_read_tokens: TokenCount::new(100_000),
context_window_size: 200_000,
..Default::default()
};
assert!(!normal.is_warning());
assert!(!normal.is_critical());
assert_eq!(
ContextAnalyzer::analyze(&normal),
ContextWarningLevel::Normal
);
let warning = ContextUsage {
cache_read_tokens: TokenCount::new(160_000),
context_window_size: 200_000,
..Default::default()
};
assert!(warning.is_warning());
assert!(!warning.is_critical());
assert_eq!(
ContextAnalyzer::analyze(&warning),
ContextWarningLevel::Warning
);
let critical = ContextUsage {
cache_read_tokens: TokenCount::new(190_000),
context_window_size: 200_000,
..Default::default()
};
assert!(critical.is_warning());
assert!(critical.is_critical());
assert_eq!(
ContextAnalyzer::analyze(&critical),
ContextWarningLevel::Critical
);
}
#[test]
fn test_remaining_tokens() {
let usage = ContextUsage {
cache_read_tokens: TokenCount::new(100_000),
context_window_size: 200_000,
..Default::default()
};
assert_eq!(usage.remaining_tokens().as_u64(), 100_000);
}
#[test]
fn test_context_tokens_calculation() {
let usage = ContextUsage {
cache_read_tokens: TokenCount::new(25_000),
current_input_tokens: TokenCount::new(500),
cache_creation_tokens: TokenCount::new(100),
context_window_size: 200_000,
..Default::default()
};
assert_eq!(usage.context_tokens().as_u64(), 25_600);
}
}