Skip to main content

tenflowers_dataset/
throughput_benchmark.rs

1//! Throughput benchmark performance harness for datasets
2//!
3//! This module provides comprehensive benchmarking capabilities for measuring
4//! dataset loading, transformation, and iteration performance. It supports
5//! multi-threaded testing, various batch sizes, and detailed performance metrics.
6
7use crate::Dataset;
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11
12#[cfg(feature = "parallel")]
13use rayon::prelude::*;
14
15/// Throughput benchmark configuration
16#[derive(Debug, Clone)]
17pub struct ThroughputBenchmarkConfig {
18    /// Number of warmup iterations before measurement
19    pub warmup_iterations: usize,
20    /// Number of measurement iterations
21    pub measurement_iterations: usize,
22    /// Batch size for batched operations
23    pub batch_size: Option<usize>,
24    /// Number of worker threads for parallel testing
25    pub num_threads: Option<usize>,
26    /// Whether to measure memory usage
27    pub measure_memory: bool,
28    /// Whether to include detailed per-sample timings
29    pub detailed_timings: bool,
30    /// Maximum samples to benchmark (None = all)
31    pub max_samples: Option<usize>,
32}
33
34impl Default for ThroughputBenchmarkConfig {
35    fn default() -> Self {
36        Self {
37            warmup_iterations: 10,
38            measurement_iterations: 100,
39            batch_size: None,
40            num_threads: None,
41            measure_memory: false,
42            detailed_timings: false,
43            max_samples: None,
44        }
45    }
46}
47
48/// Results from a throughput benchmark
49#[derive(Debug, Clone)]
50pub struct ThroughputBenchmarkResult {
51    /// Dataset name or identifier
52    pub dataset_name: String,
53    /// Total samples processed
54    pub samples_processed: usize,
55    /// Total time elapsed
56    pub total_duration: Duration,
57    /// Samples per second
58    pub samples_per_second: f64,
59    /// Average latency per sample (microseconds)
60    pub avg_latency_us: f64,
61    /// P50 latency (microseconds)
62    pub p50_latency_us: f64,
63    /// P95 latency (microseconds)
64    pub p95_latency_us: f64,
65    /// P99 latency (microseconds)
66    pub p99_latency_us: f64,
67    /// Minimum latency (microseconds)
68    pub min_latency_us: f64,
69    /// Maximum latency (microseconds)
70    pub max_latency_us: f64,
71    /// Standard deviation of latency
72    pub latency_std_dev_us: f64,
73    /// Memory usage statistics (if measured)
74    pub memory_stats: Option<MemoryStats>,
75    /// Per-thread statistics (if multi-threaded)
76    pub per_thread_stats: Vec<ThreadStats>,
77    /// Timestamp when benchmark was run
78    pub timestamp: Instant,
79}
80
81/// Memory usage statistics
82#[derive(Debug, Clone)]
83pub struct MemoryStats {
84    /// Peak memory usage in bytes
85    pub peak_bytes: usize,
86    /// Average memory usage in bytes
87    pub avg_bytes: usize,
88    /// Memory allocations per second
89    pub allocations_per_second: f64,
90}
91
92/// Per-thread statistics
93#[derive(Debug, Clone)]
94pub struct ThreadStats {
95    /// Thread identifier
96    pub thread_id: usize,
97    /// Samples processed by this thread
98    pub samples_processed: usize,
99    /// Time spent by this thread
100    pub duration: Duration,
101    /// Samples per second for this thread
102    pub samples_per_second: f64,
103}
104
105/// Throughput benchmark harness
106pub struct ThroughputBenchmarkHarness {
107    /// Benchmark configuration
108    config: ThroughputBenchmarkConfig,
109    /// Collected sample latencies (microseconds)
110    sample_latencies: Arc<Mutex<Vec<u64>>>,
111    /// Memory usage samples (if measuring)
112    memory_samples: Arc<Mutex<Vec<usize>>>,
113    /// Per-thread statistics (if multi-threaded)
114    thread_stats: Arc<Mutex<Vec<ThreadStats>>>,
115}
116
117impl ThroughputBenchmarkHarness {
118    /// Create a new benchmark harness with configuration
119    pub fn new(config: ThroughputBenchmarkConfig) -> Self {
120        Self {
121            config,
122            sample_latencies: Arc::new(Mutex::new(Vec::new())),
123            memory_samples: Arc::new(Mutex::new(Vec::new())),
124            thread_stats: Arc::new(Mutex::new(Vec::new())),
125        }
126    }
127
128    /// Create a default benchmark harness
129    pub fn default() -> Self {
130        Self::new(ThroughputBenchmarkConfig::default())
131    }
132
133    /// Benchmark a dataset's iteration performance
134    pub fn benchmark<T, D>(
135        &mut self,
136        dataset: &D,
137        name: impl Into<String>,
138    ) -> ThroughputBenchmarkResult
139    where
140        T: Clone + Send + Sync + 'static,
141        D: Dataset<T> + Sync,
142    {
143        let dataset_name = name.into();
144        let total_samples = if let Some(max) = self.config.max_samples {
145            max.min(dataset.len())
146        } else {
147            dataset.len()
148        };
149
150        // Warmup phase
151        self.warmup_phase(dataset, total_samples);
152
153        // Measurement phase
154        let start_time = Instant::now();
155        self.measurement_phase(dataset, total_samples);
156        let total_duration = start_time.elapsed();
157
158        // Calculate statistics
159        let latencies = self
160            .sample_latencies
161            .lock()
162            .expect("lock should not be poisoned")
163            .clone();
164        let stats = calculate_latency_statistics(&latencies);
165
166        // Calculate memory statistics if measured
167        let memory_stats = if self.config.measure_memory {
168            let memory_samples = self
169                .memory_samples
170                .lock()
171                .expect("lock should not be poisoned");
172            if !memory_samples.is_empty() {
173                let peak_bytes = *memory_samples.iter().max().unwrap_or(&0);
174                let avg_bytes = memory_samples.iter().sum::<usize>() / memory_samples.len();
175                let allocations_per_second =
176                    memory_samples.len() as f64 / total_duration.as_secs_f64();
177                Some(MemoryStats {
178                    peak_bytes,
179                    avg_bytes,
180                    allocations_per_second,
181                })
182            } else {
183                None
184            }
185        } else {
186            None
187        };
188
189        // Get per-thread statistics if multi-threaded
190        let per_thread_stats = self
191            .thread_stats
192            .lock()
193            .expect("lock should not be poisoned")
194            .clone();
195
196        ThroughputBenchmarkResult {
197            dataset_name,
198            samples_processed: total_samples,
199            total_duration,
200            samples_per_second: total_samples as f64 / total_duration.as_secs_f64(),
201            avg_latency_us: stats.mean,
202            p50_latency_us: stats.p50,
203            p95_latency_us: stats.p95,
204            p99_latency_us: stats.p99,
205            min_latency_us: stats.min,
206            max_latency_us: stats.max,
207            latency_std_dev_us: stats.std_dev,
208            memory_stats,
209            per_thread_stats,
210            timestamp: Instant::now(),
211        }
212    }
213
214    /// Benchmark with batched access
215    pub fn benchmark_batched<T, D>(
216        &mut self,
217        dataset: &D,
218        batch_size: usize,
219        name: impl Into<String>,
220    ) -> ThroughputBenchmarkResult
221    where
222        T: Clone + Send + Sync + 'static,
223        D: Dataset<T> + Sync,
224    {
225        let dataset_name = name.into();
226        let total_samples = if let Some(max) = self.config.max_samples {
227            max.min(dataset.len())
228        } else {
229            dataset.len()
230        };
231
232        // Warmup phase with batches
233        self.warmup_phase_batched(dataset, batch_size, total_samples);
234
235        // Measurement phase with batches
236        let start_time = Instant::now();
237        self.measurement_phase_batched(dataset, batch_size, total_samples);
238        let total_duration = start_time.elapsed();
239
240        // Calculate statistics
241        let latencies = self
242            .sample_latencies
243            .lock()
244            .expect("lock should not be poisoned")
245            .clone();
246        let stats = calculate_latency_statistics(&latencies);
247
248        // Calculate memory statistics if measured
249        let memory_stats = if self.config.measure_memory {
250            let memory_samples = self
251                .memory_samples
252                .lock()
253                .expect("lock should not be poisoned");
254            if !memory_samples.is_empty() {
255                let peak_bytes = *memory_samples.iter().max().unwrap_or(&0);
256                let avg_bytes = memory_samples.iter().sum::<usize>() / memory_samples.len();
257                let allocations_per_second =
258                    memory_samples.len() as f64 / total_duration.as_secs_f64();
259                Some(MemoryStats {
260                    peak_bytes,
261                    avg_bytes,
262                    allocations_per_second,
263                })
264            } else {
265                None
266            }
267        } else {
268            None
269        };
270
271        // Get per-thread statistics if multi-threaded
272        let per_thread_stats = self
273            .thread_stats
274            .lock()
275            .expect("lock should not be poisoned")
276            .clone();
277
278        ThroughputBenchmarkResult {
279            dataset_name,
280            samples_processed: total_samples,
281            total_duration,
282            samples_per_second: total_samples as f64 / total_duration.as_secs_f64(),
283            avg_latency_us: stats.mean,
284            p50_latency_us: stats.p50,
285            p95_latency_us: stats.p95,
286            p99_latency_us: stats.p99,
287            min_latency_us: stats.min,
288            max_latency_us: stats.max,
289            latency_std_dev_us: stats.std_dev,
290            memory_stats,
291            per_thread_stats,
292            timestamp: Instant::now(),
293        }
294    }
295
296    /// Compare multiple datasets
297    pub fn compare_datasets<T, D>(
298        &mut self,
299        datasets: Vec<(&D, String)>,
300    ) -> HashMap<String, ThroughputBenchmarkResult>
301    where
302        T: Clone + Send + Sync + 'static,
303        D: Dataset<T> + Sync,
304    {
305        let mut results = HashMap::new();
306
307        for (dataset, name) in datasets {
308            let result = self.benchmark(dataset, name.clone());
309            results.insert(name, result);
310        }
311
312        results
313    }
314
315    /// Benchmark with multi-threading (requires parallel feature)
316    #[cfg(feature = "parallel")]
317    pub fn benchmark_multithreaded<T, D>(
318        &mut self,
319        dataset: &D,
320        num_threads: usize,
321        name: impl Into<String>,
322    ) -> ThroughputBenchmarkResult
323    where
324        T: Clone + Send + Sync + 'static,
325        D: Dataset<T> + Sync,
326    {
327        let dataset_name = name.into();
328        let total_samples = if let Some(max) = self.config.max_samples {
329            max.min(dataset.len())
330        } else {
331            dataset.len()
332        };
333
334        // Warmup phase
335        self.warmup_phase(dataset, total_samples);
336
337        // Clear thread stats
338        self.thread_stats
339            .lock()
340            .expect("lock should not be poisoned")
341            .clear();
342
343        // Measurement phase with parallel execution
344        let start_time = Instant::now();
345
346        // Divide samples among threads
347        let samples_per_thread = (total_samples + num_threads - 1) / num_threads;
348        let thread_ranges: Vec<_> = (0..num_threads)
349            .map(|i| {
350                let start = i * samples_per_thread;
351                let end = ((i + 1) * samples_per_thread).min(total_samples);
352                (i, start, end)
353            })
354            .collect();
355
356        // Execute benchmark in parallel
357        let thread_stats_mutex = Arc::clone(&self.thread_stats);
358        thread_ranges
359            .par_iter()
360            .for_each(|(thread_id, start, end)| {
361                let thread_start = Instant::now();
362                let mut samples_processed = 0;
363
364                for _ in 0..self.config.measurement_iterations {
365                    for i in *start..*end {
366                        let _ = dataset.get(i);
367                        samples_processed += 1;
368                    }
369                }
370
371                let thread_duration = thread_start.elapsed();
372                let samples_per_second = samples_processed as f64 / thread_duration.as_secs_f64();
373
374                // Record thread statistics
375                let mut stats = thread_stats_mutex
376                    .lock()
377                    .expect("lock should not be poisoned");
378                stats.push(ThreadStats {
379                    thread_id: *thread_id,
380                    samples_processed,
381                    duration: thread_duration,
382                    samples_per_second,
383                });
384            });
385
386        let total_duration = start_time.elapsed();
387
388        // Calculate statistics (using thread stats for latency approximation)
389        let thread_stats = self
390            .thread_stats
391            .lock()
392            .expect("lock should not be poisoned")
393            .clone();
394        let total_processed: usize = thread_stats.iter().map(|s| s.samples_processed).sum();
395        let avg_latency_us = (total_duration.as_micros() as f64) / (total_processed as f64);
396
397        // Calculate memory statistics if measured
398        let memory_stats = if self.config.measure_memory {
399            let memory_samples = self
400                .memory_samples
401                .lock()
402                .expect("lock should not be poisoned");
403            if !memory_samples.is_empty() {
404                let peak_bytes = *memory_samples.iter().max().unwrap_or(&0);
405                let avg_bytes = memory_samples.iter().sum::<usize>() / memory_samples.len();
406                let allocations_per_second =
407                    memory_samples.len() as f64 / total_duration.as_secs_f64();
408                Some(MemoryStats {
409                    peak_bytes,
410                    avg_bytes,
411                    allocations_per_second,
412                })
413            } else {
414                None
415            }
416        } else {
417            None
418        };
419
420        ThroughputBenchmarkResult {
421            dataset_name,
422            samples_processed: total_processed,
423            total_duration,
424            samples_per_second: total_processed as f64 / total_duration.as_secs_f64(),
425            avg_latency_us,
426            p50_latency_us: avg_latency_us,
427            p95_latency_us: avg_latency_us,
428            p99_latency_us: avg_latency_us,
429            min_latency_us: avg_latency_us,
430            max_latency_us: avg_latency_us,
431            latency_std_dev_us: 0.0,
432            memory_stats,
433            per_thread_stats: thread_stats,
434            timestamp: Instant::now(),
435        }
436    }
437
438    /// Reset collected metrics
439    pub fn reset(&mut self) {
440        self.sample_latencies
441            .lock()
442            .expect("lock should not be poisoned")
443            .clear();
444        self.memory_samples
445            .lock()
446            .expect("lock should not be poisoned")
447            .clear();
448        self.thread_stats
449            .lock()
450            .expect("lock should not be poisoned")
451            .clear();
452    }
453
454    /// Get current memory usage (platform-specific approximation)
455    fn get_current_memory_usage(&self) -> usize {
456        // This is a basic approximation. On Linux/Unix systems, you could read from /proc
457        // For now, we'll return 0 as a placeholder. Real implementation would use
458        // platform-specific APIs or crates like `jemalloc_ctl` or `memory-stats`
459        0
460    }
461
462    /// Track memory usage during benchmark
463    fn track_memory(&self) {
464        if self.config.measure_memory {
465            let mem = self.get_current_memory_usage();
466            self.memory_samples
467                .lock()
468                .expect("lock should not be poisoned")
469                .push(mem);
470        }
471    }
472
473    // Private helper methods
474
475    fn warmup_phase<T, D>(&self, dataset: &D, total_samples: usize)
476    where
477        T: Clone + Send + Sync + 'static,
478        D: Dataset<T>,
479    {
480        for _ in 0..self.config.warmup_iterations {
481            for i in 0..total_samples {
482                let _ = dataset.get(i);
483            }
484        }
485    }
486
487    fn measurement_phase<T, D>(&self, dataset: &D, total_samples: usize)
488    where
489        T: Clone + Send + Sync + 'static,
490        D: Dataset<T>,
491    {
492        let mut latencies = self
493            .sample_latencies
494            .lock()
495            .expect("lock should not be poisoned");
496        latencies.clear();
497
498        for _ in 0..self.config.measurement_iterations {
499            self.track_memory(); // Track memory at start of each iteration
500
501            for i in 0..total_samples {
502                let start = Instant::now();
503                let _ = dataset.get(i);
504                let latency = start.elapsed().as_micros() as u64;
505
506                if self.config.detailed_timings {
507                    latencies.push(latency);
508                }
509
510                // Track memory periodically (every 100 samples)
511                if i % 100 == 0 {
512                    self.track_memory();
513                }
514            }
515        }
516
517        // If not detailed, record average latency
518        if !self.config.detailed_timings && !latencies.is_empty() {
519            let avg = latencies.iter().sum::<u64>() / latencies.len() as u64;
520            latencies.clear();
521            latencies.push(avg);
522        }
523    }
524
525    fn warmup_phase_batched<T, D>(&self, dataset: &D, batch_size: usize, total_samples: usize)
526    where
527        T: Clone + Send + Sync + 'static,
528        D: Dataset<T>,
529    {
530        for _ in 0..self.config.warmup_iterations {
531            for batch_start in (0..total_samples).step_by(batch_size) {
532                let batch_end = (batch_start + batch_size).min(total_samples);
533                for i in batch_start..batch_end {
534                    let _ = dataset.get(i);
535                }
536            }
537        }
538    }
539
540    fn measurement_phase_batched<T, D>(&self, dataset: &D, batch_size: usize, total_samples: usize)
541    where
542        T: Clone + Send + Sync + 'static,
543        D: Dataset<T>,
544    {
545        let mut latencies = self
546            .sample_latencies
547            .lock()
548            .expect("lock should not be poisoned");
549        latencies.clear();
550
551        for _ in 0..self.config.measurement_iterations {
552            for batch_start in (0..total_samples).step_by(batch_size) {
553                let batch_end = (batch_start + batch_size).min(total_samples);
554                let start = Instant::now();
555
556                for i in batch_start..batch_end {
557                    let _ = dataset.get(i);
558                }
559
560                let batch_latency = start.elapsed().as_micros() as u64;
561                let per_sample_latency = batch_latency / (batch_end - batch_start) as u64;
562
563                if self.config.detailed_timings {
564                    latencies.push(per_sample_latency);
565                }
566            }
567        }
568    }
569}
570
571/// Latency statistics
572struct LatencyStatistics {
573    mean: f64,
574    min: f64,
575    max: f64,
576    p50: f64,
577    p95: f64,
578    p99: f64,
579    std_dev: f64,
580}
581
582/// Calculate latency statistics from collected samples
583fn calculate_latency_statistics(latencies: &[u64]) -> LatencyStatistics {
584    if latencies.is_empty() {
585        return LatencyStatistics {
586            mean: 0.0,
587            min: 0.0,
588            max: 0.0,
589            p50: 0.0,
590            p95: 0.0,
591            p99: 0.0,
592            std_dev: 0.0,
593        };
594    }
595
596    let mut sorted = latencies.to_vec();
597    sorted.sort_unstable();
598
599    let sum: u64 = sorted.iter().sum();
600    let mean = sum as f64 / sorted.len() as f64;
601
602    let variance = sorted
603        .iter()
604        .map(|&x| {
605            let diff = x as f64 - mean;
606            diff * diff
607        })
608        .sum::<f64>()
609        / sorted.len() as f64;
610    let std_dev = variance.sqrt();
611
612    let percentile = |p: f64| -> f64 {
613        let index = ((sorted.len() as f64 * p) as usize).min(sorted.len() - 1);
614        sorted[index] as f64
615    };
616
617    LatencyStatistics {
618        mean,
619        min: sorted[0] as f64,
620        max: sorted[sorted.len() - 1] as f64,
621        p50: percentile(0.50),
622        p95: percentile(0.95),
623        p99: percentile(0.99),
624        std_dev,
625    }
626}
627
628impl ThroughputBenchmarkResult {
629    /// Generate a human-readable report
630    pub fn generate_report(&self) -> String {
631        let mut report = String::new();
632
633        report.push_str(&format!(
634            "=== Throughput Benchmark Report: {} ===\n\n",
635            self.dataset_name
636        ));
637
638        report.push_str("## Overall Statistics\n");
639        report.push_str(&format!(
640            "  Samples Processed: {}\n",
641            self.samples_processed
642        ));
643        report.push_str(&format!("  Total Duration: {:.2?}\n", self.total_duration));
644        report.push_str(&format!(
645            "  Throughput: {:.2} samples/sec\n\n",
646            self.samples_per_second
647        ));
648
649        report.push_str("## Latency Statistics (microseconds)\n");
650        report.push_str(&format!("  Average: {:.2}\n", self.avg_latency_us));
651        report.push_str(&format!("  Minimum: {:.2}\n", self.min_latency_us));
652        report.push_str(&format!("  Maximum: {:.2}\n", self.max_latency_us));
653        report.push_str(&format!("  Std Dev: {:.2}\n", self.latency_std_dev_us));
654        report.push_str(&format!("  P50: {:.2}\n", self.p50_latency_us));
655        report.push_str(&format!("  P95: {:.2}\n", self.p95_latency_us));
656        report.push_str(&format!("  P99: {:.2}\n\n", self.p99_latency_us));
657
658        if let Some(ref mem_stats) = self.memory_stats {
659            report.push_str("## Memory Statistics\n");
660            report.push_str(&format!("  Peak: {} bytes\n", mem_stats.peak_bytes));
661            report.push_str(&format!("  Average: {} bytes\n", mem_stats.avg_bytes));
662            report.push_str(&format!(
663                "  Allocations/sec: {:.2}\n\n",
664                mem_stats.allocations_per_second
665            ));
666        }
667
668        if !self.per_thread_stats.is_empty() {
669            report.push_str("## Per-Thread Statistics\n");
670            for thread_stat in &self.per_thread_stats {
671                report.push_str(&format!(
672                    "  Thread {}: {} samples, {:.2?}, {:.2} samples/sec\n",
673                    thread_stat.thread_id,
674                    thread_stat.samples_processed,
675                    thread_stat.duration,
676                    thread_stat.samples_per_second
677                ));
678            }
679        }
680
681        report
682    }
683
684    /// Export results as CSV
685    pub fn to_csv(&self) -> String {
686        format!(
687            "{},{},{},{:.2},{:.2},{:.2},{:.2},{:.2},{:.2},{:.2},{:.2}\n",
688            self.dataset_name,
689            self.samples_processed,
690            self.total_duration.as_millis(),
691            self.samples_per_second,
692            self.avg_latency_us,
693            self.min_latency_us,
694            self.max_latency_us,
695            self.p50_latency_us,
696            self.p95_latency_us,
697            self.p99_latency_us,
698            self.latency_std_dev_us
699        )
700    }
701}
702
703#[cfg(test)]
704mod tests {
705    use super::*;
706    use crate::TensorDataset;
707    use tenflowers_core::Tensor;
708
709    #[test]
710    fn test_benchmark_harness_creation() {
711        let harness = ThroughputBenchmarkHarness::default();
712        assert_eq!(harness.config.warmup_iterations, 10);
713        assert_eq!(harness.config.measurement_iterations, 100);
714    }
715
716    #[test]
717    fn test_basic_benchmark() {
718        let features =
719            Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2])
720                .expect("test: tensor creation should succeed");
721        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0, 2.0, 3.0], &[4])
722            .expect("test: tensor creation should succeed");
723
724        let dataset = TensorDataset::new(features, labels);
725        let mut harness = ThroughputBenchmarkHarness::new(ThroughputBenchmarkConfig {
726            warmup_iterations: 1,
727            measurement_iterations: 5,
728            max_samples: Some(4),
729            ..Default::default()
730        });
731
732        let result = harness.benchmark(&dataset, "test_dataset");
733
734        assert_eq!(result.samples_processed, 4);
735        assert!(result.samples_per_second > 0.0);
736        assert!(result.avg_latency_us >= 0.0);
737    }
738
739    #[test]
740    fn test_batched_benchmark() {
741        let features = Tensor::<f32>::from_vec(
742            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
743            &[5, 2],
744        )
745        .expect("test: operation should succeed");
746        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0], &[5])
747            .expect("test: tensor creation should succeed");
748
749        let dataset = TensorDataset::new(features, labels);
750        let mut harness = ThroughputBenchmarkHarness::new(ThroughputBenchmarkConfig {
751            warmup_iterations: 1,
752            measurement_iterations: 3,
753            max_samples: Some(5),
754            ..Default::default()
755        });
756
757        let result = harness.benchmark_batched(&dataset, 2, "batched_test");
758
759        assert_eq!(result.samples_processed, 5);
760        assert!(result.samples_per_second > 0.0);
761    }
762
763    #[test]
764    fn test_generate_report() {
765        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
766            .expect("test: tensor creation should succeed");
767        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
768            .expect("test: tensor creation should succeed");
769
770        let dataset = TensorDataset::new(features, labels);
771        let mut harness = ThroughputBenchmarkHarness::new(ThroughputBenchmarkConfig {
772            warmup_iterations: 1,
773            measurement_iterations: 2,
774            ..Default::default()
775        });
776
777        let result = harness.benchmark(&dataset, "report_test");
778        let report = result.generate_report();
779
780        assert!(report.contains("Throughput Benchmark Report"));
781        assert!(report.contains("Samples Processed"));
782        assert!(report.contains("Throughput:"));
783        assert!(report.contains("Latency Statistics"));
784    }
785
786    #[test]
787    fn test_csv_export() {
788        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[1, 2])
789            .expect("test: tensor creation should succeed");
790        let labels =
791            Tensor::<f32>::from_vec(vec![0.0], &[1]).expect("test: tensor creation should succeed");
792
793        let dataset = TensorDataset::new(features, labels);
794        let mut harness = ThroughputBenchmarkHarness::new(ThroughputBenchmarkConfig {
795            warmup_iterations: 1,
796            measurement_iterations: 1,
797            ..Default::default()
798        });
799
800        let result = harness.benchmark(&dataset, "csv_test");
801        let csv = result.to_csv();
802
803        assert!(csv.contains("csv_test"));
804        assert!(csv.contains(','));
805    }
806
807    #[test]
808    fn test_reset() {
809        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[1, 2])
810            .expect("test: tensor creation should succeed");
811        let labels =
812            Tensor::<f32>::from_vec(vec![0.0], &[1]).expect("test: tensor creation should succeed");
813
814        let dataset = TensorDataset::new(features, labels);
815        let mut harness = ThroughputBenchmarkHarness::new(ThroughputBenchmarkConfig {
816            warmup_iterations: 1,
817            measurement_iterations: 2,
818            detailed_timings: true,
819            ..Default::default()
820        });
821
822        let _ = harness.benchmark(&dataset, "test1");
823        assert!(!harness
824            .sample_latencies
825            .lock()
826            .expect("lock should not be poisoned")
827            .is_empty());
828
829        harness.reset();
830        assert!(harness
831            .sample_latencies
832            .lock()
833            .expect("lock should not be poisoned")
834            .is_empty());
835    }
836
837    #[test]
838    fn test_compare_datasets() {
839        let features1 = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[1, 2])
840            .expect("test: tensor creation should succeed");
841        let labels1 =
842            Tensor::<f32>::from_vec(vec![0.0], &[1]).expect("test: tensor creation should succeed");
843        let dataset1 = TensorDataset::new(features1, labels1);
844
845        let features2 = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[1, 2])
846            .expect("test: tensor creation should succeed");
847        let labels2 =
848            Tensor::<f32>::from_vec(vec![1.0], &[1]).expect("test: tensor creation should succeed");
849        let dataset2 = TensorDataset::new(features2, labels2);
850
851        let mut harness = ThroughputBenchmarkHarness::new(ThroughputBenchmarkConfig {
852            warmup_iterations: 1,
853            measurement_iterations: 1,
854            ..Default::default()
855        });
856
857        let results = harness.compare_datasets(vec![
858            (&dataset1, "dataset1".to_string()),
859            (&dataset2, "dataset2".to_string()),
860        ]);
861
862        assert_eq!(results.len(), 2);
863        assert!(results.contains_key("dataset1"));
864        assert!(results.contains_key("dataset2"));
865    }
866}