Skip to main content

ferrum_cli/commands/
bench.rs

1//! Bench command - Throughput and latency benchmarking
2//!
3//! Modes:
4//!   ferrum bench qwen3:4b                              # default: sequential, 5 rounds
5//!   ferrum bench qwen3:4b --concurrency 4              # concurrent requests (tests batch decode)
6//!   ferrum bench qwen3:4b --max-tokens 1024            # long decode (tests flash decode)
7//!   ferrum bench qwen3:4b --long-context               # 2k prompt + 256 decode
8//!   ferrum bench qwen3:4b --concurrency 8 --max-tokens 64  # throughput stress test
9
10use crate::config::CliConfig;
11use chrono::Utc;
12use clap::Args;
13use colored::*;
14use ferrum_models::HfDownloader;
15use ferrum_types::{InferenceRequest, Priority, RequestId, Result, SamplingParams};
16use futures::StreamExt;
17use std::collections::HashMap;
18use std::time::Instant;
19use uuid::Uuid;
20
21#[derive(Args)]
22pub struct BenchCommand {
23    /// Model name (e.g., qwen3:0.6b, qwen3:4b)
24    #[arg(default_value = "qwen3:0.6b")]
25    pub model: String,
26
27    /// Number of benchmark rounds
28    #[arg(long, default_value = "5")]
29    pub rounds: usize,
30
31    /// Max tokens per request
32    #[arg(long, default_value = "128")]
33    pub max_tokens: u32,
34
35    /// Backend: auto, cpu, cuda, metal
36    #[arg(long, default_value = "auto")]
37    pub backend: String,
38
39    /// Prompt to use
40    #[arg(long, default_value = "Explain the theory of relativity in detail.")]
41    pub prompt: String,
42
43    /// Number of concurrent requests (>1 tests batch decode)
44    #[arg(long, default_value = "1")]
45    pub concurrency: usize,
46
47    /// Long-context mode: use a ~2k token prompt (tests flash decode / paged KV)
48    #[arg(long)]
49    pub long_context: bool,
50}
51
52pub async fn execute(cmd: BenchCommand, config: CliConfig) -> Result<()> {
53    let model_id = super::run::resolve_model_alias(&cmd.model);
54    eprintln!("{}", format!("Ferrum Benchmark - {}", model_id).bold());
55    eprintln!("{}", "=".repeat(60).dimmed());
56
57    // Find or download model
58    let cache_dir = super::run::get_hf_cache_dir(&config);
59    let source = match super::run::find_cached_model(&cache_dir, &model_id) {
60        Some(source) => source,
61        None => {
62            eprintln!("Downloading model...");
63            let token = std::env::var("HF_TOKEN")
64                .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
65                .ok();
66            let downloader = HfDownloader::new(cache_dir.clone(), token)?;
67            let snapshot_path = downloader.download(&model_id, None).await?;
68            let format = super::run::detect_format(&snapshot_path);
69            ferrum_models::source::ResolvedModelSource {
70                original: model_id.clone(),
71                local_path: snapshot_path,
72                format,
73                from_cache: false,
74            }
75        }
76    };
77
78    unsafe {
79        std::env::set_var(
80            "FERRUM_MODEL_PATH",
81            source.local_path.to_string_lossy().to_string(),
82        );
83    }
84
85    let device = super::run::select_device(&cmd.backend);
86    eprintln!("{} {:?}", "Device:".dimmed(), device);
87
88    // Show GPU info
89    #[cfg(feature = "cuda")]
90    {
91        if let Ok(d) = candle_core::Device::new_cuda(0) {
92            if let Ok(cd) = d.as_cuda_device() {
93                let name = cd.cuda_stream().context().name().unwrap_or_default();
94                eprintln!("GPU 0: {name}");
95            }
96        }
97        if let Ok(d) = candle_core::Device::new_cuda(1) {
98            if let Ok(cd) = d.as_cuda_device() {
99                let name = cd.cuda_stream().context().name().unwrap_or_default();
100                eprintln!("GPU 1: {name}");
101            }
102        }
103        #[cfg(feature = "cuda")]
104        {
105            let tp = std::env::var("FERRUM_TP")
106                .ok()
107                .and_then(|v| v.parse::<usize>().ok())
108                .unwrap_or_else(|| {
109                    candle_core::cuda_backend::cudarc::driver::CudaContext::device_count()
110                        .map(|n| n as usize)
111                        .unwrap_or(1)
112                });
113            if tp > 1 {
114                eprintln!("Tensor Parallel: TP={tp}");
115            }
116        }
117    }
118
119    // Create engine with ContinuousBatch scheduler (not Priority).
120    // DefaultInferenceEngine (Priority) has stream lifecycle issues with bench.
121    let mut engine_config = ferrum_engine::simple_engine_config(model_id.clone(), device);
122    engine_config.scheduler.policy = ferrum_types::SchedulingPolicy::ContinuousBatch;
123    let engine = ferrum_engine::create_mvp_engine(engine_config).await?;
124
125    let prompt = if cmd.long_context {
126        generate_long_prompt()
127    } else {
128        cmd.prompt.clone()
129    };
130
131    let mode_str = if cmd.concurrency > 1 {
132        format!("concurrent({})", cmd.concurrency)
133    } else if cmd.long_context {
134        "long-context".to_string()
135    } else {
136        "sequential".to_string()
137    };
138
139    eprintln!(
140        "{}",
141        format!(
142            "Config: {} rounds, {} max_tokens, mode={}, prompt_len=~{}chars",
143            cmd.rounds,
144            cmd.max_tokens,
145            mode_str,
146            prompt.len()
147        )
148        .dimmed()
149    );
150    eprintln!("{}", "=".repeat(60).dimmed());
151
152    // Warmup
153    eprintln!("{}", "Warmup...".dimmed());
154    let _ = run_single(&*engine, &model_id, "Hello", 16).await;
155    // Let engine finish cleanup before starting benchmark rounds
156    tokio::time::sleep(std::time::Duration::from_millis(500)).await;
157
158    if cmd.concurrency > 1 {
159        run_concurrent_bench(&*engine, &model_id, &prompt, &cmd).await
160    } else {
161        run_sequential_bench(&*engine, &model_id, &prompt, &cmd).await
162    }
163}
164
165// ── Sequential benchmark (existing behavior) ────────────────────────
166
167async fn run_sequential_bench(
168    engine: &(dyn ferrum_interfaces::InferenceEngine + Send + Sync),
169    model_id: &str,
170    prompt: &str,
171    cmd: &BenchCommand,
172) -> Result<()> {
173    let mut total_tokens: usize = 0;
174    let mut total_time_ms: f64 = 0.0;
175    let mut ttft_ms_list: Vec<f64> = Vec::new();
176    let mut tps_list: Vec<f64> = Vec::new();
177    let mut tpot_ms_list: Vec<f64> = Vec::new();
178    let mut decode_tps_list: Vec<f64> = Vec::new();
179
180    for round in 1..=cmd.rounds {
181        eprintln!("{}", format!("Round {}/{}...", round, cmd.rounds).dimmed());
182
183        let result = run_single(engine, model_id, prompt, cmd.max_tokens).await?;
184        let (tps, decode_tokens, tpot_ms, decode_tps) = compute_metrics(&result);
185
186        total_tokens += result.token_count;
187        total_time_ms += result.total_ms;
188        ttft_ms_list.push(result.ttft_ms);
189        tps_list.push(tps);
190        if decode_tokens > 0 {
191            tpot_ms_list.push(tpot_ms);
192            decode_tps_list.push(decode_tps);
193        }
194
195        eprintln!(
196            "  {} tokens in {:.1}ms ({:.1} tok/s, TTFT {:.1}ms, TPOT {:.2}ms, decode {:.1} tok/s)",
197            result.token_count, result.total_ms, tps, result.ttft_ms, tpot_ms, decode_tps
198        );
199
200        // Let engine finish cleanup between rounds
201        tokio::time::sleep(std::time::Duration::from_millis(200)).await;
202    }
203
204    print_summary(
205        model_id,
206        cmd,
207        "sequential",
208        total_tokens,
209        total_time_ms,
210        &tps_list,
211        &ttft_ms_list,
212        &tpot_ms_list,
213        &decode_tps_list,
214    );
215    Ok(())
216}
217
218// ── Concurrent benchmark (tests batch decode) ───────────────────────
219
220async fn run_concurrent_bench(
221    engine: &(dyn ferrum_interfaces::InferenceEngine + Send + Sync),
222    model_id: &str,
223    prompt: &str,
224    cmd: &BenchCommand,
225) -> Result<()> {
226    let mut total_tokens: usize = 0;
227    let mut total_time_ms: f64 = 0.0;
228    let mut all_ttft: Vec<f64> = Vec::new();
229    let mut all_tpot: Vec<f64> = Vec::new();
230    let mut round_tps_list: Vec<f64> = Vec::new();
231
232    for round in 1..=cmd.rounds {
233        eprintln!(
234            "{}",
235            format!(
236                "Round {}/{} ({} concurrent)...",
237                round, cmd.rounds, cmd.concurrency
238            )
239            .dimmed()
240        );
241
242        let round_start = Instant::now();
243
244        // Launch N concurrent requests
245        let mut handles = Vec::with_capacity(cmd.concurrency);
246        for _ in 0..cmd.concurrency {
247            let request = make_request(model_id, prompt, cmd.max_tokens);
248            let stream = engine.infer_stream(request).await?;
249            handles.push(tokio::spawn(collect_stream(stream)));
250        }
251
252        // Collect all results
253        let mut round_tokens = 0usize;
254        let mut round_results = Vec::new();
255        for handle in handles {
256            match handle.await {
257                Ok(Ok(result)) => {
258                    round_tokens += result.token_count;
259                    round_results.push(result);
260                }
261                Ok(Err(e)) => eprintln!("  request error: {e}"),
262                Err(e) => eprintln!("  join error: {e}"),
263            }
264        }
265
266        // Let the engine finish cleaning up completed requests
267        // (complete_request runs asynchronously after stream ends)
268        tokio::time::sleep(std::time::Duration::from_millis(200)).await;
269
270        let round_ms = round_start.elapsed().as_secs_f64() * 1000.0;
271        let round_tps = if round_ms > 0.0 {
272            round_tokens as f64 / (round_ms / 1000.0)
273        } else {
274            0.0
275        };
276
277        total_tokens += round_tokens;
278        total_time_ms += round_ms;
279        round_tps_list.push(round_tps);
280
281        // Per-request metrics
282        for r in &round_results {
283            all_ttft.push(r.ttft_ms);
284            let decode_tokens = r.token_count.saturating_sub(1);
285            if decode_tokens > 0 {
286                let tpot = (r.total_ms - r.ttft_ms) / decode_tokens as f64;
287                all_tpot.push(tpot);
288            }
289        }
290
291        let avg_ttft = round_results.iter().map(|r| r.ttft_ms).sum::<f64>()
292            / round_results.len().max(1) as f64;
293
294        eprintln!(
295            "  {} requests, {} tokens in {:.1}ms ({:.1} tok/s total, avg TTFT {:.1}ms)",
296            round_results.len(),
297            round_tokens,
298            round_ms,
299            round_tps,
300            avg_ttft
301        );
302    }
303
304    // Summary
305    let avg_tps = if !round_tps_list.is_empty() {
306        round_tps_list.iter().sum::<f64>() / round_tps_list.len() as f64
307    } else {
308        0.0
309    };
310    let avg_ttft = if !all_ttft.is_empty() {
311        all_ttft.iter().sum::<f64>() / all_ttft.len() as f64
312    } else {
313        0.0
314    };
315    let avg_tpot = if !all_tpot.is_empty() {
316        all_tpot.iter().sum::<f64>() / all_tpot.len() as f64
317    } else {
318        0.0
319    };
320
321    eprintln!();
322    eprintln!("{}", "=".repeat(60));
323    eprintln!("{}", "BENCHMARK RESULTS (concurrent)".bold());
324    eprintln!("{}", "=".repeat(60));
325    eprintln!("Model:             {}", model_id);
326    eprintln!(
327        "Backend:           {:?}",
328        super::run::select_device(&cmd.backend)
329    );
330    eprintln!("Rounds:            {}", cmd.rounds);
331    eprintln!("Concurrency:       {}", cmd.concurrency);
332    eprintln!("Max tokens/req:    {}", cmd.max_tokens);
333    eprintln!("{}", "-".repeat(60));
334    eprintln!(
335        "Throughput (total): {:.1} tok/s avg ({:.1} min, {:.1} max)",
336        avg_tps,
337        round_tps_list.iter().cloned().fold(f64::INFINITY, f64::min),
338        round_tps_list.iter().cloned().fold(0.0_f64, f64::max)
339    );
340    eprintln!(
341        "TTFT:              {:.1}ms avg, {:.1}ms p99",
342        avg_ttft,
343        percentile(&all_ttft, 99.0)
344    );
345    eprintln!(
346        "TPOT:              {:.2}ms avg, {:.2}ms p99",
347        avg_tpot,
348        percentile(&all_tpot, 99.0)
349    );
350    eprintln!("Total tokens:      {}", total_tokens);
351    eprintln!("Total time:        {:.1}ms", total_time_ms);
352    eprintln!("{}", "=".repeat(60));
353
354    Ok(())
355}
356
357// ── Helpers ─────────────────────────────────────────────────────────
358
359struct BenchResult {
360    token_count: usize,
361    ttft_ms: f64,
362    total_ms: f64,
363}
364
365fn compute_metrics(r: &BenchResult) -> (f64, usize, f64, f64) {
366    let tps = if r.total_ms > 0.0 {
367        r.token_count as f64 / (r.total_ms / 1000.0)
368    } else {
369        0.0
370    };
371    let decode_tokens = r.token_count.saturating_sub(1);
372    let decode_time_ms = r.total_ms - r.ttft_ms;
373    let tpot_ms = if decode_tokens > 0 {
374        decode_time_ms / decode_tokens as f64
375    } else {
376        0.0
377    };
378    let decode_tps = if decode_time_ms > 0.0 {
379        decode_tokens as f64 / (decode_time_ms / 1000.0)
380    } else {
381        0.0
382    };
383    (tps, decode_tokens, tpot_ms, decode_tps)
384}
385
386fn make_request(model_id: &str, prompt: &str, max_tokens: u32) -> InferenceRequest {
387    InferenceRequest {
388        id: RequestId(Uuid::new_v4()),
389        model_id: ferrum_types::ModelId(model_id.to_string()),
390        prompt: prompt.to_string(),
391        sampling_params: SamplingParams {
392            max_tokens: max_tokens as usize,
393            temperature: 0.7,
394            top_p: 0.9,
395            repetition_penalty: 1.1,
396            stop_sequences: vec![
397                "<|im_end|>".to_string(),
398                "</s>".to_string(),
399                "<|endoftext|>".to_string(),
400            ],
401            ..Default::default()
402        },
403        stream: true,
404        priority: Priority::Normal,
405        client_id: None,
406        session_id: None,
407        created_at: Utc::now(),
408        metadata: HashMap::new(),
409    }
410}
411
412async fn run_single(
413    engine: &(dyn ferrum_interfaces::InferenceEngine + Send + Sync),
414    model_id: &str,
415    prompt: &str,
416    max_tokens: u32,
417) -> Result<BenchResult> {
418    let request = make_request(model_id, prompt, max_tokens);
419    let stream = engine.infer_stream(request).await?;
420    collect_stream(stream).await
421}
422
423async fn collect_stream(
424    mut stream: std::pin::Pin<
425        Box<
426            dyn futures::Stream<
427                    Item = std::result::Result<
428                        ferrum_types::StreamChunk,
429                        ferrum_types::FerrumError,
430                    >,
431                > + Send,
432        >,
433    >,
434) -> Result<BenchResult> {
435    let start = Instant::now();
436    let mut token_count = 0usize;
437    let mut first_token_time: Option<f64> = None;
438
439    let mut got_finish = false;
440    while let Some(result) = stream.next().await {
441        match result {
442            Ok(chunk) => {
443                if chunk.token.is_some() {
444                    token_count += 1;
445                    if first_token_time.is_none() {
446                        first_token_time = Some(start.elapsed().as_secs_f64() * 1000.0);
447                    }
448                }
449                if chunk.finish_reason.is_some() {
450                    got_finish = true;
451                    break;
452                }
453            }
454            Err(_) => break,
455        }
456    }
457    if !got_finish && token_count > 0 {
458        eprintln!(
459            "  [warn] stream ended without finish_reason ({} tokens)",
460            token_count
461        );
462    }
463
464    Ok(BenchResult {
465        token_count,
466        ttft_ms: first_token_time.unwrap_or(0.0),
467        total_ms: start.elapsed().as_secs_f64() * 1000.0,
468    })
469}
470
471fn print_summary(
472    model_id: &str,
473    cmd: &BenchCommand,
474    mode: &str,
475    total_tokens: usize,
476    total_time_ms: f64,
477    tps_list: &[f64],
478    ttft_ms_list: &[f64],
479    tpot_ms_list: &[f64],
480    decode_tps_list: &[f64],
481) {
482    let avg_tps = if !tps_list.is_empty() {
483        tps_list.iter().sum::<f64>() / tps_list.len() as f64
484    } else {
485        0.0
486    };
487    let min_tps = tps_list.iter().cloned().fold(f64::INFINITY, f64::min);
488    let max_tps = tps_list.iter().cloned().fold(0.0_f64, f64::max);
489    let avg_ttft = if !ttft_ms_list.is_empty() {
490        ttft_ms_list.iter().sum::<f64>() / ttft_ms_list.len() as f64
491    } else {
492        0.0
493    };
494    let avg_decode_tps = if !decode_tps_list.is_empty() {
495        decode_tps_list.iter().sum::<f64>() / decode_tps_list.len() as f64
496    } else {
497        0.0
498    };
499    let avg_tpot = if !tpot_ms_list.is_empty() {
500        tpot_ms_list.iter().sum::<f64>() / tpot_ms_list.len() as f64
501    } else {
502        0.0
503    };
504
505    eprintln!();
506    eprintln!("{}", "=".repeat(60));
507    eprintln!("{}", format!("BENCHMARK RESULTS ({})", mode).bold());
508    eprintln!("{}", "=".repeat(60));
509    eprintln!("Model:             {}", model_id);
510    eprintln!(
511        "Backend:           {:?}",
512        super::run::select_device(&cmd.backend)
513    );
514    eprintln!("Rounds:            {}", cmd.rounds);
515    eprintln!("Max tokens/round:  {}", cmd.max_tokens);
516    eprintln!("{}", "-".repeat(60));
517    eprintln!(
518        "Throughput (e2e):  {:.1} tok/s avg ({:.1} min, {:.1} max)",
519        avg_tps, min_tps, max_tps
520    );
521    eprintln!("Decode only:       {:.1} tok/s avg", avg_decode_tps);
522    eprintln!(
523        "TTFT:              {:.1}ms avg, {:.1}ms p99",
524        avg_ttft,
525        percentile(ttft_ms_list, 99.0)
526    );
527    eprintln!(
528        "TPOT:              {:.2}ms avg, {:.2}ms p99",
529        avg_tpot,
530        percentile(tpot_ms_list, 99.0)
531    );
532    eprintln!("Total tokens:      {}", total_tokens);
533    eprintln!("Total time:        {:.1}ms", total_time_ms);
534    eprintln!("{}", "=".repeat(60));
535}
536
537/// Generate a ~2k token prompt for long-context benchmarking.
538fn generate_long_prompt() -> String {
539    let base = "The history of artificial intelligence is a fascinating journey through decades of research, breakthroughs, and setbacks. From the early days of symbolic AI in the 1950s, through the AI winters, to the modern era of deep learning and large language models, the field has undergone remarkable transformations. ";
540    // Repeat to get ~2k tokens worth of text (~8k chars)
541    let mut prompt = String::with_capacity(8192);
542    prompt.push_str("Please provide a comprehensive analysis of the following text, identifying key themes, patterns, and insights:\n\n");
543    for _ in 0..25 {
544        prompt.push_str(base);
545    }
546    prompt.push_str("\n\nNow analyze the above text in detail:");
547    prompt
548}
549
550fn percentile(data: &[f64], pct: f64) -> f64 {
551    if data.is_empty() {
552        return 0.0;
553    }
554    let mut sorted = data.to_vec();
555    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
556    let idx = ((pct / 100.0) * (sorted.len() - 1) as f64).ceil() as usize;
557    sorted[idx.min(sorted.len() - 1)]
558}