#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Model {
OpenAiGpt4,
OpenAiGpt4o,
AnthropicClaude,
Heuristic,
}
impl Model {
pub fn name(self) -> &'static str {
match self {
Model::OpenAiGpt4 => "openai-gpt4 (cl100k_base)",
Model::OpenAiGpt4o => "openai-gpt4o (o200k_base)",
Model::AnthropicClaude => "anthropic-claude (approx)",
Model::Heuristic => "heuristic",
}
}
pub fn all() -> [Model; 4] {
[
Model::OpenAiGpt4,
Model::OpenAiGpt4o,
Model::AnthropicClaude,
Model::Heuristic,
]
}
pub fn from_name(name: &str) -> Option<Model> {
match name.trim().to_ascii_lowercase().as_str() {
"gpt4" | "gpt-4" | "openai-gpt4" | "cl100k" | "cl100k_base" => Some(Model::OpenAiGpt4),
"gpt4o" | "gpt-4o" | "openai-gpt4o" | "o200k" | "o200k_base" => {
Some(Model::OpenAiGpt4o)
}
"claude" | "anthropic" | "anthropic-claude" => Some(Model::AnthropicClaude),
"heuristic" | "heur" => Some(Model::Heuristic),
_ => None,
}
}
pub fn is_exact(self) -> bool {
match self {
Model::OpenAiGpt4 | Model::OpenAiGpt4o => cfg!(feature = "real-tokens"),
Model::AnthropicClaude | Model::Heuristic => false,
}
}
pub fn count(self, text: &str) -> usize {
match self {
Model::OpenAiGpt4 => count_openai(text, false),
Model::OpenAiGpt4o => count_openai(text, true),
Model::AnthropicClaude => heuristic_tokens(text),
Model::Heuristic => heuristic_tokens(text),
}
}
}
#[cfg(feature = "real-tokens")]
fn count_openai(text: &str, o200k: bool) -> usize {
use std::sync::OnceLock;
static CL100K: OnceLock<tiktoken_rs::CoreBPE> = OnceLock::new();
static O200K: OnceLock<tiktoken_rs::CoreBPE> = OnceLock::new();
let bpe = if o200k {
O200K.get_or_init(|| tiktoken_rs::o200k_base().expect("load o200k_base"))
} else {
CL100K.get_or_init(|| tiktoken_rs::cl100k_base().expect("load cl100k_base"))
};
bpe.encode_with_special_tokens(text).len()
}
#[cfg(not(feature = "real-tokens"))]
fn count_openai(text: &str, _o200k: bool) -> usize {
heuristic_tokens(text)
}
pub fn heuristic_tokens(text: &str) -> usize {
let mut tokens = 0usize;
let mut in_word = false;
for c in text.chars() {
if c.is_alphanumeric() {
if !in_word {
tokens += 1;
in_word = true;
}
} else {
in_word = false;
if !c.is_whitespace() && c != '_' {
tokens += 1; }
}
}
tokens
}
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct AgentCost {
pub standing_context: usize,
pub input: usize,
pub output: usize,
pub retries: usize,
}
impl AgentCost {
pub fn total_over(&self, turns: usize) -> usize {
self.standing_context + (self.input + self.output) * turns.max(1) + self.retries
}
pub fn total_standing_per_turn(&self, turns: usize) -> usize {
let t = turns.max(1);
(self.standing_context + self.input + self.output) * t + self.retries
}
}
impl std::fmt::Display for AgentCost {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"input={} output={} standing={} retries={}",
self.input, self.output, self.standing_context, self.retries
)
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[derive(Debug, Clone)]
pub struct Program {
pub name: String,
pub source: String,
pub output_sample: String,
pub standing_context: String,
pub retries: usize,
}
impl Program {
pub fn new(name: impl Into<String>, source: impl Into<String>) -> Self {
Self {
name: name.into(),
source: source.into(),
output_sample: String::new(),
standing_context: String::new(),
retries: 0,
}
}
pub fn with_output(mut self, sample: impl Into<String>) -> Self {
self.output_sample = sample.into();
self
}
pub fn with_standing_context(mut self, ctx: impl Into<String>) -> Self {
self.standing_context = ctx.into();
self
}
pub fn with_retries(mut self, retries: usize) -> Self {
self.retries = retries;
self
}
}
pub fn evaluate(program: &Program, model: Model) -> AgentCost {
AgentCost {
standing_context: model.count(&program.standing_context),
input: model.count(&program.source),
output: model.count(&program.output_sample),
retries: program.retries,
}
}
pub fn evaluate_all(program: &Program) -> Vec<(Model, AgentCost)> {
Model::all()
.into_iter()
.map(|m| (m, evaluate(program, m)))
.collect()
}
pub fn evaluate_with<F: Fn(&str) -> usize>(program: &Program, count: F) -> AgentCost {
AgentCost {
standing_context: count(&program.standing_context),
input: count(&program.source),
output: count(&program.output_sample),
retries: program.retries,
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[derive(Debug, Clone)]
pub struct Comparison {
pub model: Model,
pub turns: usize,
pub a: AgentCost,
pub b: AgentCost,
pub a_total: usize,
pub b_total: usize,
pub winner_is_a: bool,
pub ratio: f64,
}
pub fn compare(a: &Program, b: &Program, model: Model, turns: usize) -> Comparison {
let (ca, cb) = (evaluate(a, model), evaluate(b, model));
let (at, bt) = (ca.total_over(turns), cb.total_over(turns));
let winner_is_a = at <= bt;
let (lo, hi) = if at <= bt { (at, bt) } else { (bt, at) };
let ratio = if lo == 0 { 1.0 } else { hi as f64 / lo as f64 };
Comparison {
model,
turns,
a: ca,
b: cb,
a_total: at,
b_total: bt,
winner_is_a,
ratio,
}
}
impl std::fmt::Display for Comparison {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let winner = if self.winner_is_a { "A" } else { "B" };
write!(
f,
"{}: A={} B={} over {} turns → {} wins ({:.2}x){}",
self.model.name(),
self.a_total,
self.b_total,
self.turns,
winner,
self.ratio,
if self.model.is_exact() { "" } else { " [est]" },
)
}
}
pub fn rank(programs: &[Program], model: Model, turns: usize) -> Vec<(usize, usize)> {
rank_with(programs, |s| model.count(s), turns)
}
pub fn rank_with<F: Fn(&str) -> usize>(
programs: &[Program],
count: F,
turns: usize,
) -> Vec<(usize, usize)> {
let mut ranked: Vec<(usize, usize)> = programs
.iter()
.enumerate()
.map(|(i, p)| (i, evaluate_with(p, &count).total_over(turns)))
.collect();
ranked.sort_by_key(|&(_, total)| total);
ranked
}
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[derive(Debug, Clone)]
pub struct ScalingReport {
pub samples: Vec<(usize, usize)>,
pub per_item: f64,
pub fixed_overhead: f64,
pub is_constant: bool,
}
fn least_squares(points: &[(usize, usize)]) -> (f64, f64) {
let n = points.len() as f64;
if n == 0.0 {
return (0.0, 0.0);
}
let sx: f64 = points.iter().map(|&(x, _)| x as f64).sum();
let sy: f64 = points.iter().map(|&(_, y)| y as f64).sum();
let sxx: f64 = points.iter().map(|&(x, _)| (x as f64) * (x as f64)).sum();
let sxy: f64 = points.iter().map(|&(x, y)| (x as f64) * (y as f64)).sum();
let denom = n * sxx - sx * sx;
if denom.abs() < f64::EPSILON {
return (0.0, sy / n);
}
let slope = (n * sxy - sx * sy) / denom;
let intercept = (sy - slope * sx) / n;
(slope, intercept)
}
pub fn assess_scaling<P, C>(sizes: &[usize], produce: P, count: C) -> ScalingReport
where
P: Fn(usize) -> String,
C: Fn(&str) -> usize,
{
let samples: Vec<(usize, usize)> = sizes.iter().map(|&n| (n, count(&produce(n)))).collect();
let (per_item, fixed_overhead) = least_squares(&samples);
ScalingReport {
is_constant: per_item.abs() < 0.5,
per_item,
fixed_overhead,
samples,
}
}
impl std::fmt::Display for ScalingReport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{:.2} tok/item + {:.0} fixed{}",
self.per_item,
self.fixed_overhead,
if self.is_constant { " (≈O(1))" } else { "" }
)
}
}
pub const CACHE_WRITE_MULT: f64 = 1.25;
pub const CACHE_READ_MULT: f64 = 0.1;
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[derive(Debug, Clone)]
pub struct CacheReport {
pub prefix: usize,
pub variable: usize,
pub turns: usize,
pub cacheable_ratio: f64,
pub cost_uncached: usize,
pub cost_cached: usize,
pub savings_ratio: f64,
}
pub fn assess_cache(prefix: usize, variable: usize, turns: usize) -> CacheReport {
let turns = turns.max(1);
let t = turns as f64;
let (p, v) = (prefix as f64, variable as f64);
let cost_uncached = ((p + v) * t).round() as usize;
let cached = p * CACHE_WRITE_MULT + p * CACHE_READ_MULT * (t - 1.0) + v * t;
let cost_cached = (cached.round() as usize).max(1);
let total = (prefix + variable).max(1) as f64;
CacheReport {
prefix,
variable,
turns,
cacheable_ratio: p / total,
cost_uncached,
cost_cached,
savings_ratio: cost_uncached as f64 / cost_cached as f64,
}
}
impl std::fmt::Display for CacheReport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"cacheable {:.0}% → {} vs {} over {} turns ({:.2}x cheaper)",
self.cacheable_ratio * 100.0,
self.cost_cached,
self.cost_uncached,
self.turns,
self.savings_ratio
)
}
}
pub fn cacheable_prefix_tokens<C: Fn(&str) -> usize>(prompts: &[&str], count: C) -> usize {
let mut prefix: Vec<char> = match prompts.first() {
Some(s) => s.chars().collect(),
None => return 0,
};
for p in &prompts[1..] {
let mut n = 0;
for (a, b) in prefix.iter().zip(p.chars()) {
if *a == b {
n += 1;
} else {
break;
}
}
prefix.truncate(n);
if prefix.is_empty() {
break;
}
}
count(&prefix.into_iter().collect::<String>())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn heuristic_is_deterministic_and_sane() {
let s = "file.read(\"README.md\")";
assert_eq!(heuristic_tokens(s), heuristic_tokens(s)); assert!(heuristic_tokens(s) > 0);
assert_eq!(heuristic_tokens(""), 0);
assert!(heuristic_tokens("a b c") >= heuristic_tokens("a b"));
}
#[test]
fn agent_cost_total_amortizes_standing_context_once() {
let c = AgentCost {
standing_context: 1000,
input: 10,
output: 20,
retries: 5,
};
assert_eq!(c.total_over(1), 1035);
assert_eq!(c.total_over(10), 1000 + 300 + 5);
assert_eq!(c.total_over(0), c.total_over(1));
}
#[test]
fn standing_context_can_dominate_a_small_input_win() {
let cipher = Program::new("t", "F.r x")
.with_standing_context("<a multi-kilobyte cipher cheatsheet ".repeat(120).as_str());
let legible = Program::new("t", "file.read x").with_standing_context("short index");
let cmp = compare(&legible, &cipher, Model::Heuristic, 30);
assert!(cmp.winner_is_a, "legible wins once standing context counts");
assert!(cmp.ratio > 1.0);
}
#[test]
fn evaluate_all_covers_every_model() {
let p = Program::new("t", "len([1,2,3])");
let all = evaluate_all(&p);
assert_eq!(all.len(), 4);
for (_m, c) in all {
assert!(c.input > 0);
}
}
#[test]
fn heuristic_splits_snake_case_subwords() {
assert_eq!(heuristic_tokens("file_read"), 2); assert_eq!(heuristic_tokens("a_b_c"), 3);
assert_eq!(heuristic_tokens("file.read"), 3); assert_eq!(heuristic_tokens("len"), 1);
}
#[test]
fn model_from_name_parses_aliases() {
assert_eq!(Model::from_name("gpt-4"), Some(Model::OpenAiGpt4));
assert_eq!(Model::from_name("o200k"), Some(Model::OpenAiGpt4o));
assert_eq!(Model::from_name("CLAUDE"), Some(Model::AnthropicClaude));
assert_eq!(Model::from_name("heur"), Some(Model::Heuristic));
assert_eq!(Model::from_name("nope"), None);
}
#[test]
fn rank_orders_programs_cheapest_first() {
let cheap = Program::new("cheap", "file.read x").with_standing_context("short");
let dear = Program::new("dear", "file.read x")
.with_standing_context("a much longer cheatsheet ".repeat(50).as_str());
let progs = [dear, cheap];
let ranked = rank(&progs, Model::Heuristic, 30);
assert_eq!(ranked.len(), 2);
assert_eq!(ranked[0].0, 1);
assert!(ranked[0].1 <= ranked[1].1);
}
#[test]
fn displays_are_non_empty() {
let c = AgentCost {
standing_context: 10,
input: 5,
output: 2,
retries: 0,
};
assert!(c.to_string().contains("input=5"));
let cmp = compare(
&Program::new("a", "x"),
&Program::new("b", "yy"),
Model::Heuristic,
10,
);
assert!(cmp.to_string().contains("wins"));
}
#[test]
fn evaluate_with_uses_a_custom_counter() {
let p = Program::new("p", "abc")
.with_output("de")
.with_standing_context("fghi")
.with_retries(7);
let cost = evaluate_with(&p, |s| s.chars().count());
assert_eq!(cost.input, 3);
assert_eq!(cost.output, 2);
assert_eq!(cost.standing_context, 4);
assert_eq!(cost.retries, 7); }
#[test]
fn standing_per_turn_is_the_no_caching_upper_bound() {
let c = AgentCost {
standing_context: 100,
input: 10,
output: 5,
retries: 0,
};
assert_eq!(c.total_over(10), 100 + 150);
assert_eq!(c.total_standing_per_turn(10), (100 + 15) * 10);
assert!(c.total_standing_per_turn(10) > c.total_over(10));
}
#[test]
fn rank_with_custom_counter_orders_cheapest_first() {
let progs = [
Program::new("long", "a much longer program body here"),
Program::new("short", "x"),
];
let ranked = rank_with(&progs, |s| s.split_whitespace().count(), 1);
assert_eq!(ranked[0].0, 1); }
#[test]
fn scaling_fits_per_item_slope_and_overhead() {
let produce = |n: usize| {
let mut s = String::from("name size kind");
for _ in 0..n {
s.push_str(" x y");
}
s
};
let words = |s: &str| s.split_whitespace().count();
let r = assess_scaling(&[0, 10, 50, 100], produce, words);
assert!((r.per_item - 2.0).abs() < 1e-6, "per_item {}", r.per_item);
assert!(
(r.fixed_overhead - 3.0).abs() < 1e-6,
"fixed {}",
r.fixed_overhead
);
assert!(!r.is_constant);
let c = assess_scaling(&[1, 10, 100], |_| "fixed".to_string(), words);
assert!(c.is_constant && c.per_item.abs() < 0.5);
}
#[test]
fn cache_models_prefix_reuse_savings() {
let r = assess_cache(900, 100, 10);
assert!((r.cacheable_ratio - 0.9).abs() < 1e-9);
assert_eq!(r.cost_uncached, 10_000);
assert!(r.cost_cached < r.cost_uncached);
assert!(r.savings_ratio > 2.0, "savings {}", r.savings_ratio);
assert!(assess_cache(900, 100, 1).savings_ratio <= 1.0);
}
#[test]
fn cacheable_prefix_is_the_longest_common_prefix() {
let prompts = ["SYSTEM: tools…\nturn 1 do A", "SYSTEM: tools…\nturn 2 do B"];
let words = |s: &str| s.split_whitespace().count();
assert_eq!(cacheable_prefix_tokens(&prompts, words), 3);
assert_eq!(cacheable_prefix_tokens(&["abc", "xyz"], words), 0);
assert_eq!(cacheable_prefix_tokens(&[], words), 0);
}
}