Skip to main content

axonml_train/
benchmark.rs

1//! Model Benchmarking Utilities
2//!
3//! # File
4//! `crates/axonml-train/src/benchmark.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use std::time::Instant;
19
20use axonml_autograd::Variable;
21
22use axonml_nn::Module;
23
24use crate::hub::BenchmarkResult;
25
26// =============================================================================
27// Benchmarking Functions
28// =============================================================================
29
30/// Warm up a model by running a few forward passes.
31///
32/// This helps stabilize timing measurements by ensuring any lazy initialization
33/// is complete and caches are populated.
34pub fn warmup_model<M: Module>(model: &M, input: &Variable, iterations: usize) {
35    for _ in 0..iterations {
36        let _ = model.forward(input);
37    }
38}
39
40/// Benchmark model inference.
41///
42/// Runs the model forward pass multiple times and collects timing statistics.
43pub fn benchmark_model<M: Module>(
44    model: &M,
45    input: &Variable,
46    iterations: usize,
47) -> BenchmarkResult {
48    let mut latencies = Vec::with_capacity(iterations);
49
50    for _ in 0..iterations {
51        let start = Instant::now();
52        let _ = model.forward(input);
53        let elapsed = start.elapsed();
54        latencies.push(elapsed.as_secs_f64() * 1000.0);
55    }
56
57    // Estimate memory usage from input/output sizes
58    let input_elements: usize = input.data().shape().iter().product();
59    let peak_memory = (input_elements * 4 * 3) as u64; // Rough estimate: input + output + intermediate
60
61    BenchmarkResult::new("model", &latencies, peak_memory)
62}
63
64/// Benchmark model with custom name.
65pub fn benchmark_model_named<M: Module>(
66    model: &M,
67    input: &Variable,
68    iterations: usize,
69    name: &str,
70) -> BenchmarkResult {
71    let mut result = benchmark_model(model, input, iterations);
72    result.model_name = name.to_string();
73    result
74}
75
76/// Compare multiple models on the same input.
77pub fn compare_models<M: Module>(
78    models: &[(&str, &M)],
79    input: &Variable,
80    iterations: usize,
81) -> Vec<BenchmarkResult> {
82    let mut results = Vec::new();
83
84    for (name, model) in models {
85        // Warmup
86        warmup_model(*model, input, 5);
87
88        // Benchmark
89        let result = benchmark_model_named(*model, input, iterations, name);
90        results.push(result);
91    }
92
93    results
94}
95
96// =============================================================================
97// Throughput Testing
98// =============================================================================
99
100/// Configuration for throughput testing.
101#[derive(Debug, Clone)]
102pub struct ThroughputConfig {
103    /// Batch sizes to test
104    pub batch_sizes: Vec<usize>,
105    /// Number of iterations per batch size
106    pub iterations: usize,
107    /// Warmup iterations
108    pub warmup: usize,
109}
110
111impl Default for ThroughputConfig {
112    fn default() -> Self {
113        Self {
114            batch_sizes: vec![1, 4, 8, 16, 32, 64],
115            iterations: 50,
116            warmup: 5,
117        }
118    }
119}
120
121/// Throughput test result for a single batch size.
122#[derive(Debug, Clone)]
123pub struct ThroughputResult {
124    /// Batch size
125    pub batch_size: usize,
126    /// Samples per second
127    pub throughput: f64,
128    /// Average latency in ms
129    pub latency_ms: f64,
130    /// Latency per sample in ms
131    pub latency_per_sample_ms: f64,
132}
133
134/// Run throughput tests across different batch sizes.
135pub fn throughput_test<M, F>(
136    model: &M,
137    input_fn: F,
138    config: &ThroughputConfig,
139) -> Vec<ThroughputResult>
140where
141    M: Module,
142    F: Fn(usize) -> Variable,
143{
144    let mut results = Vec::new();
145
146    for &batch_size in &config.batch_sizes {
147        let input = input_fn(batch_size);
148
149        // Warmup
150        warmup_model(model, &input, config.warmup);
151
152        // Benchmark
153        let bench = benchmark_model(model, &input, config.iterations);
154
155        results.push(ThroughputResult {
156            batch_size,
157            throughput: bench.throughput * batch_size as f64,
158            latency_ms: bench.avg_latency_ms,
159            latency_per_sample_ms: bench.avg_latency_ms / batch_size as f64,
160        });
161    }
162
163    results
164}
165
166/// Print throughput results in a table format.
167pub fn print_throughput_results(results: &[ThroughputResult]) {
168    println!(
169        "\n{:<12} {:>14} {:>14} {:>18}",
170        "Batch Size", "Throughput", "Latency (ms)", "Per Sample (ms)"
171    );
172    println!("{}", "-".repeat(60));
173
174    for result in results {
175        println!(
176            "{:<12} {:>12.1}/s {:>14.2} {:>18.3}",
177            result.batch_size, result.throughput, result.latency_ms, result.latency_per_sample_ms
178        );
179    }
180
181    // Find optimal batch size (highest throughput)
182    if let Some(best) = results.iter().max_by(|a, b| {
183        a.throughput
184            .partial_cmp(&b.throughput)
185            .unwrap_or(std::cmp::Ordering::Equal)
186    }) {
187        println!(
188            "\nOptimal batch size: {} ({:.1} samples/sec)",
189            best.batch_size, best.throughput
190        );
191    }
192}
193
194// =============================================================================
195// Memory Profiling
196// =============================================================================
197
198/// Memory usage snapshot.
199#[derive(Debug, Clone, Default)]
200pub struct MemorySnapshot {
201    /// Tensor allocations count
202    pub tensor_count: usize,
203    /// Total tensor memory in bytes
204    pub tensor_bytes: u64,
205    /// Parameter count
206    pub param_count: usize,
207    /// Parameter memory in bytes
208    pub param_bytes: u64,
209}
210
211impl MemorySnapshot {
212    /// Total memory in MB.
213    pub fn total_mb(&self) -> f64 {
214        (self.tensor_bytes + self.param_bytes) as f64 / 1_000_000.0
215    }
216}
217
218/// Profile memory usage of a model.
219pub fn profile_model_memory<M: Module>(model: &M) -> MemorySnapshot {
220    let params = model.parameters();
221    let param_count = params.len();
222
223    let param_bytes: u64 = params
224        .iter()
225        .map(|p| (p.numel() * 4) as u64) // 4 bytes per f32
226        .sum();
227
228    MemorySnapshot {
229        tensor_count: 0,
230        tensor_bytes: 0,
231        param_count,
232        param_bytes,
233    }
234}
235
236/// Print memory profile.
237pub fn print_memory_profile(snapshot: &MemorySnapshot, name: &str) {
238    println!("\nMemory Profile: {}", name);
239    println!(
240        "  Parameters: {} ({:.2} MB)",
241        snapshot.param_count,
242        snapshot.param_bytes as f64 / 1_000_000.0
243    );
244    println!(
245        "  Tensors: {} ({:.2} MB)",
246        snapshot.tensor_count,
247        snapshot.tensor_bytes as f64 / 1_000_000.0
248    );
249    println!("  Total: {:.2} MB", snapshot.total_mb());
250}
251
252// =============================================================================
253// Tests
254// =============================================================================
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn test_throughput_config_default() {
262        let config = ThroughputConfig::default();
263        assert!(!config.batch_sizes.is_empty());
264        assert!(config.iterations > 0);
265    }
266
267    #[test]
268    fn test_memory_snapshot_total() {
269        let snapshot = MemorySnapshot {
270            tensor_count: 10,
271            tensor_bytes: 1_000_000,
272            param_count: 5,
273            param_bytes: 500_000,
274        };
275        assert!((snapshot.total_mb() - 1.5).abs() < 0.01);
276    }
277
278    #[test]
279    fn test_throughput_result() {
280        let result = ThroughputResult {
281            batch_size: 32,
282            throughput: 1000.0,
283            latency_ms: 32.0,
284            latency_per_sample_ms: 1.0,
285        };
286        assert_eq!(result.batch_size, 32);
287        assert!((result.latency_per_sample_ms - 1.0).abs() < 0.01);
288    }
289
290    #[test]
291    fn test_benchmark_model() {
292        use axonml_nn::Linear;
293        use axonml_tensor::Tensor;
294
295        let model = Linear::new(10, 5);
296        let input = Variable::new(Tensor::randn(&[4, 10]), false);
297
298        warmup_model(&model, &input, 2);
299        let result = benchmark_model(&model, &input, 10);
300
301        assert_eq!(result.iterations, 10);
302        assert!(result.avg_latency_ms >= 0.0);
303        assert!(result.throughput > 0.0);
304    }
305
306    #[test]
307    fn test_profile_model_memory() {
308        use axonml_nn::Linear;
309
310        let model = Linear::new(100, 50);
311        let snapshot = profile_model_memory(&model);
312
313        assert!(snapshot.param_count > 0);
314        assert!(snapshot.param_bytes > 0);
315    }
316}