1use super::types::*;
13use crate::provider::ProviderRegistry;
14use crate::ralph::{Prd, QualityChecks, RalphConfig, RalphLoop};
15use crate::telemetry::TOKEN_USAGE;
16use std::path::{Path, PathBuf};
17use std::process::Command;
18use std::time::Instant;
19use tracing::{error, info, warn};
20
21#[derive(Debug, Clone)]
23struct ModelPricing {
24 input_per_m: f64,
26 output_per_m: f64,
28}
29
30pub struct BenchmarkRunner {
32 config: BenchmarkConfig,
33}
34
35impl BenchmarkRunner {
36 pub fn new(config: BenchmarkConfig) -> Self {
37 Self { config }
38 }
39
40 fn discover_prds(&self) -> anyhow::Result<Vec<(PathBuf, u8)>> {
42 let dir = Path::new(&self.config.prd_dir);
43 if !dir.exists() {
44 anyhow::bail!("Benchmark directory not found: {}", self.config.prd_dir);
45 }
46
47 let mut prds = Vec::new();
48 for entry in std::fs::read_dir(dir)? {
49 let entry = entry?;
50 let path = entry.path();
51 if path.extension().is_some_and(|e| e == "json") {
52 let filename = path
53 .file_name()
54 .unwrap_or_default()
55 .to_string_lossy()
56 .to_string();
57
58 let tier = detect_tier(&filename);
59
60 if let Some(filter_tier) = self.config.tier {
62 if tier != filter_tier {
63 continue;
64 }
65 }
66
67 prds.push((path, tier));
68 }
69 }
70
71 prds.sort_by_key(|(_, tier)| *tier);
72 Ok(prds)
73 }
74
75 fn parse_model(model_str: &str) -> anyhow::Result<(String, String)> {
77 let parts: Vec<&str> = model_str.splitn(2, ':').collect();
78 if parts.len() != 2 {
79 anyhow::bail!(
80 "Invalid model format '{}'. Expected 'provider:model'",
81 model_str
82 );
83 }
84 Ok((parts[0].to_string(), parts[1].to_string()))
85 }
86
87 fn scaffold_workspace(
92 prd_path: &Path,
93 prd: &Prd,
94 ) -> anyhow::Result<(tempfile::TempDir, PathBuf)> {
95 let tmp = tempfile::Builder::new()
96 .prefix(&format!("bench-{}-", prd.project))
97 .tempdir()?;
98 let workspace = tmp.path();
99
100 info!(
101 "Scaffolding benchmark workspace at {:?} for PRD '{}'",
102 workspace, prd.feature
103 );
104
105 let cargo_init = Command::new("cargo")
107 .args(["init", "--name", &prd.project])
108 .current_dir(workspace)
109 .output()?;
110 if !cargo_init.status.success() {
111 anyhow::bail!(
112 "cargo init failed: {}",
113 String::from_utf8_lossy(&cargo_init.stderr)
114 );
115 }
116
117 let mut cargo_add_args: Vec<&str> = vec!["add", "serde", "--features", "derive"];
120 let cargo_add = Command::new("cargo")
121 .args(&cargo_add_args)
122 .current_dir(workspace)
123 .output()?;
124 if !cargo_add.status.success() {
125 warn!(
126 "cargo add serde failed (non-fatal): {}",
127 String::from_utf8_lossy(&cargo_add.stderr)
128 );
129 }
130
131 let feature_lower = prd.feature.to_lowercase();
133 let desc_text: String = prd
134 .user_stories
135 .iter()
136 .map(|s| format!("{} {}", s.title, s.description))
137 .collect::<Vec<_>>()
138 .join(" ")
139 .to_lowercase();
140
141 if feature_lower.contains("api")
142 || feature_lower.contains("rest")
143 || desc_text.contains("axum")
144 || desc_text.contains("endpoint")
145 {
146 cargo_add_args = vec!["add", "axum", "tokio", "--features", "tokio/full"];
147 let _ = Command::new("cargo")
148 .args(&cargo_add_args)
149 .current_dir(workspace)
150 .output();
151 let _ = Command::new("cargo")
153 .args(["add", "serde_json"])
154 .current_dir(workspace)
155 .output();
156 }
157
158 if desc_text.contains("clap") || feature_lower.contains("cli") {
159 let _ = Command::new("cargo")
160 .args(["add", "clap", "--features", "derive"])
161 .current_dir(workspace)
162 .output();
163 }
164
165 if desc_text.contains("csv") {
166 let _ = Command::new("cargo")
167 .args(["add", "csv"])
168 .current_dir(workspace)
169 .output();
170 }
171
172 let git_init = Command::new("git")
174 .args(["init"])
175 .current_dir(workspace)
176 .output()?;
177 if !git_init.status.success() {
178 anyhow::bail!(
179 "git init failed: {}",
180 String::from_utf8_lossy(&git_init.stderr)
181 );
182 }
183
184 let _ = Command::new("git")
186 .args(["config", "user.email", "bench@codetether.run"])
187 .current_dir(workspace)
188 .output();
189 let _ = Command::new("git")
190 .args(["config", "user.name", "CodeTether Benchmark"])
191 .current_dir(workspace)
192 .output();
193
194 let _ = Command::new("git")
196 .args(["add", "-A"])
197 .current_dir(workspace)
198 .output();
199 let _ = Command::new("git")
200 .args(["commit", "-m", "initial scaffold for benchmark"])
201 .current_dir(workspace)
202 .output();
203
204 let dest_prd = workspace.join("prd.json");
206 std::fs::copy(prd_path, &dest_prd)?;
207
208 Ok((tmp, dest_prd))
209 }
210
211 fn run_quality_checks(working_dir: &Path, checks: &QualityChecks) -> Vec<QualityCheckResult> {
213 let mut results = Vec::new();
214
215 for (name, cmd) in [
216 ("typecheck", &checks.typecheck),
217 ("lint", &checks.lint),
218 ("test", &checks.test),
219 ("build", &checks.build),
220 ] {
221 if let Some(command) = cmd {
222 let output = Command::new("/bin/sh")
223 .arg("-c")
224 .arg(command)
225 .current_dir(working_dir)
226 .output();
227
228 let (passed, output_text) = match output {
229 Ok(o) => {
230 let passed = o.status.success();
231 let text = if passed {
232 None
233 } else {
234 let stderr = String::from_utf8_lossy(&o.stderr);
235 let stdout = String::from_utf8_lossy(&o.stdout);
236 Some(format!("{}\n{}", stdout, stderr))
237 };
238 (passed, text)
239 }
240 Err(e) => (false, Some(format!("Failed to execute: {}", e))),
241 };
242
243 results.push(QualityCheckResult {
244 name: name.to_string(),
245 passed,
246 output: output_text,
247 });
248 }
249 }
250
251 results
252 }
253
254 async fn fetch_pricing(provider_id: &str, model_id: &str) -> Option<ModelPricing> {
256 let url = "https://models.dev/api.json";
257 let client = reqwest::Client::builder()
258 .timeout(std::time::Duration::from_secs(10))
259 .build()
260 .ok()?;
261
262 let resp = client.get(url).send().await.ok()?;
263 let data: serde_json::Value = resp.json().await.ok()?;
264
265 let cost = data
267 .get(provider_id)?
268 .get("models")?
269 .get(model_id)?
270 .get("cost")?;
271
272 let input = cost.get("input")?.as_f64()?;
273 let output = cost.get("output")?.as_f64()?;
274
275 Some(ModelPricing {
276 input_per_m: input,
277 output_per_m: output,
278 })
279 }
280
281 fn calculate_cost(
283 input_tokens: u64,
284 output_tokens: u64,
285 pricing: &Option<ModelPricing>,
286 ) -> f64 {
287 match pricing {
288 Some(p) => {
289 (input_tokens as f64 / 1_000_000.0) * p.input_per_m
290 + (output_tokens as f64 / 1_000_000.0) * p.output_per_m
291 }
292 None => (input_tokens + output_tokens) as f64 * 0.000005,
294 }
295 }
296
297 async fn submit_results(result: &BenchmarkSuiteResult, api_url: &str, api_key: &str) {
299 let client = match reqwest::Client::builder()
300 .timeout(std::time::Duration::from_secs(30))
301 .build()
302 {
303 Ok(c) => c,
304 Err(e) => {
305 warn!("Failed to create HTTP client for submission: {}", e);
306 return;
307 }
308 };
309
310 let result_json = match serde_json::to_string(result) {
311 Ok(j) => j,
312 Err(e) => {
313 warn!("Failed to serialize results: {}", e);
314 return;
315 }
316 };
317
318 for mr in &result.model_results {
319 let submission = crate::benchmark::types::BenchmarkSubmission {
320 model: mr.model.clone(),
321 agent: format!("{} v{}", result.agent, result.agent_version),
322 result: result_json.clone(),
323 };
324
325 match client
326 .post(api_url)
327 .header("Authorization", format!("Bearer {}", api_key))
328 .json(&submission)
329 .send()
330 .await
331 {
332 Ok(resp) if resp.status().is_success() => {
333 info!("Submitted benchmark results for {}", mr.model);
334 }
335 Ok(resp) => {
336 warn!(
337 "Benchmark submission failed for {} (HTTP {})",
338 mr.model,
339 resp.status()
340 );
341 }
342 Err(e) => {
343 warn!("Failed to submit benchmark results for {}: {}", mr.model, e);
344 }
345 }
346 }
347 }
348
349 pub async fn run(&self) -> anyhow::Result<BenchmarkSuiteResult> {
351 let start = Instant::now();
352 let prds = self.discover_prds()?;
353
354 if prds.is_empty() {
355 anyhow::bail!("No benchmark PRDs found in {}", self.config.prd_dir);
356 }
357
358 info!("Discovered {} benchmark PRDs across tiers", prds.len());
359
360 let mut model_results = Vec::new();
361
362 for model_str in &self.config.models {
363 let (provider_name, model_name) = Self::parse_model(model_str)?;
364 info!("Benchmarking model: {}:{}", provider_name, model_name);
365
366 let registry = ProviderRegistry::from_vault().await?;
368 let provider = registry.get(&provider_name).ok_or_else(|| {
369 anyhow::anyhow!("Provider '{}' not found in Vault", provider_name)
370 })?;
371
372 let pricing = Self::fetch_pricing(&provider_name, &model_name).await;
374 if let Some(ref p) = pricing {
375 info!(
376 "Model pricing: ${:.2}/M input, ${:.2}/M output",
377 p.input_per_m, p.output_per_m
378 );
379 } else {
380 warn!(
381 "Could not fetch pricing for {}:{}, using fallback estimates",
382 provider_name, model_name
383 );
384 }
385
386 let mut prd_results = Vec::new();
387 let mut total_cost = 0.0_f64;
388
389 for (prd_path, tier) in &prds {
390 if let Some(ceiling) = self.config.cost_ceiling_usd {
392 if total_cost >= ceiling {
393 warn!(
394 "Cost ceiling ${:.2} reached — skipping remaining PRDs for model {}",
395 ceiling, model_str
396 );
397 break;
398 }
399 }
400
401 let prd_id = prd_path
402 .file_stem()
403 .unwrap_or_default()
404 .to_string_lossy()
405 .to_string();
406
407 info!("Running benchmark PRD: {} (tier {})", prd_id, tier);
408
409 match self
410 .run_single_prd(prd_path, *tier, provider.clone(), &model_name, &pricing)
411 .await
412 {
413 Ok(result) => {
414 total_cost += result.cost_usd;
415 info!(
416 "PRD {} complete: {}/{} stories passed ({:.0}%) — ${:.4}",
417 prd_id,
418 result.stories_passed,
419 result.stories_total,
420 result.pass_rate * 100.0,
421 result.cost_usd,
422 );
423 prd_results.push(result);
424 }
425 Err(e) => {
426 error!("Failed to run benchmark PRD {}: {}", prd_id, e);
427 prd_results.push(PrdBenchmarkResult {
428 prd_id,
429 prd_tier: *tier,
430 prd_feature: "FAILED".to_string(),
431 stories_total: 0,
432 stories_passed: 0,
433 pass_rate: 0.0,
434 duration_seconds: 0.0,
435 tokens_used: 0,
436 cost_usd: 0.0,
437 quality_checks: Vec::new(),
438 per_story: Vec::new(),
439 });
440 }
441 }
442 }
443
444 let aggregate = Self::compute_aggregate(&prd_results);
445 model_results.push(ModelBenchmarkResult {
446 model: model_str.clone(),
447 prd_results,
448 aggregate,
449 });
450 }
451
452 let summary = Self::compute_summary(&model_results);
453 let elapsed = start.elapsed();
454
455 let result = BenchmarkSuiteResult {
456 run_date: chrono::Utc::now().to_rfc3339(),
457 agent: "codetether".to_string(),
458 agent_version: env!("CARGO_PKG_VERSION").to_string(),
459 model_results,
460 summary,
461 };
462
463 info!("Benchmark suite complete in {:.1}s", elapsed.as_secs_f64());
464
465 let output_path = Path::new(&self.config.output);
467 let json = serde_json::to_string_pretty(&result)?;
468 tokio::fs::write(output_path, &json).await?;
469 info!("Results written to {}", self.config.output);
470
471 if let (Some(api_url), Some(api_key)) =
473 (&self.config.submit_api_url, &self.config.submit_api_key)
474 {
475 Self::submit_results(&result, api_url, api_key).await;
476 }
477
478 Ok(result)
479 }
480
481 async fn run_single_prd(
483 &self,
484 prd_path: &Path,
485 tier: u8,
486 provider: std::sync::Arc<dyn crate::provider::Provider>,
487 model: &str,
488 pricing: &Option<ModelPricing>,
489 ) -> anyhow::Result<PrdBenchmarkResult> {
490 let prd = Prd::load(&prd_path.to_path_buf()).await?;
491 let prd_id = prd_path
492 .file_stem()
493 .unwrap_or_default()
494 .to_string_lossy()
495 .to_string();
496 let prd_feature = prd.feature.clone();
497 let prd_quality_checks = prd.quality_checks.clone();
498
499 let (_tmp_handle, workspace_prd_path) = Self::scaffold_workspace(prd_path, &prd)?;
501 let workspace_dir: &Path = workspace_prd_path
502 .parent()
503 .ok_or_else(|| anyhow::anyhow!("Invalid workspace PRD path"))?;
504
505 info!(
506 "Benchmark workspace ready at {:?} for PRD '{}'",
507 workspace_dir, prd_feature
508 );
509
510 let tokens_before = TOKEN_USAGE.global_snapshot();
512
513 let start = Instant::now();
514
515 let config = RalphConfig {
517 prd_path: workspace_prd_path.to_string_lossy().to_string(),
518 max_iterations: self.config.max_iterations,
519 story_timeout_secs: self.config.story_timeout_secs,
520 quality_checks_enabled: true,
521 auto_commit: true, parallel_enabled: self.config.parallel,
523 ..Default::default()
524 };
525
526 let mut ralph_loop = RalphLoop::new(
528 workspace_prd_path.clone(),
529 provider,
530 model.to_string(),
531 config,
532 )
533 .await?;
534
535 let state = ralph_loop.run().await?;
536 let duration = start.elapsed();
537
538 let tokens_after = TOKEN_USAGE.global_snapshot();
540 let input_tokens = tokens_after
541 .totals
542 .input
543 .saturating_sub(tokens_before.totals.input);
544 let output_tokens = tokens_after
545 .totals
546 .output
547 .saturating_sub(tokens_before.totals.output);
548 let tokens_used = input_tokens + output_tokens;
549
550 let cost_usd = Self::calculate_cost(input_tokens, output_tokens, pricing);
552
553 let per_story: Vec<StoryBenchmarkResult> = state
555 .prd
556 .user_stories
557 .iter()
558 .map(|story| {
559 let progress_entries: Vec<_> = state
560 .progress_log
561 .iter()
562 .filter(|p| p.story_id == story.id)
563 .collect();
564
565 StoryBenchmarkResult {
566 story_id: story.id.clone(),
567 title: story.title.clone(),
568 passed: story.passes,
569 iterations: progress_entries.len(),
570 duration_seconds: 0.0, tokens_used: 0, files_changed: progress_entries
573 .iter()
574 .flat_map(|p| p.files_changed.iter().cloned())
575 .collect::<std::collections::HashSet<_>>()
576 .into_iter()
577 .collect(),
578 }
579 })
580 .collect();
581
582 let stories_passed = state.prd.passed_count();
583 let stories_total = state.prd.user_stories.len();
584
585 let quality_checks = Self::run_quality_checks(workspace_dir, &prd_quality_checks);
587
588 Ok(PrdBenchmarkResult {
589 prd_id,
590 prd_tier: tier,
591 prd_feature,
592 stories_total,
593 stories_passed,
594 pass_rate: if stories_total > 0 {
595 stories_passed as f64 / stories_total as f64
596 } else {
597 0.0
598 },
599 duration_seconds: duration.as_secs_f64(),
600 tokens_used,
601 cost_usd,
602 quality_checks,
603 per_story,
604 })
605 }
606
607 fn compute_aggregate(prd_results: &[PrdBenchmarkResult]) -> AggregateMetrics {
609 let prds_attempted = prd_results.len();
610 let prds_fully_passed = prd_results.iter().filter(|r| r.pass_rate >= 1.0).count();
611
612 let total_stories: usize = prd_results.iter().map(|r| r.stories_total).sum();
613 let total_passed: usize = prd_results.iter().map(|r| r.stories_passed).sum();
614 let total_tokens: u64 = prd_results.iter().map(|r| r.tokens_used).sum();
615 let total_cost: f64 = prd_results.iter().map(|r| r.cost_usd).sum();
616 let total_duration: f64 = prd_results.iter().map(|r| r.duration_seconds).sum();
617
618 let overall_pass_rate = if total_stories > 0 {
619 total_passed as f64 / total_stories as f64
620 } else {
621 0.0
622 };
623
624 let avg_seconds_per_story = if total_passed > 0 {
625 total_duration / total_passed as f64
626 } else {
627 0.0
628 };
629
630 let avg_tokens_per_story = if total_passed > 0 {
631 total_tokens as f64 / total_passed as f64
632 } else {
633 0.0
634 };
635
636 let stories_per_hour = if total_duration > 0.0 {
637 total_passed as f64 / (total_duration / 3600.0)
638 } else {
639 0.0
640 };
641
642 AggregateMetrics {
643 prds_attempted,
644 prds_fully_passed,
645 overall_pass_rate,
646 total_stories,
647 total_stories_passed: total_passed,
648 avg_seconds_per_story,
649 avg_tokens_per_story,
650 total_cost_usd: total_cost,
651 avg_cost_per_story: if total_passed > 0 {
652 total_cost / total_passed as f64
653 } else {
654 0.0
655 },
656 total_duration_seconds: total_duration,
657 stories_per_hour,
658 }
659 }
660
661 fn compute_summary(model_results: &[ModelBenchmarkResult]) -> BenchmarkSummary {
663 if model_results.is_empty() {
664 return BenchmarkSummary {
665 best_pass_rate_model: String::new(),
666 fastest_model: String::new(),
667 cheapest_model: String::new(),
668 best_overall_model: String::new(),
669 rankings: Vec::new(),
670 };
671 }
672
673 let max_pass_rate = model_results
675 .iter()
676 .map(|r| r.aggregate.overall_pass_rate)
677 .fold(0.0_f64, f64::max);
678 let max_speed = model_results
679 .iter()
680 .map(|r| r.aggregate.stories_per_hour)
681 .fold(0.0_f64, f64::max);
682 let min_cost = model_results
683 .iter()
684 .filter(|r| r.aggregate.avg_cost_per_story > 0.0)
685 .map(|r| r.aggregate.avg_cost_per_story)
686 .fold(f64::INFINITY, f64::min);
687
688 let mut rankings: Vec<ModelRanking> = model_results
689 .iter()
690 .map(|r| {
691 let pass_rate_score = if max_pass_rate > 0.0 {
692 (r.aggregate.overall_pass_rate / max_pass_rate) * 100.0
693 } else {
694 0.0
695 };
696
697 let speed_score = if max_speed > 0.0 {
698 (r.aggregate.stories_per_hour / max_speed) * 100.0
699 } else {
700 0.0
701 };
702
703 let cost_score = if r.aggregate.avg_cost_per_story > 0.0 && min_cost.is_finite() {
704 (min_cost / r.aggregate.avg_cost_per_story) * 100.0
705 } else {
706 0.0
707 };
708
709 let overall_score = pass_rate_score * 0.50 + speed_score * 0.25 + cost_score * 0.25;
711
712 ModelRanking {
713 model: r.model.clone(),
714 pass_rate_score,
715 speed_score,
716 cost_score,
717 overall_score,
718 }
719 })
720 .collect();
721
722 rankings.sort_by(|a, b| {
723 b.overall_score
724 .partial_cmp(&a.overall_score)
725 .unwrap_or(std::cmp::Ordering::Equal)
726 });
727
728 let best_pass = model_results
729 .iter()
730 .max_by(|a, b| {
731 a.aggregate
732 .overall_pass_rate
733 .partial_cmp(&b.aggregate.overall_pass_rate)
734 .unwrap_or(std::cmp::Ordering::Equal)
735 })
736 .map(|r| r.model.clone())
737 .unwrap_or_default();
738
739 let fastest = model_results
740 .iter()
741 .max_by(|a, b| {
742 a.aggregate
743 .stories_per_hour
744 .partial_cmp(&b.aggregate.stories_per_hour)
745 .unwrap_or(std::cmp::Ordering::Equal)
746 })
747 .map(|r| r.model.clone())
748 .unwrap_or_default();
749
750 let cheapest = model_results
751 .iter()
752 .filter(|r| r.aggregate.avg_cost_per_story > 0.0)
753 .min_by(|a, b| {
754 a.aggregate
755 .avg_cost_per_story
756 .partial_cmp(&b.aggregate.avg_cost_per_story)
757 .unwrap_or(std::cmp::Ordering::Equal)
758 })
759 .map(|r| r.model.clone())
760 .unwrap_or_default();
761
762 let best_overall = rankings
763 .first()
764 .map(|r| r.model.clone())
765 .unwrap_or_default();
766
767 BenchmarkSummary {
768 best_pass_rate_model: best_pass,
769 fastest_model: fastest,
770 cheapest_model: cheapest,
771 best_overall_model: best_overall,
772 rankings,
773 }
774 }
775}