kizzasi_model/
profiling.rs

1//! Model Profiling and Benchmarking Utilities
2//!
3//! Provides tools for measuring model performance:
4//! - **Latency measurement**: Per-step and batched inference timing
5//! - **Memory profiling**: Track memory usage and allocations
6//! - **Throughput analysis**: Tokens/sequences per second
7//! - **Bottleneck identification**: Find performance-critical sections
8//!
9//! # Example
10//!
11//! ```rust,ignore
12//! use kizzasi_model::profiling::ModelProfiler;
13//! use kizzasi_model::mamba::Mamba;
14//!
15//! let model = Mamba::new(config)?;
16//! let profiler = ModelProfiler::new(model);
17//!
18//! let results = profiler.profile_inference(num_steps)?;
19//! println!("Average latency: {:.2}μs", results.avg_latency_us);
20//! ```
21
22use crate::error::ModelResult;
23use crate::AutoregressiveModel;
24use scirs2_core::ndarray::Array1;
25use std::time::{Duration, Instant};
26
27/// Profiling results for model inference
28#[derive(Debug, Clone)]
29pub struct ProfilingResults {
30    /// Total number of steps profiled
31    pub num_steps: usize,
32    /// Total execution time
33    pub total_duration: Duration,
34    /// Average latency per step
35    pub avg_latency_us: f64,
36    /// Minimum latency observed
37    pub min_latency_us: f64,
38    /// Maximum latency observed
39    pub max_latency_us: f64,
40    /// Median latency
41    pub median_latency_us: f64,
42    /// 95th percentile latency
43    pub p95_latency_us: f64,
44    /// 99th percentile latency
45    pub p99_latency_us: f64,
46    /// Throughput (steps per second)
47    pub throughput_steps_per_sec: f64,
48    /// Standard deviation of latency
49    pub std_dev_us: f64,
50}
51
52impl ProfilingResults {
53    /// Create results from timing measurements
54    pub fn from_timings(timings: &[Duration]) -> Self {
55        let num_steps = timings.len();
56
57        if num_steps == 0 {
58            return Self::default();
59        }
60
61        // Convert to microseconds
62        let mut latencies_us: Vec<f64> = timings
63            .iter()
64            .map(|d| d.as_secs_f64() * 1_000_000.0)
65            .collect();
66
67        latencies_us.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
68
69        let total_duration: Duration = timings.iter().sum();
70        let total_us = total_duration.as_secs_f64() * 1_000_000.0;
71        let avg_latency_us = total_us / num_steps as f64;
72
73        let min_latency_us = latencies_us[0];
74        let max_latency_us = latencies_us[num_steps - 1];
75
76        let median_latency_us = if num_steps.is_multiple_of(2) {
77            (latencies_us[num_steps / 2 - 1] + latencies_us[num_steps / 2]) / 2.0
78        } else {
79            latencies_us[num_steps / 2]
80        };
81
82        let p95_idx = ((num_steps as f64 * 0.95) as usize).min(num_steps - 1);
83        let p95_latency_us = latencies_us[p95_idx];
84
85        let p99_idx = ((num_steps as f64 * 0.99) as usize).min(num_steps - 1);
86        let p99_latency_us = latencies_us[p99_idx];
87
88        let throughput_steps_per_sec = num_steps as f64 / total_duration.as_secs_f64();
89
90        // Calculate standard deviation
91        let variance = latencies_us
92            .iter()
93            .map(|&x| (x - avg_latency_us).powi(2))
94            .sum::<f64>()
95            / num_steps as f64;
96        let std_dev_us = variance.sqrt();
97
98        Self {
99            num_steps,
100            total_duration,
101            avg_latency_us,
102            min_latency_us,
103            max_latency_us,
104            median_latency_us,
105            p95_latency_us,
106            p99_latency_us,
107            throughput_steps_per_sec,
108            std_dev_us,
109        }
110    }
111
112    /// Format results as a human-readable string
113    pub fn format_report(&self) -> String {
114        format!(
115            "Profiling Results:\n\
116             ==================\n\
117             Steps:              {}\n\
118             Total Duration:     {:.2}ms\n\
119             Average Latency:    {:.2}μs\n\
120             Min Latency:        {:.2}μs\n\
121             Max Latency:        {:.2}μs\n\
122             Median Latency:     {:.2}μs\n\
123             P95 Latency:        {:.2}μs\n\
124             P99 Latency:        {:.2}μs\n\
125             Std Dev:            {:.2}μs\n\
126             Throughput:         {:.2} steps/sec\n",
127            self.num_steps,
128            self.total_duration.as_secs_f64() * 1000.0,
129            self.avg_latency_us,
130            self.min_latency_us,
131            self.max_latency_us,
132            self.median_latency_us,
133            self.p95_latency_us,
134            self.p99_latency_us,
135            self.std_dev_us,
136            self.throughput_steps_per_sec,
137        )
138    }
139}
140
141impl Default for ProfilingResults {
142    fn default() -> Self {
143        Self {
144            num_steps: 0,
145            total_duration: Duration::ZERO,
146            avg_latency_us: 0.0,
147            min_latency_us: 0.0,
148            max_latency_us: 0.0,
149            median_latency_us: 0.0,
150            p95_latency_us: 0.0,
151            p99_latency_us: 0.0,
152            throughput_steps_per_sec: 0.0,
153            std_dev_us: 0.0,
154        }
155    }
156}
157
158/// Model profiler for performance measurement
159pub struct ModelProfiler<M: AutoregressiveModel> {
160    model: M,
161    warmup_steps: usize,
162}
163
164impl<M: AutoregressiveModel> ModelProfiler<M> {
165    /// Create a new profiler
166    pub fn new(model: M) -> Self {
167        Self {
168            model,
169            warmup_steps: 10,
170        }
171    }
172
173    /// Set number of warmup steps (default: 10)
174    pub fn warmup_steps(mut self, steps: usize) -> Self {
175        self.warmup_steps = steps;
176        self
177    }
178
179    /// Profile single-step inference
180    ///
181    /// # Arguments
182    ///
183    /// * `num_steps` - Number of inference steps to measure
184    /// * `input_dim` - Input dimension
185    ///
186    /// # Returns
187    ///
188    /// Profiling results with timing statistics
189    pub fn profile_inference(
190        &mut self,
191        num_steps: usize,
192        input_dim: usize,
193    ) -> ModelResult<ProfilingResults> {
194        // Warmup
195        let warmup_input = Array1::zeros(input_dim);
196        for _ in 0..self.warmup_steps {
197            let _ = self.model.step(&warmup_input)?;
198        }
199
200        // Reset state for clean measurements
201        self.model.reset();
202
203        // Measure inference
204        let mut timings = Vec::with_capacity(num_steps);
205        let input = Array1::from_elem(input_dim, 1.0);
206
207        for _ in 0..num_steps {
208            let start = Instant::now();
209            let _ = self.model.step(&input)?;
210            timings.push(start.elapsed());
211        }
212
213        Ok(ProfilingResults::from_timings(&timings))
214    }
215
216    /// Profile with varying input sizes
217    ///
218    /// Useful for understanding how performance scales with input dimension
219    pub fn profile_input_scaling(
220        &mut self,
221        input_dims: &[usize],
222        steps_per_dim: usize,
223    ) -> ModelResult<Vec<(usize, ProfilingResults)>> {
224        let mut results = Vec::with_capacity(input_dims.len());
225
226        for &dim in input_dims {
227            self.model.reset();
228            let profile = self.profile_inference(steps_per_dim, dim)?;
229            results.push((dim, profile));
230        }
231
232        Ok(results)
233    }
234
235    /// Profile memory usage (estimates based on model dimensions)
236    pub fn estimate_memory_usage(&self) -> MemoryProfile {
237        let hidden_dim = self.model.hidden_dim();
238        let state_dim = self.model.state_dim();
239        let num_layers = self.model.num_layers();
240
241        // Estimate per-layer state size (in bytes)
242        // Typical SSM: hidden_dim * state_dim * 4 bytes (f32)
243        let state_bytes_per_layer = hidden_dim * state_dim * 4;
244        let total_state_bytes = state_bytes_per_layer * num_layers;
245
246        // Estimate weight memory (rough approximation)
247        // Typical: multiple weight matrices per layer
248        let weight_estimate = hidden_dim * hidden_dim * 4 * num_layers * 5; // ~5 matrices per layer
249
250        MemoryProfile {
251            hidden_dim,
252            state_dim,
253            num_layers,
254            state_memory_bytes: total_state_bytes,
255            estimated_weight_memory_bytes: weight_estimate,
256            total_estimated_bytes: total_state_bytes + weight_estimate,
257        }
258    }
259
260    /// Get reference to underlying model
261    pub fn model(&self) -> &M {
262        &self.model
263    }
264
265    /// Get mutable reference to underlying model
266    pub fn model_mut(&mut self) -> &mut M {
267        &mut self.model
268    }
269
270    /// Consume profiler and return model
271    pub fn into_model(self) -> M {
272        self.model
273    }
274}
275
276/// Memory usage profile
277#[derive(Debug, Clone)]
278pub struct MemoryProfile {
279    /// Hidden dimension
280    pub hidden_dim: usize,
281    /// State dimension
282    pub state_dim: usize,
283    /// Number of layers
284    pub num_layers: usize,
285    /// Memory for hidden states (bytes)
286    pub state_memory_bytes: usize,
287    /// Estimated weight memory (bytes)
288    pub estimated_weight_memory_bytes: usize,
289    /// Total estimated memory (bytes)
290    pub total_estimated_bytes: usize,
291}
292
293impl MemoryProfile {
294    /// Format memory profile as human-readable string
295    pub fn format_report(&self) -> String {
296        format!(
297            "Memory Profile:\n\
298             ===============\n\
299             Hidden Dim:         {}\n\
300             State Dim:          {}\n\
301             Num Layers:         {}\n\
302             State Memory:       {:.2} MB\n\
303             Weight Memory:      {:.2} MB (estimated)\n\
304             Total Memory:       {:.2} MB (estimated)\n",
305            self.hidden_dim,
306            self.state_dim,
307            self.num_layers,
308            self.state_memory_bytes as f64 / 1_048_576.0,
309            self.estimated_weight_memory_bytes as f64 / 1_048_576.0,
310            self.total_estimated_bytes as f64 / 1_048_576.0,
311        )
312    }
313}
314
315/// Benchmark suite for comparing models
316pub struct BenchmarkSuite {
317    num_steps: usize,
318    warmup_steps: usize,
319    input_dim: usize,
320}
321
322impl BenchmarkSuite {
323    /// Create a new benchmark suite
324    pub fn new() -> Self {
325        Self {
326            num_steps: 1000,
327            warmup_steps: 10,
328            input_dim: 1,
329        }
330    }
331
332    /// Set number of benchmark steps
333    pub fn num_steps(mut self, steps: usize) -> Self {
334        self.num_steps = steps;
335        self
336    }
337
338    /// Set number of warmup steps
339    pub fn warmup_steps(mut self, steps: usize) -> Self {
340        self.warmup_steps = steps;
341        self
342    }
343
344    /// Set input dimension
345    pub fn input_dim(mut self, dim: usize) -> Self {
346        self.input_dim = dim;
347        self
348    }
349
350    /// Run benchmark on a model
351    pub fn benchmark<M: AutoregressiveModel>(&self, model: M) -> ModelResult<ProfilingResults> {
352        let mut profiler = ModelProfiler::new(model).warmup_steps(self.warmup_steps);
353        profiler.profile_inference(self.num_steps, self.input_dim)
354    }
355}
356
357impl Default for BenchmarkSuite {
358    fn default() -> Self {
359        Self::new()
360    }
361}
362
363/// Bottleneck severity level
364#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
365pub enum BottleneckSeverity {
366    /// Low severity - minor optimization opportunity
367    Low,
368    /// Medium severity - noticeable performance impact
369    Medium,
370    /// High severity - significant performance bottleneck
371    High,
372    /// Critical severity - major performance issue
373    Critical,
374}
375
376/// Information about a specific bottleneck
377#[derive(Debug, Clone)]
378pub struct BottleneckInfo {
379    /// Name of the bottleneck (e.g., "Matrix multiplication")
380    pub name: String,
381    /// Description of the issue
382    pub description: String,
383    /// Estimated time spent (microseconds)
384    pub estimated_time_us: f64,
385    /// Percentage of total execution time
386    pub percentage_of_total: f64,
387    /// Severity level
388    pub severity: BottleneckSeverity,
389    /// Recommended optimizations
390    pub recommendations: Vec<String>,
391}
392
393/// Bottleneck analysis for a single model
394#[derive(Debug, Clone)]
395pub struct ModelBottleneckAnalysis {
396    /// Model type name
397    pub model_name: String,
398    /// Profiling results
399    pub results: ProfilingResults,
400    /// Memory profile
401    pub memory: MemoryProfile,
402    /// Identified bottlenecks
403    pub bottlenecks: Vec<BottleneckInfo>,
404    /// Overall performance rating (0-100, higher is better)
405    pub performance_score: f64,
406}
407
408impl ModelBottleneckAnalysis {
409    /// Analyze a model for bottlenecks
410    pub fn analyze<M: AutoregressiveModel>(
411        model: M,
412        model_name: String,
413        num_steps: usize,
414    ) -> ModelResult<Self> {
415        let mut profiler = ModelProfiler::new(model).warmup_steps(10);
416
417        // Profile inference
418        let results = profiler.profile_inference(num_steps, 1)?;
419
420        // Get memory profile
421        let memory = profiler.estimate_memory_usage();
422
423        // Identify bottlenecks based on performance characteristics
424        let bottlenecks = Self::identify_bottlenecks(&results, &memory, &model_name);
425
426        // Calculate performance score (0-100)
427        let performance_score = Self::calculate_performance_score(&results, &memory);
428
429        Ok(Self {
430            model_name,
431            results,
432            memory,
433            bottlenecks,
434            performance_score,
435        })
436    }
437
438    /// Identify bottlenecks from profiling results
439    fn identify_bottlenecks(
440        results: &ProfilingResults,
441        memory: &MemoryProfile,
442        model_name: &str,
443    ) -> Vec<BottleneckInfo> {
444        let mut bottlenecks = Vec::new();
445
446        // Check latency bottlenecks
447        if results.avg_latency_us > 200.0 {
448            let severity = if results.avg_latency_us > 1000.0 {
449                BottleneckSeverity::Critical
450            } else if results.avg_latency_us > 500.0 {
451                BottleneckSeverity::High
452            } else {
453                BottleneckSeverity::Medium
454            };
455
456            bottlenecks.push(BottleneckInfo {
457                name: "High average latency".to_string(),
458                description: format!(
459                    "Average latency of {:.2}μs exceeds target of 200μs",
460                    results.avg_latency_us
461                ),
462                estimated_time_us: results.avg_latency_us,
463                percentage_of_total: 100.0,
464                severity,
465                recommendations: vec![
466                    "Consider using SIMD optimizations".to_string(),
467                    "Enable parallel processing for multi-head operations".to_string(),
468                    "Use cache-friendly memory layouts".to_string(),
469                ],
470            });
471        }
472
473        // Check latency variance (high variance indicates unstable performance)
474        let cv = results.std_dev_us / results.avg_latency_us; // Coefficient of variation
475        if cv > 0.5 {
476            bottlenecks.push(BottleneckInfo {
477                name: "High latency variance".to_string(),
478                description: format!(
479                    "Standard deviation {:.2}μs is {:.1}% of mean, indicating unstable performance",
480                    results.std_dev_us,
481                    cv * 100.0
482                ),
483                estimated_time_us: results.std_dev_us,
484                percentage_of_total: cv * 100.0,
485                severity: if cv > 1.0 {
486                    BottleneckSeverity::High
487                } else {
488                    BottleneckSeverity::Medium
489                },
490                recommendations: vec![
491                    "Investigate cache misses and memory allocation patterns".to_string(),
492                    "Use memory pooling to reduce allocation variance".to_string(),
493                ],
494            });
495        }
496
497        // Check memory usage
498        let memory_mb = memory.total_estimated_bytes as f64 / (1024.0 * 1024.0);
499        if memory_mb > 100.0 {
500            bottlenecks.push(BottleneckInfo {
501                name: "High memory usage".to_string(),
502                description: format!("Estimated memory usage of {:.2}MB is high", memory_mb),
503                estimated_time_us: 0.0,
504                percentage_of_total: 0.0,
505                severity: if memory_mb > 500.0 {
506                    BottleneckSeverity::High
507                } else {
508                    BottleneckSeverity::Medium
509                },
510                recommendations: vec![
511                    "Consider quantization (INT8/FP16) to reduce memory footprint".to_string(),
512                    "Use sparse representations where applicable".to_string(),
513                ],
514            });
515        }
516
517        // Model-specific bottleneck identification
518        if model_name.contains("Transformer") {
519            // Transformer-specific bottlenecks
520            if results.avg_latency_us > 500.0 {
521                bottlenecks.push(BottleneckInfo {
522                    name: "Quadratic attention complexity".to_string(),
523                    description: "Transformer attention has O(N²) complexity per step".to_string(),
524                    estimated_time_us: results.avg_latency_us * 0.7, // Estimate 70% in attention
525                    percentage_of_total: 70.0,
526                    severity: BottleneckSeverity::High,
527                    recommendations: vec![
528                        "Consider using linear attention variants (e.g., Performers)".to_string(),
529                        "Use Flash Attention for memory-efficient attention".to_string(),
530                        "Switch to SSM-based models (Mamba, RWKV) for O(1) inference".to_string(),
531                    ],
532                });
533            }
534        }
535
536        bottlenecks
537    }
538
539    /// Calculate overall performance score (0-100)
540    fn calculate_performance_score(results: &ProfilingResults, memory: &MemoryProfile) -> f64 {
541        // Target: <100μs latency, <50MB memory
542        let latency_score = 100.0 * (100.0 / (results.avg_latency_us + 100.0));
543        let memory_mb = memory.total_estimated_bytes as f64 / (1024.0 * 1024.0);
544        let memory_score = 100.0 * (50.0 / (memory_mb + 50.0));
545
546        // Variance penalty
547        let cv = results.std_dev_us / results.avg_latency_us;
548        let stability_score = 100.0 * (1.0 / (1.0 + cv));
549
550        // Weighted average
551        (latency_score * 0.5 + memory_score * 0.3 + stability_score * 0.2).min(100.0)
552    }
553
554    /// Format analysis as a detailed report
555    pub fn format_report(&self) -> String {
556        let mut report = String::new();
557
558        report.push_str("\n═══════════════════════════════════════\n");
559        report.push_str(&format!("  {} Analysis Report\n", self.model_name));
560        report.push_str("═══════════════════════════════════════\n\n");
561
562        // Performance score
563        report.push_str(&format!(
564            "Performance Score: {:.1}/100\n\n",
565            self.performance_score
566        ));
567
568        // Profiling results
569        report.push_str(&self.results.format_report());
570        report.push('\n');
571
572        // Memory profile
573        report.push_str(&self.memory.format_report());
574        report.push('\n');
575
576        // Bottlenecks
577        if self.bottlenecks.is_empty() {
578            report.push_str("✓ No significant bottlenecks identified!\n");
579        } else {
580            report.push_str(&format!(
581                "⚠ {} Bottleneck(s) Identified:\n\n",
582                self.bottlenecks.len()
583            ));
584
585            for (i, bottleneck) in self.bottlenecks.iter().enumerate() {
586                let severity_icon = match bottleneck.severity {
587                    BottleneckSeverity::Low => "ℹ",
588                    BottleneckSeverity::Medium => "⚠",
589                    BottleneckSeverity::High => "⚠⚠",
590                    BottleneckSeverity::Critical => "🔥",
591                };
592
593                report.push_str(&format!(
594                    "{}. {} {}\n",
595                    i + 1,
596                    severity_icon,
597                    bottleneck.name
598                ));
599                report.push_str(&format!("   {}\n", bottleneck.description));
600
601                if bottleneck.estimated_time_us > 0.0 {
602                    report.push_str(&format!(
603                        "   Time: {:.2}μs ({:.1}% of total)\n",
604                        bottleneck.estimated_time_us, bottleneck.percentage_of_total
605                    ));
606                }
607
608                report.push_str("   Recommendations:\n");
609                for rec in &bottleneck.recommendations {
610                    report.push_str(&format!("     • {}\n", rec));
611                }
612                report.push('\n');
613            }
614        }
615
616        report
617    }
618}
619
620/// Comprehensive comparison of all models
621#[derive(Debug, Clone)]
622pub struct ComprehensiveComparison {
623    /// Analyses for each model
624    pub analyses: Vec<ModelBottleneckAnalysis>,
625    /// Best performing model by latency
626    pub fastest_model: String,
627    /// Most memory efficient model
628    pub most_memory_efficient: String,
629    /// Overall best model (by performance score)
630    pub best_overall: String,
631}
632
633impl ComprehensiveComparison {
634    /// Generate comparison report
635    pub fn format_report(&self) -> String {
636        let mut report = String::new();
637
638        report.push('\n');
639        report.push_str("╔═══════════════════════════════════════════════════════════════╗\n");
640        report.push_str("║      COMPREHENSIVE MODEL PERFORMANCE COMPARISON               ║\n");
641        report.push_str("╚═══════════════════════════════════════════════════════════════╝\n\n");
642
643        // Summary table
644        report.push_str("┌─────────────┬──────────────┬──────────────┬──────────────┬────────┐\n");
645        report.push_str("│ Model       │ Avg Latency  │ Throughput   │ Memory (MB)  │ Score  │\n");
646        report.push_str("├─────────────┼──────────────┼──────────────┼──────────────┼────────┤\n");
647
648        for analysis in &self.analyses {
649            let memory_mb = analysis.memory.total_estimated_bytes as f64 / (1024.0 * 1024.0);
650            report.push_str(&format!(
651                "│ {:11} │ {:9.2} μs │ {:9.1} /s │ {:11.2}  │ {:5.1}  │\n",
652                analysis.model_name,
653                analysis.results.avg_latency_us,
654                analysis.results.throughput_steps_per_sec,
655                memory_mb,
656                analysis.performance_score
657            ));
658        }
659
660        report
661            .push_str("└─────────────┴──────────────┴──────────────┴──────────────┴────────┘\n\n");
662
663        // Winners
664        report.push_str(&format!(
665            "🏆 Fastest Model:            {}\n",
666            self.fastest_model
667        ));
668        report.push_str(&format!(
669            "💾 Most Memory Efficient:    {}\n",
670            self.most_memory_efficient
671        ));
672        report.push_str(&format!(
673            "⭐ Best Overall:             {}\n\n",
674            self.best_overall
675        ));
676
677        // Detailed analyses
678        report.push_str("═══════════════════════════════════════════════════════════════\n");
679        report.push_str("           DETAILED BOTTLENECK ANALYSES\n");
680        report.push_str("═══════════════════════════════════════════════════════════════\n");
681
682        for analysis in &self.analyses {
683            report.push_str(&analysis.format_report());
684        }
685
686        report
687    }
688}
689
690/// Comprehensive profiler for all models
691pub struct ComprehensiveProfiler {
692    num_steps: usize,
693}
694
695impl ComprehensiveProfiler {
696    /// Create a new comprehensive profiler
697    pub fn new() -> Self {
698        Self { num_steps: 1000 }
699    }
700
701    /// Set number of steps for profiling
702    pub fn num_steps(mut self, steps: usize) -> Self {
703        self.num_steps = steps;
704        self
705    }
706
707    /// Profile all available models and generate comprehensive comparison
708    pub fn profile_all_models(&self) -> ModelResult<ComprehensiveComparison> {
709        use crate::{mamba::*, mamba2::*, rwkv::*, s4::*, s5::*, transformer::*};
710
711        let mut analyses = Vec::new();
712
713        // Define common configuration
714        let hidden_dim = 256;
715        let num_layers = 4;
716        let state_dim = 64;
717
718        // Profile Mamba
719        let mamba_config = MambaConfig::default()
720            .hidden_dim(hidden_dim)
721            .state_dim(state_dim)
722            .num_layers(num_layers);
723
724        let mamba = Mamba::new(mamba_config)?;
725        let mamba_analysis =
726            ModelBottleneckAnalysis::analyze(mamba, "Mamba".to_string(), self.num_steps)?;
727        analyses.push(mamba_analysis);
728
729        // Profile Mamba2
730        let mamba2_config = Mamba2Config::default()
731            .hidden_dim(hidden_dim)
732            .state_dim(state_dim)
733            .num_layers(num_layers)
734            .num_heads(4);
735
736        let mamba2 = Mamba2::new(mamba2_config)?;
737        let mamba2_analysis =
738            ModelBottleneckAnalysis::analyze(mamba2, "Mamba2".to_string(), self.num_steps)?;
739        analyses.push(mamba2_analysis);
740
741        // Profile RWKV
742        let rwkv_config = RwkvConfig::default()
743            .hidden_dim(hidden_dim)
744            .num_layers(num_layers)
745            .num_heads(4);
746
747        let rwkv = Rwkv::new(rwkv_config)?;
748        let rwkv_analysis =
749            ModelBottleneckAnalysis::analyze(rwkv, "RWKV".to_string(), self.num_steps)?;
750        analyses.push(rwkv_analysis);
751
752        // Profile S4D
753        let s4_config = S4Config::default()
754            .hidden_dim(hidden_dim)
755            .state_dim(state_dim)
756            .num_layers(num_layers);
757
758        let s4 = S4D::new(s4_config)?;
759        let s4_analysis = ModelBottleneckAnalysis::analyze(s4, "S4D".to_string(), self.num_steps)?;
760        analyses.push(s4_analysis);
761
762        // Profile S5 (S5Config doesn't have fluent setters, so we use new with defaults)
763        let s5_config = S5Config::new(1, hidden_dim, num_layers);
764
765        let s5 = S5::new(s5_config)?;
766        let s5_analysis = ModelBottleneckAnalysis::analyze(s5, "S5".to_string(), self.num_steps)?;
767        analyses.push(s5_analysis);
768
769        // Profile Transformer
770        let transformer_config = TransformerConfig::default()
771            .hidden_dim(hidden_dim)
772            .num_heads(4)
773            .num_layers(num_layers);
774
775        let transformer = Transformer::new(transformer_config)?;
776        let transformer_analysis = ModelBottleneckAnalysis::analyze(
777            transformer,
778            "Transformer".to_string(),
779            self.num_steps,
780        )?;
781        analyses.push(transformer_analysis);
782
783        // Determine winners
784        let fastest_model = analyses
785            .iter()
786            .min_by(|a, b| {
787                a.results
788                    .avg_latency_us
789                    .partial_cmp(&b.results.avg_latency_us)
790                    .unwrap_or(std::cmp::Ordering::Equal)
791            })
792            .map(|a| a.model_name.clone())
793            .unwrap_or_default();
794
795        let most_memory_efficient = analyses
796            .iter()
797            .min_by(|a, b| {
798                a.memory
799                    .total_estimated_bytes
800                    .cmp(&b.memory.total_estimated_bytes)
801            })
802            .map(|a| a.model_name.clone())
803            .unwrap_or_default();
804
805        let best_overall = analyses
806            .iter()
807            .max_by(|a, b| {
808                a.performance_score
809                    .partial_cmp(&b.performance_score)
810                    .unwrap_or(std::cmp::Ordering::Equal)
811            })
812            .map(|a| a.model_name.clone())
813            .unwrap_or_default();
814
815        Ok(ComprehensiveComparison {
816            analyses,
817            fastest_model,
818            most_memory_efficient,
819            best_overall,
820        })
821    }
822}
823
824impl Default for ComprehensiveProfiler {
825    fn default() -> Self {
826        Self::new()
827    }
828}
829
830#[cfg(test)]
831mod tests {
832    use super::*;
833    use crate::mamba::{Mamba, MambaConfig};
834
835    #[test]
836    fn test_profiling_results() {
837        let timings = vec![
838            Duration::from_micros(100),
839            Duration::from_micros(150),
840            Duration::from_micros(120),
841            Duration::from_micros(200),
842            Duration::from_micros(110),
843        ];
844
845        let results = ProfilingResults::from_timings(&timings);
846
847        assert_eq!(results.num_steps, 5);
848        assert!(results.avg_latency_us > 0.0);
849        assert!(results.min_latency_us <= results.avg_latency_us);
850        assert!(results.avg_latency_us <= results.max_latency_us);
851        assert!(results.throughput_steps_per_sec > 0.0);
852    }
853
854    #[test]
855    #[ignore] // Slow test: ~39s due to comprehensive profiling (100 steps)
856    fn test_model_profiler() {
857        let config = MambaConfig::default().hidden_dim(64).num_layers(2);
858        let model = Mamba::new(config).expect("Failed to create Mamba model");
859
860        let mut profiler = ModelProfiler::new(model).warmup_steps(5);
861
862        let results = profiler
863            .profile_inference(100, 1)
864            .expect("Failed to profile inference");
865
866        assert_eq!(results.num_steps, 100);
867        assert!(results.avg_latency_us > 0.0);
868        assert!(results.throughput_steps_per_sec > 0.0);
869    }
870
871    #[test]
872    #[ignore] // Slow test: ~104s due to multiple model creation and profiling runs
873    fn test_input_scaling() {
874        // Test with different hidden dimensions (not input dims, since input_dim is fixed by model config)
875        let hidden_dims = vec![32, 64, 128];
876        let mut results = Vec::new();
877
878        for hidden_dim in hidden_dims {
879            let config = MambaConfig::default().hidden_dim(hidden_dim).num_layers(2);
880            let model = Mamba::new(config).expect("Failed to create Mamba model");
881
882            let mut profiler = ModelProfiler::new(model).warmup_steps(5);
883            let profile = profiler
884                .profile_inference(50, 1)
885                .expect("Failed to profile inference");
886
887            results.push((hidden_dim, profile));
888        }
889
890        assert_eq!(results.len(), 3);
891
892        // Verify all profiles are valid
893        for (dim, result) in &results {
894            assert_eq!(result.num_steps, 50);
895            assert!(*dim > 0);
896            assert!(result.avg_latency_us > 0.0);
897        }
898    }
899
900    #[test]
901    fn test_memory_profile() {
902        let config = MambaConfig::default().hidden_dim(256).num_layers(4);
903        let model = Mamba::new(config).expect("Failed to create Mamba model");
904
905        let profiler = ModelProfiler::new(model);
906        let memory = profiler.estimate_memory_usage();
907
908        assert_eq!(memory.hidden_dim, 256);
909        assert_eq!(memory.num_layers, 4);
910        assert!(memory.total_estimated_bytes > 0);
911    }
912
913    #[test]
914    #[ignore] // Slow test: ~32s due to benchmark suite execution (100 steps)
915    fn test_benchmark_suite() {
916        let config = MambaConfig::default().hidden_dim(64).num_layers(2);
917        let model = Mamba::new(config).expect("Failed to create Mamba model");
918
919        let suite = BenchmarkSuite::new().num_steps(100).warmup_steps(5);
920
921        let results = suite.benchmark(model).expect("Failed to run benchmark");
922
923        assert_eq!(results.num_steps, 100);
924        assert!(results.avg_latency_us > 0.0);
925    }
926
927    #[test]
928    fn test_format_report() {
929        let timings = vec![Duration::from_micros(100); 10];
930        let results = ProfilingResults::from_timings(&timings);
931
932        let report = results.format_report();
933        assert!(report.contains("Profiling Results"));
934        assert!(report.contains("Average Latency"));
935        assert!(report.contains("Throughput"));
936    }
937}