1use crate::config::CliConfig;
11use chrono::Utc;
12use clap::Args;
13use colored::*;
14use ferrum_models::HfDownloader;
15use ferrum_types::{InferenceRequest, Priority, RequestId, Result, SamplingParams};
16use futures::StreamExt;
17use std::collections::HashMap;
18use std::time::Instant;
19use uuid::Uuid;
20
21#[derive(Args)]
22pub struct BenchCommand {
23 #[arg(default_value = "qwen3:0.6b")]
25 pub model: String,
26
27 #[arg(long, default_value = "5")]
29 pub rounds: usize,
30
31 #[arg(long, default_value = "128")]
33 pub max_tokens: u32,
34
35 #[arg(long, default_value = "auto")]
37 pub backend: String,
38
39 #[arg(long, default_value = "Explain the theory of relativity in detail.")]
41 pub prompt: String,
42
43 #[arg(long, default_value = "1")]
45 pub concurrency: usize,
46
47 #[arg(long)]
49 pub long_context: bool,
50}
51
52pub async fn execute(cmd: BenchCommand, config: CliConfig) -> Result<()> {
53 let model_id = super::run::resolve_model_alias(&cmd.model);
54 eprintln!("{}", format!("Ferrum Benchmark - {}", model_id).bold());
55 eprintln!("{}", "=".repeat(60).dimmed());
56
57 let cache_dir = super::run::get_hf_cache_dir(&config);
59 let source = match super::run::find_cached_model(&cache_dir, &model_id) {
60 Some(source) => source,
61 None => {
62 eprintln!("Downloading model...");
63 let token = std::env::var("HF_TOKEN")
64 .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
65 .ok();
66 let downloader = HfDownloader::new(cache_dir.clone(), token)?;
67 let snapshot_path = downloader.download(&model_id, None).await?;
68 let format = super::run::detect_format(&snapshot_path);
69 ferrum_models::source::ResolvedModelSource {
70 original: model_id.clone(),
71 local_path: snapshot_path,
72 format,
73 from_cache: false,
74 }
75 }
76 };
77
78 unsafe {
79 std::env::set_var(
80 "FERRUM_MODEL_PATH",
81 source.local_path.to_string_lossy().to_string(),
82 );
83 }
84
85 let device = super::run::select_device(&cmd.backend);
86 eprintln!("{} {:?}", "Device:".dimmed(), device);
87
88 #[cfg(feature = "cuda")]
90 {
91 if let Ok(d) = candle_core::Device::new_cuda(0) {
92 if let Ok(cd) = d.as_cuda_device() {
93 let name = cd.cuda_stream().context().name().unwrap_or_default();
94 eprintln!("GPU 0: {name}");
95 }
96 }
97 if let Ok(d) = candle_core::Device::new_cuda(1) {
98 if let Ok(cd) = d.as_cuda_device() {
99 let name = cd.cuda_stream().context().name().unwrap_or_default();
100 eprintln!("GPU 1: {name}");
101 }
102 }
103 #[cfg(feature = "cuda")]
104 {
105 let tp = std::env::var("FERRUM_TP")
106 .ok()
107 .and_then(|v| v.parse::<usize>().ok())
108 .unwrap_or_else(|| {
109 candle_core::cuda_backend::cudarc::driver::CudaContext::device_count()
110 .map(|n| n as usize)
111 .unwrap_or(1)
112 });
113 if tp > 1 {
114 eprintln!("Tensor Parallel: TP={tp}");
115 }
116 }
117 }
118
119 let mut engine_config = ferrum_engine::simple_engine_config(model_id.clone(), device);
122 engine_config.scheduler.policy = ferrum_types::SchedulingPolicy::ContinuousBatch;
123 let engine = ferrum_engine::create_mvp_engine(engine_config).await?;
124
125 let prompt = if cmd.long_context {
126 generate_long_prompt()
127 } else {
128 cmd.prompt.clone()
129 };
130
131 let mode_str = if cmd.concurrency > 1 {
132 format!("concurrent({})", cmd.concurrency)
133 } else if cmd.long_context {
134 "long-context".to_string()
135 } else {
136 "sequential".to_string()
137 };
138
139 eprintln!(
140 "{}",
141 format!(
142 "Config: {} rounds, {} max_tokens, mode={}, prompt_len=~{}chars",
143 cmd.rounds,
144 cmd.max_tokens,
145 mode_str,
146 prompt.len()
147 )
148 .dimmed()
149 );
150 eprintln!("{}", "=".repeat(60).dimmed());
151
152 eprintln!("{}", "Warmup...".dimmed());
154 let _ = run_single(&*engine, &model_id, "Hello", 16).await;
155 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
157
158 if cmd.concurrency > 1 {
159 run_concurrent_bench(&*engine, &model_id, &prompt, &cmd).await
160 } else {
161 run_sequential_bench(&*engine, &model_id, &prompt, &cmd).await
162 }
163}
164
165async fn run_sequential_bench(
168 engine: &(dyn ferrum_interfaces::InferenceEngine + Send + Sync),
169 model_id: &str,
170 prompt: &str,
171 cmd: &BenchCommand,
172) -> Result<()> {
173 let mut total_tokens: usize = 0;
174 let mut total_time_ms: f64 = 0.0;
175 let mut ttft_ms_list: Vec<f64> = Vec::new();
176 let mut tps_list: Vec<f64> = Vec::new();
177 let mut tpot_ms_list: Vec<f64> = Vec::new();
178 let mut decode_tps_list: Vec<f64> = Vec::new();
179
180 for round in 1..=cmd.rounds {
181 eprintln!("{}", format!("Round {}/{}...", round, cmd.rounds).dimmed());
182
183 let result = run_single(engine, model_id, prompt, cmd.max_tokens).await?;
184 let (tps, decode_tokens, tpot_ms, decode_tps) = compute_metrics(&result);
185
186 total_tokens += result.token_count;
187 total_time_ms += result.total_ms;
188 ttft_ms_list.push(result.ttft_ms);
189 tps_list.push(tps);
190 if decode_tokens > 0 {
191 tpot_ms_list.push(tpot_ms);
192 decode_tps_list.push(decode_tps);
193 }
194
195 eprintln!(
196 " {} tokens in {:.1}ms ({:.1} tok/s, TTFT {:.1}ms, TPOT {:.2}ms, decode {:.1} tok/s)",
197 result.token_count, result.total_ms, tps, result.ttft_ms, tpot_ms, decode_tps
198 );
199
200 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
202 }
203
204 print_summary(
205 model_id,
206 cmd,
207 "sequential",
208 total_tokens,
209 total_time_ms,
210 &tps_list,
211 &ttft_ms_list,
212 &tpot_ms_list,
213 &decode_tps_list,
214 );
215 Ok(())
216}
217
218async fn run_concurrent_bench(
221 engine: &(dyn ferrum_interfaces::InferenceEngine + Send + Sync),
222 model_id: &str,
223 prompt: &str,
224 cmd: &BenchCommand,
225) -> Result<()> {
226 let mut total_tokens: usize = 0;
227 let mut total_time_ms: f64 = 0.0;
228 let mut all_ttft: Vec<f64> = Vec::new();
229 let mut all_tpot: Vec<f64> = Vec::new();
230 let mut round_tps_list: Vec<f64> = Vec::new();
231
232 for round in 1..=cmd.rounds {
233 eprintln!(
234 "{}",
235 format!(
236 "Round {}/{} ({} concurrent)...",
237 round, cmd.rounds, cmd.concurrency
238 )
239 .dimmed()
240 );
241
242 let round_start = Instant::now();
243
244 let mut handles = Vec::with_capacity(cmd.concurrency);
246 for _ in 0..cmd.concurrency {
247 let request = make_request(model_id, prompt, cmd.max_tokens);
248 let stream = engine.infer_stream(request).await?;
249 handles.push(tokio::spawn(collect_stream(stream)));
250 }
251
252 let mut round_tokens = 0usize;
254 let mut round_results = Vec::new();
255 for handle in handles {
256 match handle.await {
257 Ok(Ok(result)) => {
258 round_tokens += result.token_count;
259 round_results.push(result);
260 }
261 Ok(Err(e)) => eprintln!(" request error: {e}"),
262 Err(e) => eprintln!(" join error: {e}"),
263 }
264 }
265
266 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
269
270 let round_ms = round_start.elapsed().as_secs_f64() * 1000.0;
271 let round_tps = if round_ms > 0.0 {
272 round_tokens as f64 / (round_ms / 1000.0)
273 } else {
274 0.0
275 };
276
277 total_tokens += round_tokens;
278 total_time_ms += round_ms;
279 round_tps_list.push(round_tps);
280
281 for r in &round_results {
283 all_ttft.push(r.ttft_ms);
284 let decode_tokens = r.token_count.saturating_sub(1);
285 if decode_tokens > 0 {
286 let tpot = (r.total_ms - r.ttft_ms) / decode_tokens as f64;
287 all_tpot.push(tpot);
288 }
289 }
290
291 let avg_ttft = round_results.iter().map(|r| r.ttft_ms).sum::<f64>()
292 / round_results.len().max(1) as f64;
293
294 eprintln!(
295 " {} requests, {} tokens in {:.1}ms ({:.1} tok/s total, avg TTFT {:.1}ms)",
296 round_results.len(),
297 round_tokens,
298 round_ms,
299 round_tps,
300 avg_ttft
301 );
302 }
303
304 let avg_tps = if !round_tps_list.is_empty() {
306 round_tps_list.iter().sum::<f64>() / round_tps_list.len() as f64
307 } else {
308 0.0
309 };
310 let avg_ttft = if !all_ttft.is_empty() {
311 all_ttft.iter().sum::<f64>() / all_ttft.len() as f64
312 } else {
313 0.0
314 };
315 let avg_tpot = if !all_tpot.is_empty() {
316 all_tpot.iter().sum::<f64>() / all_tpot.len() as f64
317 } else {
318 0.0
319 };
320
321 eprintln!();
322 eprintln!("{}", "=".repeat(60));
323 eprintln!("{}", "BENCHMARK RESULTS (concurrent)".bold());
324 eprintln!("{}", "=".repeat(60));
325 eprintln!("Model: {}", model_id);
326 eprintln!(
327 "Backend: {:?}",
328 super::run::select_device(&cmd.backend)
329 );
330 eprintln!("Rounds: {}", cmd.rounds);
331 eprintln!("Concurrency: {}", cmd.concurrency);
332 eprintln!("Max tokens/req: {}", cmd.max_tokens);
333 eprintln!("{}", "-".repeat(60));
334 eprintln!(
335 "Throughput (total): {:.1} tok/s avg ({:.1} min, {:.1} max)",
336 avg_tps,
337 round_tps_list.iter().cloned().fold(f64::INFINITY, f64::min),
338 round_tps_list.iter().cloned().fold(0.0_f64, f64::max)
339 );
340 eprintln!(
341 "TTFT: {:.1}ms avg, {:.1}ms p99",
342 avg_ttft,
343 percentile(&all_ttft, 99.0)
344 );
345 eprintln!(
346 "TPOT: {:.2}ms avg, {:.2}ms p99",
347 avg_tpot,
348 percentile(&all_tpot, 99.0)
349 );
350 eprintln!("Total tokens: {}", total_tokens);
351 eprintln!("Total time: {:.1}ms", total_time_ms);
352 eprintln!("{}", "=".repeat(60));
353
354 Ok(())
355}
356
357struct BenchResult {
360 token_count: usize,
361 ttft_ms: f64,
362 total_ms: f64,
363}
364
365fn compute_metrics(r: &BenchResult) -> (f64, usize, f64, f64) {
366 let tps = if r.total_ms > 0.0 {
367 r.token_count as f64 / (r.total_ms / 1000.0)
368 } else {
369 0.0
370 };
371 let decode_tokens = r.token_count.saturating_sub(1);
372 let decode_time_ms = r.total_ms - r.ttft_ms;
373 let tpot_ms = if decode_tokens > 0 {
374 decode_time_ms / decode_tokens as f64
375 } else {
376 0.0
377 };
378 let decode_tps = if decode_time_ms > 0.0 {
379 decode_tokens as f64 / (decode_time_ms / 1000.0)
380 } else {
381 0.0
382 };
383 (tps, decode_tokens, tpot_ms, decode_tps)
384}
385
386fn make_request(model_id: &str, prompt: &str, max_tokens: u32) -> InferenceRequest {
387 InferenceRequest {
388 id: RequestId(Uuid::new_v4()),
389 model_id: ferrum_types::ModelId(model_id.to_string()),
390 prompt: prompt.to_string(),
391 sampling_params: SamplingParams {
392 max_tokens: max_tokens as usize,
393 temperature: 0.7,
394 top_p: 0.9,
395 repetition_penalty: 1.1,
396 stop_sequences: vec![
397 "<|im_end|>".to_string(),
398 "</s>".to_string(),
399 "<|endoftext|>".to_string(),
400 ],
401 ..Default::default()
402 },
403 stream: true,
404 priority: Priority::Normal,
405 client_id: None,
406 session_id: None,
407 created_at: Utc::now(),
408 metadata: HashMap::new(),
409 }
410}
411
412async fn run_single(
413 engine: &(dyn ferrum_interfaces::InferenceEngine + Send + Sync),
414 model_id: &str,
415 prompt: &str,
416 max_tokens: u32,
417) -> Result<BenchResult> {
418 let request = make_request(model_id, prompt, max_tokens);
419 let stream = engine.infer_stream(request).await?;
420 collect_stream(stream).await
421}
422
423async fn collect_stream(
424 mut stream: std::pin::Pin<
425 Box<
426 dyn futures::Stream<
427 Item = std::result::Result<
428 ferrum_types::StreamChunk,
429 ferrum_types::FerrumError,
430 >,
431 > + Send,
432 >,
433 >,
434) -> Result<BenchResult> {
435 let start = Instant::now();
436 let mut token_count = 0usize;
437 let mut first_token_time: Option<f64> = None;
438
439 let mut got_finish = false;
440 while let Some(result) = stream.next().await {
441 match result {
442 Ok(chunk) => {
443 if chunk.token.is_some() {
444 token_count += 1;
445 if first_token_time.is_none() {
446 first_token_time = Some(start.elapsed().as_secs_f64() * 1000.0);
447 }
448 }
449 if chunk.finish_reason.is_some() {
450 got_finish = true;
451 break;
452 }
453 }
454 Err(_) => break,
455 }
456 }
457 if !got_finish && token_count > 0 {
458 eprintln!(
459 " [warn] stream ended without finish_reason ({} tokens)",
460 token_count
461 );
462 }
463
464 Ok(BenchResult {
465 token_count,
466 ttft_ms: first_token_time.unwrap_or(0.0),
467 total_ms: start.elapsed().as_secs_f64() * 1000.0,
468 })
469}
470
471fn print_summary(
472 model_id: &str,
473 cmd: &BenchCommand,
474 mode: &str,
475 total_tokens: usize,
476 total_time_ms: f64,
477 tps_list: &[f64],
478 ttft_ms_list: &[f64],
479 tpot_ms_list: &[f64],
480 decode_tps_list: &[f64],
481) {
482 let avg_tps = if !tps_list.is_empty() {
483 tps_list.iter().sum::<f64>() / tps_list.len() as f64
484 } else {
485 0.0
486 };
487 let min_tps = tps_list.iter().cloned().fold(f64::INFINITY, f64::min);
488 let max_tps = tps_list.iter().cloned().fold(0.0_f64, f64::max);
489 let avg_ttft = if !ttft_ms_list.is_empty() {
490 ttft_ms_list.iter().sum::<f64>() / ttft_ms_list.len() as f64
491 } else {
492 0.0
493 };
494 let avg_decode_tps = if !decode_tps_list.is_empty() {
495 decode_tps_list.iter().sum::<f64>() / decode_tps_list.len() as f64
496 } else {
497 0.0
498 };
499 let avg_tpot = if !tpot_ms_list.is_empty() {
500 tpot_ms_list.iter().sum::<f64>() / tpot_ms_list.len() as f64
501 } else {
502 0.0
503 };
504
505 eprintln!();
506 eprintln!("{}", "=".repeat(60));
507 eprintln!("{}", format!("BENCHMARK RESULTS ({})", mode).bold());
508 eprintln!("{}", "=".repeat(60));
509 eprintln!("Model: {}", model_id);
510 eprintln!(
511 "Backend: {:?}",
512 super::run::select_device(&cmd.backend)
513 );
514 eprintln!("Rounds: {}", cmd.rounds);
515 eprintln!("Max tokens/round: {}", cmd.max_tokens);
516 eprintln!("{}", "-".repeat(60));
517 eprintln!(
518 "Throughput (e2e): {:.1} tok/s avg ({:.1} min, {:.1} max)",
519 avg_tps, min_tps, max_tps
520 );
521 eprintln!("Decode only: {:.1} tok/s avg", avg_decode_tps);
522 eprintln!(
523 "TTFT: {:.1}ms avg, {:.1}ms p99",
524 avg_ttft,
525 percentile(ttft_ms_list, 99.0)
526 );
527 eprintln!(
528 "TPOT: {:.2}ms avg, {:.2}ms p99",
529 avg_tpot,
530 percentile(tpot_ms_list, 99.0)
531 );
532 eprintln!("Total tokens: {}", total_tokens);
533 eprintln!("Total time: {:.1}ms", total_time_ms);
534 eprintln!("{}", "=".repeat(60));
535}
536
537fn generate_long_prompt() -> String {
539 let base = "The history of artificial intelligence is a fascinating journey through decades of research, breakthroughs, and setbacks. From the early days of symbolic AI in the 1950s, through the AI winters, to the modern era of deep learning and large language models, the field has undergone remarkable transformations. ";
540 let mut prompt = String::with_capacity(8192);
542 prompt.push_str("Please provide a comprehensive analysis of the following text, identifying key themes, patterns, and insights:\n\n");
543 for _ in 0..25 {
544 prompt.push_str(base);
545 }
546 prompt.push_str("\n\nNow analyze the above text in detail:");
547 prompt
548}
549
550fn percentile(data: &[f64], pct: f64) -> f64 {
551 if data.is_empty() {
552 return 0.0;
553 }
554 let mut sorted = data.to_vec();
555 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
556 let idx = ((pct / 100.0) * (sorted.len() - 1) as f64).ceil() as usize;
557 sorted[idx.min(sorted.len() - 1)]
558}