use crate::Dataset;
use std::collections::HashMap;
use std::time::{Duration, Instant};
use tenflowers_core::{Result, TensorError};
pub struct PipelineProfiler {
name: String,
start_time: Option<Instant>,
events: Vec<ProfileEvent>,
stage_timings: HashMap<String, Vec<Duration>>,
config: ProfilerConfig,
}
#[derive(Debug, Clone)]
pub struct ProfilerConfig {
pub track_memory: bool,
pub track_cache: bool,
pub track_io: bool,
pub max_events: usize,
pub sample_rate: f64,
}
impl Default for ProfilerConfig {
fn default() -> Self {
Self {
track_memory: true,
track_cache: true,
track_io: true,
max_events: 10000,
sample_rate: 1.0,
}
}
}
#[derive(Debug, Clone)]
pub struct ProfileEvent {
pub timestamp: Instant,
pub event_type: EventType,
pub stage: String,
pub duration: Option<Duration>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EventType {
StageStart,
StageEnd,
DataLoad,
Transform,
CacheHit,
CacheMiss,
MemoryAlloc,
IoOperation,
Custom(String),
}
impl PipelineProfiler {
pub fn new(name: impl Into<String>, config: ProfilerConfig) -> Self {
Self {
name: name.into(),
start_time: None,
events: Vec::new(),
stage_timings: HashMap::new(),
config,
}
}
pub fn default_config(name: impl Into<String>) -> Self {
Self::new(name, ProfilerConfig::default())
}
pub fn start(&mut self) {
self.start_time = Some(Instant::now());
self.record_event(
EventType::Custom("profiling_started".to_string()),
"root",
None,
);
}
pub fn stop(&mut self) {
if let Some(start) = self.start_time {
let duration = start.elapsed();
self.record_event(
EventType::Custom("profiling_stopped".to_string()),
"root",
Some(duration),
);
}
}
pub fn record_event(
&mut self,
event_type: EventType,
stage: impl Into<String>,
duration: Option<Duration>,
) {
if self.config.sample_rate < 1.0 {
use scirs2_core::random::rand_prelude::*;
let mut rng = scirs2_core::random::rng();
let sample: f64 = rng.random();
if sample > self.config.sample_rate {
return;
}
}
if self.events.len() >= self.config.max_events {
self.events.remove(0);
}
let event = ProfileEvent {
timestamp: Instant::now(),
event_type,
stage: stage.into(),
duration,
metadata: HashMap::new(),
};
self.events.push(event);
}
pub fn start_stage(&mut self, stage: impl Into<String>) -> StageTimer {
let stage_name = stage.into();
self.record_event(EventType::StageStart, &stage_name, None);
StageTimer::new(stage_name, self.start_time.unwrap_or_else(Instant::now))
}
pub fn end_stage(&mut self, timer: StageTimer) {
let duration = timer.elapsed();
self.record_event(EventType::StageEnd, &timer.stage, Some(duration));
self.stage_timings
.entry(timer.stage.clone())
.or_insert_with(Vec::new)
.push(duration);
}
pub fn generate_report(&self) -> ProfileReport {
let total_duration = self
.start_time
.map(|start| start.elapsed())
.unwrap_or(Duration::from_secs(0));
let mut stage_stats = HashMap::new();
for (stage, durations) in &self.stage_timings {
let stats = StageStatistics::from_durations(durations);
stage_stats.insert(stage.clone(), stats);
}
let mut event_counts = HashMap::new();
for event in &self.events {
let event_name = format!("{:?}", event.event_type);
*event_counts.entry(event_name).or_insert(0) += 1;
}
let cache_hits = self
.events
.iter()
.filter(|e| e.event_type == EventType::CacheHit)
.count();
let cache_misses = self
.events
.iter()
.filter(|e| e.event_type == EventType::CacheMiss)
.count();
let cache_hit_rate = if cache_hits + cache_misses > 0 {
cache_hits as f64 / (cache_hits + cache_misses) as f64
} else {
0.0
};
ProfileReport {
pipeline_name: self.name.clone(),
total_duration,
total_events: self.events.len(),
stage_stats,
event_counts,
cache_hit_rate,
bottlenecks: self.identify_bottlenecks(),
recommendations: self.generate_recommendations(),
}
}
fn identify_bottlenecks(&self) -> Vec<Bottleneck> {
let mut bottlenecks = Vec::new();
for (stage, durations) in &self.stage_timings {
if durations.is_empty() {
continue;
}
let avg_duration = durations.iter().sum::<Duration>() / durations.len() as u32;
if avg_duration.as_millis() > 100 {
bottlenecks.push(Bottleneck {
category: BottleneckCategory::SlowStage,
description: format!("Stage '{}' is slow (avg: {:?})", stage, avg_duration),
severity: if avg_duration.as_millis() > 1000 {
Severity::High
} else {
Severity::Medium
},
affected_component: stage.clone(),
});
}
}
let cache_hits = self
.events
.iter()
.filter(|e| e.event_type == EventType::CacheHit)
.count();
let cache_misses = self
.events
.iter()
.filter(|e| e.event_type == EventType::CacheMiss)
.count();
if cache_hits + cache_misses > 0 {
let hit_rate = cache_hits as f64 / (cache_hits + cache_misses) as f64;
if hit_rate < 0.5 {
bottlenecks.push(Bottleneck {
category: BottleneckCategory::LowCacheHitRate,
description: format!("Low cache hit rate: {:.1}%", hit_rate * 100.0),
severity: Severity::Medium,
affected_component: "cache".to_string(),
});
}
}
bottlenecks
}
fn generate_recommendations(&self) -> Vec<String> {
let mut recommendations = Vec::new();
let cache_hits = self
.events
.iter()
.filter(|e| e.event_type == EventType::CacheHit)
.count();
let cache_misses = self
.events
.iter()
.filter(|e| e.event_type == EventType::CacheMiss)
.count();
if cache_hits + cache_misses > 0 {
let hit_rate = cache_hits as f64 / (cache_hits + cache_misses) as f64;
if hit_rate < 0.7 {
recommendations.push(
"Consider increasing cache size or using predictive prefetching".to_string(),
);
}
}
for (stage, durations) in &self.stage_timings {
if durations.is_empty() {
continue;
}
let avg_duration = durations.iter().sum::<Duration>() / durations.len() as u32;
if avg_duration.as_millis() > 500 {
recommendations.push(format!(
"Optimize '{}' stage - consider parallelization or GPU acceleration",
stage
));
}
}
if recommendations.is_empty() {
recommendations.push("Pipeline is well optimized".to_string());
}
recommendations
}
pub fn export_events(&self) -> Vec<HashMap<String, String>> {
self.events
.iter()
.map(|event| {
let mut map = HashMap::new();
map.insert("stage".to_string(), event.stage.clone());
map.insert("type".to_string(), format!("{:?}", event.event_type));
if let Some(duration) = event.duration {
map.insert("duration_ms".to_string(), duration.as_millis().to_string());
}
map
})
.collect()
}
}
pub struct StageTimer {
stage: String,
start: Instant,
}
impl StageTimer {
fn new(stage: String, start: Instant) -> Self {
Self {
stage,
start: Instant::now(),
}
}
fn elapsed(&self) -> Duration {
self.start.elapsed()
}
}
#[derive(Debug, Clone)]
pub struct StageStatistics {
pub count: usize,
pub total_duration: Duration,
pub avg_duration: Duration,
pub min_duration: Duration,
pub max_duration: Duration,
pub std_dev: Duration,
}
impl StageStatistics {
fn from_durations(durations: &[Duration]) -> Self {
if durations.is_empty() {
return Self {
count: 0,
total_duration: Duration::from_secs(0),
avg_duration: Duration::from_secs(0),
min_duration: Duration::from_secs(0),
max_duration: Duration::from_secs(0),
std_dev: Duration::from_secs(0),
};
}
let total: Duration = durations.iter().sum();
let avg = total / durations.len() as u32;
let min = *durations
.iter()
.min()
.expect("collection should not be empty for min()");
let max = *durations
.iter()
.max()
.expect("collection should not be empty for max()");
let variance: f64 = durations
.iter()
.map(|d| {
let diff = d.as_secs_f64() - avg.as_secs_f64();
diff * diff
})
.sum::<f64>()
/ durations.len() as f64;
let std_dev = Duration::from_secs_f64(variance.sqrt());
Self {
count: durations.len(),
total_duration: total,
avg_duration: avg,
min_duration: min,
max_duration: max,
std_dev,
}
}
}
#[derive(Debug, Clone)]
pub struct ProfileReport {
pub pipeline_name: String,
pub total_duration: Duration,
pub total_events: usize,
pub stage_stats: HashMap<String, StageStatistics>,
pub event_counts: HashMap<String, usize>,
pub cache_hit_rate: f64,
pub bottlenecks: Vec<Bottleneck>,
pub recommendations: Vec<String>,
}
impl ProfileReport {
pub fn format_report(&self) -> String {
let mut report = String::new();
report.push_str(&format!(
"Pipeline Profiling Report: {}\n",
self.pipeline_name
));
report.push_str("=".repeat(60).as_str());
report.push('\n');
report.push_str(&format!("Total Duration: {:?}\n", self.total_duration));
report.push_str(&format!("Total Events: {}\n", self.total_events));
report.push_str(&format!(
"Cache Hit Rate: {:.1}%\n\n",
self.cache_hit_rate * 100.0
));
if !self.stage_stats.is_empty() {
report.push_str("Stage Statistics:\n");
report.push_str("-".repeat(60).as_str());
report.push('\n');
let mut stages: Vec<_> = self.stage_stats.iter().collect();
stages.sort_by_key(|a| std::cmp::Reverse(a.1.total_duration));
for (stage, stats) in stages {
report.push_str(&format!(
" {}: {} calls, avg {:?}, total {:?}\n",
stage, stats.count, stats.avg_duration, stats.total_duration
));
}
report.push('\n');
}
if !self.bottlenecks.is_empty() {
report.push_str("Identified Bottlenecks:\n");
report.push_str("-".repeat(60).as_str());
report.push('\n');
for bottleneck in &self.bottlenecks {
report.push_str(&format!(
" [{:?}] {}\n",
bottleneck.severity, bottleneck.description
));
}
report.push('\n');
}
if !self.recommendations.is_empty() {
report.push_str("Recommendations:\n");
report.push_str("-".repeat(60).as_str());
report.push('\n');
for (i, rec) in self.recommendations.iter().enumerate() {
report.push_str(&format!(" {}. {}\n", i + 1, rec));
}
}
report
}
}
#[derive(Debug, Clone)]
pub struct Bottleneck {
pub category: BottleneckCategory,
pub description: String,
pub severity: Severity,
pub affected_component: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BottleneckCategory {
SlowStage,
HighMemoryUsage,
LowCacheHitRate,
SlowIo,
InefficientTransform,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum Severity {
Low,
Medium,
High,
Critical,
}
pub struct DatasetDebugger;
impl DatasetDebugger {
pub fn inspect_samples<T>(
dataset: &dyn Dataset<T>,
num_samples: usize,
) -> Result<Vec<SampleInfo>>
where
T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
{
let mut samples = Vec::new();
let count = num_samples.min(dataset.len());
for i in 0..count {
if let Ok((features, labels)) = dataset.get(i) {
samples.push(SampleInfo {
index: i,
feature_shape: features.shape().dims().to_vec(),
label_shape: labels.shape().dims().to_vec(),
feature_size: features.size(),
label_size: labels.size(),
});
}
}
Ok(samples)
}
pub fn verify_consistency<T>(dataset: &dyn Dataset<T>) -> Result<ConsistencyReport>
where
T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
{
let mut issues = Vec::new();
let samples_to_check = dataset.len().min(100);
if samples_to_check == 0 {
return Ok(ConsistencyReport {
total_samples: 0,
checked_samples: 0,
issues,
is_consistent: true,
});
}
let (first_features, first_labels) = dataset.get(0)?;
let expected_feature_shape = first_features.shape().dims().to_vec();
let expected_label_shape = first_labels.shape().dims().to_vec();
for i in 1..samples_to_check {
if let Ok((features, labels)) = dataset.get(i) {
if features.shape().dims() != expected_feature_shape.as_slice() {
issues.push(format!(
"Sample {}: Inconsistent feature shape {:?}, expected {:?}",
i,
features.shape().dims(),
expected_feature_shape
));
}
if labels.shape().dims() != expected_label_shape.as_slice() {
issues.push(format!(
"Sample {}: Inconsistent label shape {:?}, expected {:?}",
i,
labels.shape().dims(),
expected_label_shape
));
}
} else {
issues.push(format!("Sample {}: Failed to load", i));
}
}
let is_consistent = issues.is_empty();
Ok(ConsistencyReport {
total_samples: dataset.len(),
checked_samples: samples_to_check,
issues,
is_consistent,
})
}
}
#[derive(Debug, Clone)]
pub struct SampleInfo {
pub index: usize,
pub feature_shape: Vec<usize>,
pub label_shape: Vec<usize>,
pub feature_size: usize,
pub label_size: usize,
}
#[derive(Debug, Clone)]
pub struct ConsistencyReport {
pub total_samples: usize,
pub checked_samples: usize,
pub issues: Vec<String>,
pub is_consistent: bool,
}
impl ConsistencyReport {
pub fn format_report(&self) -> String {
let mut report = String::new();
report.push_str("Dataset Consistency Report\n");
report.push_str("=".repeat(60).as_str());
report.push('\n');
report.push_str(&format!("Total Samples: {}\n", self.total_samples));
report.push_str(&format!("Checked Samples: {}\n", self.checked_samples));
report.push_str(&format!("Is Consistent: {}\n\n", self.is_consistent));
if !self.issues.is_empty() {
report.push_str(&format!("Issues Found ({}):\n", self.issues.len()));
for (i, issue) in self.issues.iter().enumerate() {
report.push_str(&format!(" {}. {}\n", i + 1, issue));
}
} else {
report.push_str("No issues found.\n");
}
report
}
}
#[derive(Debug, Clone)]
pub struct InspectionEvent {
pub step_name: String,
pub input_shape: Vec<usize>,
pub output_shape: Option<Vec<usize>>,
pub latency_micros: u64,
pub error: Option<String>,
}
#[derive(Debug, Clone)]
pub struct PipelineInspectionReport {
pub events: Vec<InspectionEvent>,
pub total_latency_micros: u64,
pub error_count: usize,
pub sample_count: usize,
}
impl PipelineInspectionReport {
pub fn new() -> Self {
Self {
events: Vec::new(),
total_latency_micros: 0,
error_count: 0,
sample_count: 0,
}
}
fn push_event(&mut self, event: InspectionEvent) {
self.total_latency_micros += event.latency_micros;
if event.error.is_some() {
self.error_count += 1;
}
self.events.push(event);
}
pub fn avg_latency_per_step_micros(&self) -> u64 {
if self.events.is_empty() {
return 0;
}
self.total_latency_micros / self.events.len() as u64
}
pub fn error_rate(&self) -> f64 {
if self.events.is_empty() {
return 0.0;
}
self.error_count as f64 / self.events.len() as f64
}
}
impl Default for PipelineInspectionReport {
fn default() -> Self {
Self::new()
}
}
pub struct InspectablePipeline {
steps: Vec<(String, Box<dyn crate::transforms::Transform<f32>>)>,
}
impl InspectablePipeline {
pub fn new() -> Self {
Self { steps: Vec::new() }
}
pub fn add_step(
&mut self,
name: impl Into<String>,
transform: Box<dyn crate::transforms::Transform<f32>>,
) {
self.steps.push((name.into(), transform));
}
pub fn inspect_sample(
&self,
sample: (tenflowers_core::Tensor<f32>, tenflowers_core::Tensor<f32>),
) -> Vec<InspectionEvent> {
let mut events = Vec::with_capacity(self.steps.len());
let mut current = sample;
for (name, transform) in &self.steps {
let input_shape = current.0.shape().to_vec();
let start = std::time::Instant::now();
match transform.apply(current.clone()) {
Ok(out) => {
let latency_micros = start.elapsed().as_micros() as u64;
let output_shape = Some(out.0.shape().to_vec());
events.push(InspectionEvent {
step_name: name.clone(),
input_shape,
output_shape,
latency_micros,
error: None,
});
current = out;
}
Err(e) => {
let latency_micros = start.elapsed().as_micros() as u64;
events.push(InspectionEvent {
step_name: name.clone(),
input_shape,
output_shape: None,
latency_micros,
error: Some(e.to_string()),
});
break;
}
}
}
events
}
pub fn run_inspection_batch<D>(&self, dataset: &D, n_samples: usize) -> PipelineInspectionReport
where
D: crate::Dataset<f32>,
{
let mut report = PipelineInspectionReport::new();
let count = n_samples.min(dataset.len());
for idx in 0..count {
if let Ok(sample) = dataset.get(idx) {
for event in self.inspect_sample(sample) {
report.push_event(event);
}
report.sample_count += 1;
}
}
report
}
}
impl Default for InspectablePipeline {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TensorDataset;
use tenflowers_core::Tensor;
#[test]
fn test_profiler_creation() {
let profiler = PipelineProfiler::default_config("test_pipeline");
assert_eq!(profiler.name, "test_pipeline");
}
#[test]
fn test_profiler_events() {
let mut profiler = PipelineProfiler::default_config("test");
profiler.start();
profiler.record_event(EventType::DataLoad, "load_stage", None);
profiler.record_event(
EventType::Transform,
"transform_stage",
Some(Duration::from_millis(10)),
);
profiler.stop();
let report = profiler.generate_report();
assert!(report.total_events > 0);
}
#[test]
fn test_stage_timing() {
let mut profiler = PipelineProfiler::default_config("test");
profiler.start();
let timer = profiler.start_stage("test_stage");
std::thread::sleep(Duration::from_millis(10));
profiler.end_stage(timer);
let report = profiler.generate_report();
assert!(report.stage_stats.contains_key("test_stage"));
}
#[test]
fn test_dataset_debugger_inspect() {
let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
.expect("test: tensor creation should succeed");
let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
.expect("test: tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let samples =
DatasetDebugger::inspect_samples(&dataset, 5).expect("test: operation should succeed");
assert_eq!(samples.len(), 2);
assert_eq!(samples[0].feature_shape, vec![2]);
}
#[test]
fn test_consistency_check() {
let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
.expect("test: tensor creation should succeed");
let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
.expect("test: tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let report =
DatasetDebugger::verify_consistency(&dataset).expect("test: operation should succeed");
assert!(report.is_consistent);
assert_eq!(report.total_samples, 2);
}
#[test]
fn test_profile_report_generation() {
let mut profiler = PipelineProfiler::default_config("test");
profiler.start();
let timer = profiler.start_stage("data_loading");
std::thread::sleep(Duration::from_millis(5));
profiler.end_stage(timer);
profiler.stop();
let report = profiler.generate_report();
let report_string = report.format_report();
assert!(report_string.contains("Pipeline Profiling Report"));
assert!(report_string.contains("data_loading"));
}
struct IdentityTransform;
impl crate::transforms::Transform<f32> for IdentityTransform {
fn apply(
&self,
sample: (Tensor<f32>, Tensor<f32>),
) -> tenflowers_core::Result<(Tensor<f32>, Tensor<f32>)> {
Ok(sample)
}
}
#[test]
fn test_inspectable_pipeline_records_events() {
let mut pipeline = InspectablePipeline::new();
pipeline.add_step("identity", Box::new(IdentityTransform));
let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3])
.expect("test: tensor creation should succeed");
let labels =
Tensor::<f32>::from_vec(vec![1.0], &[1]).expect("test: tensor creation should succeed");
let events = pipeline.inspect_sample((features, labels));
assert_eq!(events.len(), 1);
assert_eq!(events[0].step_name, "identity");
assert!(events[0].error.is_none());
assert!(events[0].output_shape.is_some());
}
#[test]
fn test_inspectable_pipeline_shape_tracking() {
let mut pipeline = InspectablePipeline::new();
pipeline.add_step("step1", Box::new(IdentityTransform));
pipeline.add_step("step2", Box::new(IdentityTransform));
let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
.expect("test: tensor creation should succeed");
let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
.expect("test: tensor creation should succeed");
let events = pipeline.inspect_sample((features, labels));
assert_eq!(events.len(), 2);
assert_eq!(events[0].input_shape, vec![2, 2]);
assert_eq!(events[1].input_shape, vec![2, 2]);
}
#[test]
fn test_run_inspection_batch_aggregation() {
let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
.expect("test: tensor creation should succeed");
let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
.expect("test: tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let mut pipeline = InspectablePipeline::new();
pipeline.add_step("identity", Box::new(IdentityTransform));
let report = pipeline.run_inspection_batch(&dataset, 100);
assert_eq!(report.sample_count, 2);
assert_eq!(report.events.len(), 2);
assert_eq!(report.error_count, 0);
assert_eq!(report.error_rate(), 0.0);
}
#[test]
fn test_pipeline_inspection_report_empty() {
let report = PipelineInspectionReport::new();
assert_eq!(report.avg_latency_per_step_micros(), 0);
assert_eq!(report.error_rate(), 0.0);
}
}