1use crate::Dataset;
8use std::collections::HashMap;
9use std::time::{Duration, Instant};
10use tenflowers_core::{Result, TensorError};
11
12pub struct PipelineProfiler {
14 name: String,
16 start_time: Option<Instant>,
18 events: Vec<ProfileEvent>,
20 stage_timings: HashMap<String, Vec<Duration>>,
22 config: ProfilerConfig,
24}
25
26#[derive(Debug, Clone)]
28pub struct ProfilerConfig {
29 pub track_memory: bool,
31 pub track_cache: bool,
33 pub track_io: bool,
35 pub max_events: usize,
37 pub sample_rate: f64,
39}
40
41impl Default for ProfilerConfig {
42 fn default() -> Self {
43 Self {
44 track_memory: true,
45 track_cache: true,
46 track_io: true,
47 max_events: 10000,
48 sample_rate: 1.0,
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct ProfileEvent {
56 pub timestamp: Instant,
58 pub event_type: EventType,
60 pub stage: String,
62 pub duration: Option<Duration>,
64 pub metadata: HashMap<String, String>,
66}
67
68#[derive(Debug, Clone, PartialEq, Eq)]
70pub enum EventType {
71 StageStart,
73 StageEnd,
75 DataLoad,
77 Transform,
79 CacheHit,
81 CacheMiss,
83 MemoryAlloc,
85 IoOperation,
87 Custom(String),
89}
90
91impl PipelineProfiler {
92 pub fn new(name: impl Into<String>, config: ProfilerConfig) -> Self {
94 Self {
95 name: name.into(),
96 start_time: None,
97 events: Vec::new(),
98 stage_timings: HashMap::new(),
99 config,
100 }
101 }
102
103 pub fn default_config(name: impl Into<String>) -> Self {
105 Self::new(name, ProfilerConfig::default())
106 }
107
108 pub fn start(&mut self) {
110 self.start_time = Some(Instant::now());
111 self.record_event(
112 EventType::Custom("profiling_started".to_string()),
113 "root",
114 None,
115 );
116 }
117
118 pub fn stop(&mut self) {
120 if let Some(start) = self.start_time {
121 let duration = start.elapsed();
122 self.record_event(
123 EventType::Custom("profiling_stopped".to_string()),
124 "root",
125 Some(duration),
126 );
127 }
128 }
129
130 pub fn record_event(
132 &mut self,
133 event_type: EventType,
134 stage: impl Into<String>,
135 duration: Option<Duration>,
136 ) {
137 if self.config.sample_rate < 1.0 {
139 use scirs2_core::random::rand_prelude::*;
140 let mut rng = scirs2_core::random::rng();
141 let sample: f64 = rng.random();
142 if sample > self.config.sample_rate {
143 return;
144 }
145 }
146
147 if self.events.len() >= self.config.max_events {
148 self.events.remove(0);
150 }
151
152 let event = ProfileEvent {
153 timestamp: Instant::now(),
154 event_type,
155 stage: stage.into(),
156 duration,
157 metadata: HashMap::new(),
158 };
159
160 self.events.push(event);
161 }
162
163 pub fn start_stage(&mut self, stage: impl Into<String>) -> StageTimer {
165 let stage_name = stage.into();
166 self.record_event(EventType::StageStart, &stage_name, None);
167 StageTimer::new(stage_name, self.start_time.unwrap_or_else(Instant::now))
168 }
169
170 pub fn end_stage(&mut self, timer: StageTimer) {
172 let duration = timer.elapsed();
173 self.record_event(EventType::StageEnd, &timer.stage, Some(duration));
174
175 self.stage_timings
176 .entry(timer.stage.clone())
177 .or_insert_with(Vec::new)
178 .push(duration);
179 }
180
181 pub fn generate_report(&self) -> ProfileReport {
183 let total_duration = self
184 .start_time
185 .map(|start| start.elapsed())
186 .unwrap_or(Duration::from_secs(0));
187
188 let mut stage_stats = HashMap::new();
190 for (stage, durations) in &self.stage_timings {
191 let stats = StageStatistics::from_durations(durations);
192 stage_stats.insert(stage.clone(), stats);
193 }
194
195 let mut event_counts = HashMap::new();
197 for event in &self.events {
198 let event_name = format!("{:?}", event.event_type);
199 *event_counts.entry(event_name).or_insert(0) += 1;
200 }
201
202 let cache_hits = self
204 .events
205 .iter()
206 .filter(|e| e.event_type == EventType::CacheHit)
207 .count();
208 let cache_misses = self
209 .events
210 .iter()
211 .filter(|e| e.event_type == EventType::CacheMiss)
212 .count();
213 let cache_hit_rate = if cache_hits + cache_misses > 0 {
214 cache_hits as f64 / (cache_hits + cache_misses) as f64
215 } else {
216 0.0
217 };
218
219 ProfileReport {
220 pipeline_name: self.name.clone(),
221 total_duration,
222 total_events: self.events.len(),
223 stage_stats,
224 event_counts,
225 cache_hit_rate,
226 bottlenecks: self.identify_bottlenecks(),
227 recommendations: self.generate_recommendations(),
228 }
229 }
230
231 fn identify_bottlenecks(&self) -> Vec<Bottleneck> {
233 let mut bottlenecks = Vec::new();
234
235 for (stage, durations) in &self.stage_timings {
237 if durations.is_empty() {
238 continue;
239 }
240
241 let avg_duration = durations.iter().sum::<Duration>() / durations.len() as u32;
242
243 if avg_duration.as_millis() > 100 {
245 bottlenecks.push(Bottleneck {
246 category: BottleneckCategory::SlowStage,
247 description: format!("Stage '{}' is slow (avg: {:?})", stage, avg_duration),
248 severity: if avg_duration.as_millis() > 1000 {
249 Severity::High
250 } else {
251 Severity::Medium
252 },
253 affected_component: stage.clone(),
254 });
255 }
256 }
257
258 let cache_hits = self
260 .events
261 .iter()
262 .filter(|e| e.event_type == EventType::CacheHit)
263 .count();
264 let cache_misses = self
265 .events
266 .iter()
267 .filter(|e| e.event_type == EventType::CacheMiss)
268 .count();
269
270 if cache_hits + cache_misses > 0 {
271 let hit_rate = cache_hits as f64 / (cache_hits + cache_misses) as f64;
272 if hit_rate < 0.5 {
273 bottlenecks.push(Bottleneck {
274 category: BottleneckCategory::LowCacheHitRate,
275 description: format!("Low cache hit rate: {:.1}%", hit_rate * 100.0),
276 severity: Severity::Medium,
277 affected_component: "cache".to_string(),
278 });
279 }
280 }
281
282 bottlenecks
283 }
284
285 fn generate_recommendations(&self) -> Vec<String> {
287 let mut recommendations = Vec::new();
288
289 let cache_hits = self
291 .events
292 .iter()
293 .filter(|e| e.event_type == EventType::CacheHit)
294 .count();
295 let cache_misses = self
296 .events
297 .iter()
298 .filter(|e| e.event_type == EventType::CacheMiss)
299 .count();
300
301 if cache_hits + cache_misses > 0 {
302 let hit_rate = cache_hits as f64 / (cache_hits + cache_misses) as f64;
303 if hit_rate < 0.7 {
304 recommendations.push(
305 "Consider increasing cache size or using predictive prefetching".to_string(),
306 );
307 }
308 }
309
310 for (stage, durations) in &self.stage_timings {
312 if durations.is_empty() {
313 continue;
314 }
315
316 let avg_duration = durations.iter().sum::<Duration>() / durations.len() as u32;
317 if avg_duration.as_millis() > 500 {
318 recommendations.push(format!(
319 "Optimize '{}' stage - consider parallelization or GPU acceleration",
320 stage
321 ));
322 }
323 }
324
325 if recommendations.is_empty() {
326 recommendations.push("Pipeline is well optimized".to_string());
327 }
328
329 recommendations
330 }
331
332 pub fn export_events(&self) -> Vec<HashMap<String, String>> {
334 self.events
335 .iter()
336 .map(|event| {
337 let mut map = HashMap::new();
338 map.insert("stage".to_string(), event.stage.clone());
339 map.insert("type".to_string(), format!("{:?}", event.event_type));
340 if let Some(duration) = event.duration {
341 map.insert("duration_ms".to_string(), duration.as_millis().to_string());
342 }
343 map
344 })
345 .collect()
346 }
347}
348
349pub struct StageTimer {
351 stage: String,
352 start: Instant,
353}
354
355impl StageTimer {
356 fn new(stage: String, start: Instant) -> Self {
357 Self {
358 stage,
359 start: Instant::now(),
360 }
361 }
362
363 fn elapsed(&self) -> Duration {
364 self.start.elapsed()
365 }
366}
367
368#[derive(Debug, Clone)]
370pub struct StageStatistics {
371 pub count: usize,
373 pub total_duration: Duration,
375 pub avg_duration: Duration,
377 pub min_duration: Duration,
379 pub max_duration: Duration,
381 pub std_dev: Duration,
383}
384
385impl StageStatistics {
386 fn from_durations(durations: &[Duration]) -> Self {
387 if durations.is_empty() {
388 return Self {
389 count: 0,
390 total_duration: Duration::from_secs(0),
391 avg_duration: Duration::from_secs(0),
392 min_duration: Duration::from_secs(0),
393 max_duration: Duration::from_secs(0),
394 std_dev: Duration::from_secs(0),
395 };
396 }
397
398 let total: Duration = durations.iter().sum();
399 let avg = total / durations.len() as u32;
400 let min = *durations
401 .iter()
402 .min()
403 .expect("collection should not be empty for min()");
404 let max = *durations
405 .iter()
406 .max()
407 .expect("collection should not be empty for max()");
408
409 let variance: f64 = durations
411 .iter()
412 .map(|d| {
413 let diff = d.as_secs_f64() - avg.as_secs_f64();
414 diff * diff
415 })
416 .sum::<f64>()
417 / durations.len() as f64;
418 let std_dev = Duration::from_secs_f64(variance.sqrt());
419
420 Self {
421 count: durations.len(),
422 total_duration: total,
423 avg_duration: avg,
424 min_duration: min,
425 max_duration: max,
426 std_dev,
427 }
428 }
429}
430
431#[derive(Debug, Clone)]
433pub struct ProfileReport {
434 pub pipeline_name: String,
436 pub total_duration: Duration,
438 pub total_events: usize,
440 pub stage_stats: HashMap<String, StageStatistics>,
442 pub event_counts: HashMap<String, usize>,
444 pub cache_hit_rate: f64,
446 pub bottlenecks: Vec<Bottleneck>,
448 pub recommendations: Vec<String>,
450}
451
452impl ProfileReport {
453 pub fn format_report(&self) -> String {
455 let mut report = String::new();
456
457 report.push_str(&format!(
458 "Pipeline Profiling Report: {}\n",
459 self.pipeline_name
460 ));
461 report.push_str("=".repeat(60).as_str());
462 report.push('\n');
463
464 report.push_str(&format!("Total Duration: {:?}\n", self.total_duration));
465 report.push_str(&format!("Total Events: {}\n", self.total_events));
466 report.push_str(&format!(
467 "Cache Hit Rate: {:.1}%\n\n",
468 self.cache_hit_rate * 100.0
469 ));
470
471 if !self.stage_stats.is_empty() {
473 report.push_str("Stage Statistics:\n");
474 report.push_str("-".repeat(60).as_str());
475 report.push('\n');
476
477 let mut stages: Vec<_> = self.stage_stats.iter().collect();
478 stages.sort_by_key(|a| std::cmp::Reverse(a.1.total_duration));
479
480 for (stage, stats) in stages {
481 report.push_str(&format!(
482 " {}: {} calls, avg {:?}, total {:?}\n",
483 stage, stats.count, stats.avg_duration, stats.total_duration
484 ));
485 }
486 report.push('\n');
487 }
488
489 if !self.bottlenecks.is_empty() {
491 report.push_str("Identified Bottlenecks:\n");
492 report.push_str("-".repeat(60).as_str());
493 report.push('\n');
494
495 for bottleneck in &self.bottlenecks {
496 report.push_str(&format!(
497 " [{:?}] {}\n",
498 bottleneck.severity, bottleneck.description
499 ));
500 }
501 report.push('\n');
502 }
503
504 if !self.recommendations.is_empty() {
506 report.push_str("Recommendations:\n");
507 report.push_str("-".repeat(60).as_str());
508 report.push('\n');
509
510 for (i, rec) in self.recommendations.iter().enumerate() {
511 report.push_str(&format!(" {}. {}\n", i + 1, rec));
512 }
513 }
514
515 report
516 }
517}
518
519#[derive(Debug, Clone)]
521pub struct Bottleneck {
522 pub category: BottleneckCategory,
524 pub description: String,
526 pub severity: Severity,
528 pub affected_component: String,
530}
531
532#[derive(Debug, Clone, PartialEq, Eq)]
534pub enum BottleneckCategory {
535 SlowStage,
537 HighMemoryUsage,
539 LowCacheHitRate,
541 SlowIo,
543 InefficientTransform,
545}
546
547#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
549pub enum Severity {
550 Low,
552 Medium,
554 High,
556 Critical,
558}
559
560pub struct DatasetDebugger;
562
563impl DatasetDebugger {
564 pub fn inspect_samples<T>(
566 dataset: &dyn Dataset<T>,
567 num_samples: usize,
568 ) -> Result<Vec<SampleInfo>>
569 where
570 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
571 {
572 let mut samples = Vec::new();
573 let count = num_samples.min(dataset.len());
574
575 for i in 0..count {
576 if let Ok((features, labels)) = dataset.get(i) {
577 samples.push(SampleInfo {
578 index: i,
579 feature_shape: features.shape().dims().to_vec(),
580 label_shape: labels.shape().dims().to_vec(),
581 feature_size: features.size(),
582 label_size: labels.size(),
583 });
584 }
585 }
586
587 Ok(samples)
588 }
589
590 pub fn verify_consistency<T>(dataset: &dyn Dataset<T>) -> Result<ConsistencyReport>
592 where
593 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
594 {
595 let mut issues = Vec::new();
596 let samples_to_check = dataset.len().min(100);
597
598 if samples_to_check == 0 {
599 return Ok(ConsistencyReport {
600 total_samples: 0,
601 checked_samples: 0,
602 issues,
603 is_consistent: true,
604 });
605 }
606
607 let (first_features, first_labels) = dataset.get(0)?;
609 let expected_feature_shape = first_features.shape().dims().to_vec();
610 let expected_label_shape = first_labels.shape().dims().to_vec();
611
612 for i in 1..samples_to_check {
613 if let Ok((features, labels)) = dataset.get(i) {
614 if features.shape().dims() != expected_feature_shape.as_slice() {
615 issues.push(format!(
616 "Sample {}: Inconsistent feature shape {:?}, expected {:?}",
617 i,
618 features.shape().dims(),
619 expected_feature_shape
620 ));
621 }
622 if labels.shape().dims() != expected_label_shape.as_slice() {
623 issues.push(format!(
624 "Sample {}: Inconsistent label shape {:?}, expected {:?}",
625 i,
626 labels.shape().dims(),
627 expected_label_shape
628 ));
629 }
630 } else {
631 issues.push(format!("Sample {}: Failed to load", i));
632 }
633 }
634
635 let is_consistent = issues.is_empty();
636 Ok(ConsistencyReport {
637 total_samples: dataset.len(),
638 checked_samples: samples_to_check,
639 issues,
640 is_consistent,
641 })
642 }
643}
644
645#[derive(Debug, Clone)]
647pub struct SampleInfo {
648 pub index: usize,
650 pub feature_shape: Vec<usize>,
652 pub label_shape: Vec<usize>,
654 pub feature_size: usize,
656 pub label_size: usize,
658}
659
660#[derive(Debug, Clone)]
662pub struct ConsistencyReport {
663 pub total_samples: usize,
665 pub checked_samples: usize,
667 pub issues: Vec<String>,
669 pub is_consistent: bool,
671}
672
673impl ConsistencyReport {
674 pub fn format_report(&self) -> String {
676 let mut report = String::new();
677
678 report.push_str("Dataset Consistency Report\n");
679 report.push_str("=".repeat(60).as_str());
680 report.push('\n');
681
682 report.push_str(&format!("Total Samples: {}\n", self.total_samples));
683 report.push_str(&format!("Checked Samples: {}\n", self.checked_samples));
684 report.push_str(&format!("Is Consistent: {}\n\n", self.is_consistent));
685
686 if !self.issues.is_empty() {
687 report.push_str(&format!("Issues Found ({}):\n", self.issues.len()));
688 for (i, issue) in self.issues.iter().enumerate() {
689 report.push_str(&format!(" {}. {}\n", i + 1, issue));
690 }
691 } else {
692 report.push_str("No issues found.\n");
693 }
694
695 report
696 }
697}
698
699#[derive(Debug, Clone)]
705pub struct InspectionEvent {
706 pub step_name: String,
708 pub input_shape: Vec<usize>,
710 pub output_shape: Option<Vec<usize>>,
712 pub latency_micros: u64,
714 pub error: Option<String>,
716}
717
718#[derive(Debug, Clone)]
720pub struct PipelineInspectionReport {
721 pub events: Vec<InspectionEvent>,
723 pub total_latency_micros: u64,
725 pub error_count: usize,
727 pub sample_count: usize,
729}
730
731impl PipelineInspectionReport {
732 pub fn new() -> Self {
734 Self {
735 events: Vec::new(),
736 total_latency_micros: 0,
737 error_count: 0,
738 sample_count: 0,
739 }
740 }
741
742 fn push_event(&mut self, event: InspectionEvent) {
743 self.total_latency_micros += event.latency_micros;
744 if event.error.is_some() {
745 self.error_count += 1;
746 }
747 self.events.push(event);
748 }
749
750 pub fn avg_latency_per_step_micros(&self) -> u64 {
752 if self.events.is_empty() {
753 return 0;
754 }
755 self.total_latency_micros / self.events.len() as u64
756 }
757
758 pub fn error_rate(&self) -> f64 {
760 if self.events.is_empty() {
761 return 0.0;
762 }
763 self.error_count as f64 / self.events.len() as f64
764 }
765}
766
767impl Default for PipelineInspectionReport {
768 fn default() -> Self {
769 Self::new()
770 }
771}
772
773pub struct InspectablePipeline {
778 steps: Vec<(String, Box<dyn crate::transforms::Transform<f32>>)>,
779}
780
781impl InspectablePipeline {
782 pub fn new() -> Self {
784 Self { steps: Vec::new() }
785 }
786
787 pub fn add_step(
789 &mut self,
790 name: impl Into<String>,
791 transform: Box<dyn crate::transforms::Transform<f32>>,
792 ) {
793 self.steps.push((name.into(), transform));
794 }
795
796 pub fn inspect_sample(
799 &self,
800 sample: (tenflowers_core::Tensor<f32>, tenflowers_core::Tensor<f32>),
801 ) -> Vec<InspectionEvent> {
802 let mut events = Vec::with_capacity(self.steps.len());
803 let mut current = sample;
804
805 for (name, transform) in &self.steps {
806 let input_shape = current.0.shape().to_vec();
807 let start = std::time::Instant::now();
808 match transform.apply(current.clone()) {
809 Ok(out) => {
810 let latency_micros = start.elapsed().as_micros() as u64;
811 let output_shape = Some(out.0.shape().to_vec());
812 events.push(InspectionEvent {
813 step_name: name.clone(),
814 input_shape,
815 output_shape,
816 latency_micros,
817 error: None,
818 });
819 current = out;
820 }
821 Err(e) => {
822 let latency_micros = start.elapsed().as_micros() as u64;
823 events.push(InspectionEvent {
824 step_name: name.clone(),
825 input_shape,
826 output_shape: None,
827 latency_micros,
828 error: Some(e.to_string()),
829 });
830 break;
831 }
832 }
833 }
834
835 events
836 }
837
838 pub fn run_inspection_batch<D>(&self, dataset: &D, n_samples: usize) -> PipelineInspectionReport
840 where
841 D: crate::Dataset<f32>,
842 {
843 let mut report = PipelineInspectionReport::new();
844 let count = n_samples.min(dataset.len());
845
846 for idx in 0..count {
847 if let Ok(sample) = dataset.get(idx) {
848 for event in self.inspect_sample(sample) {
849 report.push_event(event);
850 }
851 report.sample_count += 1;
852 }
853 }
854
855 report
856 }
857}
858
859impl Default for InspectablePipeline {
860 fn default() -> Self {
861 Self::new()
862 }
863}
864
865#[cfg(test)]
866mod tests {
867 use super::*;
868 use crate::TensorDataset;
869 use tenflowers_core::Tensor;
870
871 #[test]
872 fn test_profiler_creation() {
873 let profiler = PipelineProfiler::default_config("test_pipeline");
874 assert_eq!(profiler.name, "test_pipeline");
875 }
876
877 #[test]
878 fn test_profiler_events() {
879 let mut profiler = PipelineProfiler::default_config("test");
880 profiler.start();
881
882 profiler.record_event(EventType::DataLoad, "load_stage", None);
883 profiler.record_event(
884 EventType::Transform,
885 "transform_stage",
886 Some(Duration::from_millis(10)),
887 );
888
889 profiler.stop();
890
891 let report = profiler.generate_report();
892 assert!(report.total_events > 0);
893 }
894
895 #[test]
896 fn test_stage_timing() {
897 let mut profiler = PipelineProfiler::default_config("test");
898 profiler.start();
899
900 let timer = profiler.start_stage("test_stage");
901 std::thread::sleep(Duration::from_millis(10));
902 profiler.end_stage(timer);
903
904 let report = profiler.generate_report();
905 assert!(report.stage_stats.contains_key("test_stage"));
906 }
907
908 #[test]
909 fn test_dataset_debugger_inspect() {
910 let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
911 .expect("test: tensor creation should succeed");
912 let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
913 .expect("test: tensor creation should succeed");
914 let dataset = TensorDataset::new(features, labels);
915
916 let samples =
917 DatasetDebugger::inspect_samples(&dataset, 5).expect("test: operation should succeed");
918 assert_eq!(samples.len(), 2);
919 assert_eq!(samples[0].feature_shape, vec![2]);
920 }
921
922 #[test]
923 fn test_consistency_check() {
924 let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
925 .expect("test: tensor creation should succeed");
926 let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
927 .expect("test: tensor creation should succeed");
928 let dataset = TensorDataset::new(features, labels);
929
930 let report =
931 DatasetDebugger::verify_consistency(&dataset).expect("test: operation should succeed");
932 assert!(report.is_consistent);
933 assert_eq!(report.total_samples, 2);
934 }
935
936 #[test]
937 fn test_profile_report_generation() {
938 let mut profiler = PipelineProfiler::default_config("test");
939 profiler.start();
940
941 let timer = profiler.start_stage("data_loading");
942 std::thread::sleep(Duration::from_millis(5));
943 profiler.end_stage(timer);
944
945 profiler.stop();
946
947 let report = profiler.generate_report();
948 let report_string = report.format_report();
949
950 assert!(report_string.contains("Pipeline Profiling Report"));
951 assert!(report_string.contains("data_loading"));
952 }
953
954 struct IdentityTransform;
957
958 impl crate::transforms::Transform<f32> for IdentityTransform {
959 fn apply(
960 &self,
961 sample: (Tensor<f32>, Tensor<f32>),
962 ) -> tenflowers_core::Result<(Tensor<f32>, Tensor<f32>)> {
963 Ok(sample)
964 }
965 }
966
967 #[test]
968 fn test_inspectable_pipeline_records_events() {
969 let mut pipeline = InspectablePipeline::new();
970 pipeline.add_step("identity", Box::new(IdentityTransform));
971
972 let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3])
973 .expect("test: tensor creation should succeed");
974 let labels =
975 Tensor::<f32>::from_vec(vec![1.0], &[1]).expect("test: tensor creation should succeed");
976
977 let events = pipeline.inspect_sample((features, labels));
978 assert_eq!(events.len(), 1);
979 assert_eq!(events[0].step_name, "identity");
980 assert!(events[0].error.is_none());
981 assert!(events[0].output_shape.is_some());
982 }
983
984 #[test]
985 fn test_inspectable_pipeline_shape_tracking() {
986 let mut pipeline = InspectablePipeline::new();
987 pipeline.add_step("step1", Box::new(IdentityTransform));
988 pipeline.add_step("step2", Box::new(IdentityTransform));
989
990 let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
991 .expect("test: tensor creation should succeed");
992 let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
993 .expect("test: tensor creation should succeed");
994
995 let events = pipeline.inspect_sample((features, labels));
996 assert_eq!(events.len(), 2);
997 assert_eq!(events[0].input_shape, vec![2, 2]);
998 assert_eq!(events[1].input_shape, vec![2, 2]);
999 }
1000
1001 #[test]
1002 fn test_run_inspection_batch_aggregation() {
1003 let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
1004 .expect("test: tensor creation should succeed");
1005 let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
1006 .expect("test: tensor creation should succeed");
1007 let dataset = TensorDataset::new(features, labels);
1008
1009 let mut pipeline = InspectablePipeline::new();
1010 pipeline.add_step("identity", Box::new(IdentityTransform));
1011
1012 let report = pipeline.run_inspection_batch(&dataset, 100);
1013 assert_eq!(report.sample_count, 2);
1014 assert_eq!(report.events.len(), 2);
1015 assert_eq!(report.error_count, 0);
1016 assert_eq!(report.error_rate(), 0.0);
1017 }
1018
1019 #[test]
1020 fn test_pipeline_inspection_report_empty() {
1021 let report = PipelineInspectionReport::new();
1022 assert_eq!(report.avg_latency_per_step_micros(), 0);
1023 assert_eq!(report.error_rate(), 0.0);
1024 }
1025}