1#[cfg_attr(feature = "serde", derive(serde::Serialize))]
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum Model {
15 OpenAiGpt4,
17 OpenAiGpt4o,
19 AnthropicClaude,
24 Heuristic,
26}
27
28impl Model {
29 pub fn name(self) -> &'static str {
31 match self {
32 Model::OpenAiGpt4 => "openai-gpt4 (cl100k_base)",
33 Model::OpenAiGpt4o => "openai-gpt4o (o200k_base)",
34 Model::AnthropicClaude => "anthropic-claude (approx)",
35 Model::Heuristic => "heuristic",
36 }
37 }
38
39 pub fn all() -> [Model; 4] {
41 [
42 Model::OpenAiGpt4,
43 Model::OpenAiGpt4o,
44 Model::AnthropicClaude,
45 Model::Heuristic,
46 ]
47 }
48
49 pub fn from_name(name: &str) -> Option<Model> {
53 match name.trim().to_ascii_lowercase().as_str() {
54 "gpt4" | "gpt-4" | "openai-gpt4" | "cl100k" | "cl100k_base" => Some(Model::OpenAiGpt4),
55 "gpt4o" | "gpt-4o" | "openai-gpt4o" | "o200k" | "o200k_base" => {
56 Some(Model::OpenAiGpt4o)
57 }
58 "claude" | "anthropic" | "anthropic-claude" => Some(Model::AnthropicClaude),
59 "heuristic" | "heur" => Some(Model::Heuristic),
60 _ => None,
61 }
62 }
63
64 pub fn is_exact(self) -> bool {
67 match self {
68 Model::OpenAiGpt4 | Model::OpenAiGpt4o => cfg!(feature = "real-tokens"),
69 Model::AnthropicClaude | Model::Heuristic => false,
70 }
71 }
72
73 pub fn count(self, text: &str) -> usize {
75 match self {
76 Model::OpenAiGpt4 => count_openai(text, false),
77 Model::OpenAiGpt4o => count_openai(text, true),
78 Model::AnthropicClaude => heuristic_tokens(text),
81 Model::Heuristic => heuristic_tokens(text),
82 }
83 }
84}
85
86#[cfg(feature = "real-tokens")]
87fn count_openai(text: &str, o200k: bool) -> usize {
88 use std::sync::OnceLock;
89 static CL100K: OnceLock<tiktoken_rs::CoreBPE> = OnceLock::new();
90 static O200K: OnceLock<tiktoken_rs::CoreBPE> = OnceLock::new();
91 let bpe = if o200k {
92 O200K.get_or_init(|| tiktoken_rs::o200k_base().expect("load o200k_base"))
93 } else {
94 CL100K.get_or_init(|| tiktoken_rs::cl100k_base().expect("load cl100k_base"))
95 };
96 bpe.encode_with_special_tokens(text).len()
97}
98
99#[cfg(not(feature = "real-tokens"))]
100fn count_openai(text: &str, _o200k: bool) -> usize {
101 heuristic_tokens(text)
102}
103
104pub fn heuristic_tokens(text: &str) -> usize {
111 let mut tokens = 0usize;
112 let mut in_word = false;
113 for c in text.chars() {
114 if c.is_alphanumeric() {
115 if !in_word {
116 tokens += 1;
117 in_word = true;
118 }
119 } else {
120 in_word = false;
121 if !c.is_whitespace() && c != '_' {
122 tokens += 1; }
124 }
125 }
126 tokens
127}
128
129#[cfg_attr(feature = "serde", derive(serde::Serialize))]
131#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
132pub struct AgentCost {
133 pub standing_context: usize,
135 pub input: usize,
137 pub output: usize,
139 pub retries: usize,
142}
143
144impl AgentCost {
145 pub fn total_over(&self, turns: usize) -> usize {
149 self.standing_context + (self.input + self.output) * turns.max(1) + self.retries
150 }
151
152 pub fn total_standing_per_turn(&self, turns: usize) -> usize {
157 let t = turns.max(1);
158 (self.standing_context + self.input + self.output) * t + self.retries
159 }
160}
161
162impl std::fmt::Display for AgentCost {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 write!(
165 f,
166 "input={} output={} standing={} retries={}",
167 self.input, self.output, self.standing_context, self.retries
168 )
169 }
170}
171
172#[cfg_attr(feature = "serde", derive(serde::Serialize))]
174#[derive(Debug, Clone)]
175pub struct Program {
176 pub name: String,
178 pub source: String,
180 pub output_sample: String,
182 pub standing_context: String,
184 pub retries: usize,
186}
187
188impl Program {
189 pub fn new(name: impl Into<String>, source: impl Into<String>) -> Self {
191 Self {
192 name: name.into(),
193 source: source.into(),
194 output_sample: String::new(),
195 standing_context: String::new(),
196 retries: 0,
197 }
198 }
199 pub fn with_output(mut self, sample: impl Into<String>) -> Self {
201 self.output_sample = sample.into();
202 self
203 }
204 pub fn with_standing_context(mut self, ctx: impl Into<String>) -> Self {
206 self.standing_context = ctx.into();
207 self
208 }
209 pub fn with_retries(mut self, retries: usize) -> Self {
211 self.retries = retries;
212 self
213 }
214}
215
216pub fn evaluate(program: &Program, model: Model) -> AgentCost {
218 AgentCost {
219 standing_context: model.count(&program.standing_context),
220 input: model.count(&program.source),
221 output: model.count(&program.output_sample),
222 retries: program.retries,
223 }
224}
225
226pub fn evaluate_all(program: &Program) -> Vec<(Model, AgentCost)> {
228 Model::all()
229 .into_iter()
230 .map(|m| (m, evaluate(program, m)))
231 .collect()
232}
233
234pub fn evaluate_with<F: Fn(&str) -> usize>(program: &Program, count: F) -> AgentCost {
246 AgentCost {
247 standing_context: count(&program.standing_context),
248 input: count(&program.source),
249 output: count(&program.output_sample),
250 retries: program.retries,
251 }
252}
253
254#[cfg_attr(feature = "serde", derive(serde::Serialize))]
256#[derive(Debug, Clone)]
257pub struct Comparison {
258 pub model: Model,
260 pub turns: usize,
262 pub a: AgentCost,
264 pub b: AgentCost,
266 pub a_total: usize,
268 pub b_total: usize,
270 pub winner_is_a: bool,
272 pub ratio: f64,
274}
275
276pub fn compare(a: &Program, b: &Program, model: Model, turns: usize) -> Comparison {
278 let (ca, cb) = (evaluate(a, model), evaluate(b, model));
279 let (at, bt) = (ca.total_over(turns), cb.total_over(turns));
280 let winner_is_a = at <= bt;
281 let (lo, hi) = if at <= bt { (at, bt) } else { (bt, at) };
282 let ratio = if lo == 0 { 1.0 } else { hi as f64 / lo as f64 };
283 Comparison {
284 model,
285 turns,
286 a: ca,
287 b: cb,
288 a_total: at,
289 b_total: bt,
290 winner_is_a,
291 ratio,
292 }
293}
294
295impl std::fmt::Display for Comparison {
296 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 let winner = if self.winner_is_a { "A" } else { "B" };
298 write!(
299 f,
300 "{}: A={} B={} over {} turns → {} wins ({:.2}x){}",
301 self.model.name(),
302 self.a_total,
303 self.b_total,
304 self.turns,
305 winner,
306 self.ratio,
307 if self.model.is_exact() { "" } else { " [est]" },
308 )
309 }
310}
311
312pub fn rank(programs: &[Program], model: Model, turns: usize) -> Vec<(usize, usize)> {
316 rank_with(programs, |s| model.count(s), turns)
317}
318
319pub fn rank_with<F: Fn(&str) -> usize>(
322 programs: &[Program],
323 count: F,
324 turns: usize,
325) -> Vec<(usize, usize)> {
326 let mut ranked: Vec<(usize, usize)> = programs
327 .iter()
328 .enumerate()
329 .map(|(i, p)| (i, evaluate_with(p, &count).total_over(turns)))
330 .collect();
331 ranked.sort_by_key(|&(_, total)| total);
332 ranked
333}
334
335#[cfg_attr(feature = "serde", derive(serde::Serialize))]
339#[derive(Debug, Clone)]
340pub struct ScalingReport {
341 pub samples: Vec<(usize, usize)>,
343 pub per_item: f64,
345 pub fixed_overhead: f64,
347 pub is_constant: bool,
349}
350
351fn least_squares(points: &[(usize, usize)]) -> (f64, f64) {
354 let n = points.len() as f64;
355 if n == 0.0 {
356 return (0.0, 0.0);
357 }
358 let sx: f64 = points.iter().map(|&(x, _)| x as f64).sum();
359 let sy: f64 = points.iter().map(|&(_, y)| y as f64).sum();
360 let sxx: f64 = points.iter().map(|&(x, _)| (x as f64) * (x as f64)).sum();
361 let sxy: f64 = points.iter().map(|&(x, y)| (x as f64) * (y as f64)).sum();
362 let denom = n * sxx - sx * sx;
363 if denom.abs() < f64::EPSILON {
364 return (0.0, sy / n);
365 }
366 let slope = (n * sxy - sx * sy) / denom;
367 let intercept = (sy - slope * sx) / n;
368 (slope, intercept)
369}
370
371pub fn assess_scaling<P, C>(sizes: &[usize], produce: P, count: C) -> ScalingReport
375where
376 P: Fn(usize) -> String,
377 C: Fn(&str) -> usize,
378{
379 let samples: Vec<(usize, usize)> = sizes.iter().map(|&n| (n, count(&produce(n)))).collect();
380 let (per_item, fixed_overhead) = least_squares(&samples);
381 ScalingReport {
382 is_constant: per_item.abs() < 0.5,
383 per_item,
384 fixed_overhead,
385 samples,
386 }
387}
388
389impl std::fmt::Display for ScalingReport {
390 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391 write!(
392 f,
393 "{:.2} tok/item + {:.0} fixed{}",
394 self.per_item,
395 self.fixed_overhead,
396 if self.is_constant { " (≈O(1))" } else { "" }
397 )
398 }
399}
400
401pub const CACHE_WRITE_MULT: f64 = 1.25;
404pub const CACHE_READ_MULT: f64 = 0.1;
406
407#[cfg_attr(feature = "serde", derive(serde::Serialize))]
412#[derive(Debug, Clone)]
413pub struct CacheReport {
414 pub prefix: usize,
416 pub variable: usize,
418 pub turns: usize,
420 pub cacheable_ratio: f64,
422 pub cost_uncached: usize,
424 pub cost_cached: usize,
427 pub savings_ratio: f64,
429}
430
431pub fn assess_cache(prefix: usize, variable: usize, turns: usize) -> CacheReport {
434 let turns = turns.max(1);
435 let t = turns as f64;
436 let (p, v) = (prefix as f64, variable as f64);
437 let cost_uncached = ((p + v) * t).round() as usize;
438 let cached = p * CACHE_WRITE_MULT + p * CACHE_READ_MULT * (t - 1.0) + v * t;
439 let cost_cached = (cached.round() as usize).max(1);
440 let total = (prefix + variable).max(1) as f64;
441 CacheReport {
442 prefix,
443 variable,
444 turns,
445 cacheable_ratio: p / total,
446 cost_uncached,
447 cost_cached,
448 savings_ratio: cost_uncached as f64 / cost_cached as f64,
449 }
450}
451
452impl std::fmt::Display for CacheReport {
453 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
454 write!(
455 f,
456 "cacheable {:.0}% → {} vs {} over {} turns ({:.2}x cheaper)",
457 self.cacheable_ratio * 100.0,
458 self.cost_cached,
459 self.cost_uncached,
460 self.turns,
461 self.savings_ratio
462 )
463 }
464}
465
466pub fn cacheable_prefix_tokens<C: Fn(&str) -> usize>(prompts: &[&str], count: C) -> usize {
470 let mut prefix: Vec<char> = match prompts.first() {
471 Some(s) => s.chars().collect(),
472 None => return 0,
473 };
474 for p in &prompts[1..] {
475 let mut n = 0;
476 for (a, b) in prefix.iter().zip(p.chars()) {
477 if *a == b {
478 n += 1;
479 } else {
480 break;
481 }
482 }
483 prefix.truncate(n);
484 if prefix.is_empty() {
485 break;
486 }
487 }
488 count(&prefix.into_iter().collect::<String>())
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn heuristic_is_deterministic_and_sane() {
497 let s = "file.read(\"README.md\")";
498 assert_eq!(heuristic_tokens(s), heuristic_tokens(s)); assert!(heuristic_tokens(s) > 0);
500 assert_eq!(heuristic_tokens(""), 0);
502 assert!(heuristic_tokens("a b c") >= heuristic_tokens("a b"));
504 }
505
506 #[test]
507 fn agent_cost_total_amortizes_standing_context_once() {
508 let c = AgentCost {
509 standing_context: 1000,
510 input: 10,
511 output: 20,
512 retries: 5,
513 };
514 assert_eq!(c.total_over(1), 1035);
516 assert_eq!(c.total_over(10), 1000 + 300 + 5);
518 assert_eq!(c.total_over(0), c.total_over(1));
520 }
521
522 #[test]
523 fn standing_context_can_dominate_a_small_input_win() {
524 let cipher = Program::new("t", "F.r x")
527 .with_standing_context("<a multi-kilobyte cipher cheatsheet ".repeat(120).as_str());
528 let legible = Program::new("t", "file.read x").with_standing_context("short index");
529 let cmp = compare(&legible, &cipher, Model::Heuristic, 30);
530 assert!(cmp.winner_is_a, "legible wins once standing context counts");
531 assert!(cmp.ratio > 1.0);
532 }
533
534 #[test]
535 fn evaluate_all_covers_every_model() {
536 let p = Program::new("t", "len([1,2,3])");
537 let all = evaluate_all(&p);
538 assert_eq!(all.len(), 4);
539 for (_m, c) in all {
540 assert!(c.input > 0);
541 }
542 }
543
544 #[test]
545 fn heuristic_splits_snake_case_subwords() {
546 assert_eq!(heuristic_tokens("file_read"), 2); assert_eq!(heuristic_tokens("a_b_c"), 3);
549 assert_eq!(heuristic_tokens("file.read"), 3); assert_eq!(heuristic_tokens("len"), 1);
553 }
554
555 #[test]
556 fn model_from_name_parses_aliases() {
557 assert_eq!(Model::from_name("gpt-4"), Some(Model::OpenAiGpt4));
558 assert_eq!(Model::from_name("o200k"), Some(Model::OpenAiGpt4o));
559 assert_eq!(Model::from_name("CLAUDE"), Some(Model::AnthropicClaude));
560 assert_eq!(Model::from_name("heur"), Some(Model::Heuristic));
561 assert_eq!(Model::from_name("nope"), None);
562 }
563
564 #[test]
565 fn rank_orders_programs_cheapest_first() {
566 let cheap = Program::new("cheap", "file.read x").with_standing_context("short");
568 let dear = Program::new("dear", "file.read x")
569 .with_standing_context("a much longer cheatsheet ".repeat(50).as_str());
570 let progs = [dear, cheap];
571 let ranked = rank(&progs, Model::Heuristic, 30);
572 assert_eq!(ranked.len(), 2);
573 assert_eq!(ranked[0].0, 1);
575 assert!(ranked[0].1 <= ranked[1].1);
576 }
577
578 #[test]
579 fn displays_are_non_empty() {
580 let c = AgentCost {
581 standing_context: 10,
582 input: 5,
583 output: 2,
584 retries: 0,
585 };
586 assert!(c.to_string().contains("input=5"));
587 let cmp = compare(
588 &Program::new("a", "x"),
589 &Program::new("b", "yy"),
590 Model::Heuristic,
591 10,
592 );
593 assert!(cmp.to_string().contains("wins"));
594 }
595
596 #[test]
597 fn evaluate_with_uses_a_custom_counter() {
598 let p = Program::new("p", "abc")
600 .with_output("de")
601 .with_standing_context("fghi")
602 .with_retries(7);
603 let cost = evaluate_with(&p, |s| s.chars().count());
604 assert_eq!(cost.input, 3);
605 assert_eq!(cost.output, 2);
606 assert_eq!(cost.standing_context, 4);
607 assert_eq!(cost.retries, 7); }
609
610 #[test]
611 fn standing_per_turn_is_the_no_caching_upper_bound() {
612 let c = AgentCost {
613 standing_context: 100,
614 input: 10,
615 output: 5,
616 retries: 0,
617 };
618 assert_eq!(c.total_over(10), 100 + 150);
620 assert_eq!(c.total_standing_per_turn(10), (100 + 15) * 10);
621 assert!(c.total_standing_per_turn(10) > c.total_over(10));
622 }
623
624 #[test]
625 fn rank_with_custom_counter_orders_cheapest_first() {
626 let progs = [
627 Program::new("long", "a much longer program body here"),
628 Program::new("short", "x"),
629 ];
630 let ranked = rank_with(&progs, |s| s.split_whitespace().count(), 1);
631 assert_eq!(ranked[0].0, 1); }
633
634 #[test]
635 fn scaling_fits_per_item_slope_and_overhead() {
636 let produce = |n: usize| {
638 let mut s = String::from("name size kind");
639 for _ in 0..n {
640 s.push_str(" x y");
641 }
642 s
643 };
644 let words = |s: &str| s.split_whitespace().count();
645 let r = assess_scaling(&[0, 10, 50, 100], produce, words);
646 assert!((r.per_item - 2.0).abs() < 1e-6, "per_item {}", r.per_item);
647 assert!(
648 (r.fixed_overhead - 3.0).abs() < 1e-6,
649 "fixed {}",
650 r.fixed_overhead
651 );
652 assert!(!r.is_constant);
653
654 let c = assess_scaling(&[1, 10, 100], |_| "fixed".to_string(), words);
656 assert!(c.is_constant && c.per_item.abs() < 0.5);
657 }
658
659 #[test]
660 fn cache_models_prefix_reuse_savings() {
661 let r = assess_cache(900, 100, 10);
663 assert!((r.cacheable_ratio - 0.9).abs() < 1e-9);
664 assert_eq!(r.cost_uncached, 10_000);
666 assert!(r.cost_cached < r.cost_uncached);
668 assert!(r.savings_ratio > 2.0, "savings {}", r.savings_ratio);
669 assert!(assess_cache(900, 100, 1).savings_ratio <= 1.0);
671 }
672
673 #[test]
674 fn cacheable_prefix_is_the_longest_common_prefix() {
675 let prompts = ["SYSTEM: tools…\nturn 1 do A", "SYSTEM: tools…\nturn 2 do B"];
676 let words = |s: &str| s.split_whitespace().count();
677 assert_eq!(cacheable_prefix_tokens(&prompts, words), 3);
679 assert_eq!(cacheable_prefix_tokens(&["abc", "xyz"], words), 0);
681 assert_eq!(cacheable_prefix_tokens(&[], words), 0);
682 }
683}