use std::collections::HashMap;
use std::fmt;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use torsh_core::{
dtype::TensorElement,
error::{Result, TorshError},
};
use crate::Tensor;
pub type TrackId = u64;
#[derive(Debug, Clone)]
pub struct TensorValueStats<T: TensorElement> {
pub min: Option<T>,
pub max: Option<T>,
pub mean: Option<f64>,
pub std: Option<f64>,
pub nan_count: usize,
pub inf_count: usize,
pub zero_count: usize,
pub total_elements: usize,
}
impl<T: TensorElement> TensorValueStats<T> {
pub fn from_tensor(tensor: &Tensor<T>) -> Result<Self>
where
T: Copy + PartialOrd + num_traits::Zero + num_traits::ToPrimitive,
{
let data = tensor.to_vec()?;
let total_elements = data.len();
let mut min = None;
let mut max = None;
let mut nan_count = 0;
let mut inf_count = 0;
let mut zero_count = 0;
let mut sum = 0.0f64;
for &val in &data {
if let Some(f_val) = num_traits::ToPrimitive::to_f64(&val) {
if f_val.is_nan() {
nan_count += 1;
continue;
}
if f_val.is_infinite() {
inf_count += 1;
continue;
}
sum += f_val;
}
match (min, max) {
(None, None) => {
min = Some(val);
max = Some(val);
}
(Some(current_min), Some(current_max)) => {
if val < current_min {
min = Some(val);
}
if val > current_max {
max = Some(val);
}
}
_ => unreachable!(),
}
if val == <T as num_traits::Zero>::zero() {
zero_count += 1;
}
}
let mean = if total_elements > 0 && nan_count + inf_count < total_elements {
Some(sum / (total_elements - nan_count - inf_count) as f64)
} else {
None
};
let std = if let Some(mean_val) = mean {
let variance: f64 = data
.iter()
.filter_map(|&v| num_traits::ToPrimitive::to_f64(&v))
.filter(|&f| !f.is_nan() && !f.is_infinite())
.map(|v| {
let diff = v - mean_val;
diff * diff
})
.sum::<f64>()
/ (total_elements - nan_count - inf_count) as f64;
Some(variance.sqrt())
} else {
None
};
Ok(Self {
min,
max,
mean,
std,
nan_count,
inf_count,
zero_count,
total_elements,
})
}
}
impl<T: TensorElement + fmt::Display> fmt::Display for TensorValueStats<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Tensor Value Statistics:")?;
writeln!(f, " Total elements: {}", self.total_elements)?;
if let (Some(min), Some(max)) = (&self.min, &self.max) {
writeln!(f, " Min: {}", min)?;
writeln!(f, " Max: {}", max)?;
}
if let Some(mean) = self.mean {
writeln!(f, " Mean: {:.6}", mean)?;
}
if let Some(std) = self.std {
writeln!(f, " Std: {:.6}", std)?;
}
if self.nan_count > 0 {
writeln!(f, " NaN count: {}", self.nan_count)?;
}
if self.inf_count > 0 {
writeln!(f, " Inf count: {}", self.inf_count)?;
}
if self.zero_count > 0 {
writeln!(f, " Zero count: {}", self.zero_count)?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct OperationRecord {
pub operation: String,
pub parameters: Vec<String>,
pub timestamp: Instant,
pub duration: Option<Duration>,
pub shape_before: Vec<usize>,
pub shape_after: Vec<usize>,
}
impl fmt::Display for OperationRecord {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.operation)?;
if !self.parameters.is_empty() {
write!(f, "({}) ", self.parameters.join(", "))?;
}
write!(f, ": {:?} -> {:?}", self.shape_before, self.shape_after)?;
if let Some(duration) = self.duration {
write!(f, " [{:?}]", duration)?;
}
Ok(())
}
}
#[derive(Clone)]
pub struct TensorSnapshot<T: TensorElement> {
pub values: Vec<T>,
pub shape: Vec<usize>,
pub timestamp: Instant,
pub label: String,
}
pub struct TrackedTensor<T: TensorElement> {
pub id: TrackId,
pub label: String,
pub tensor: Tensor<T>,
pub operations: Vec<OperationRecord>,
pub snapshots: Vec<TensorSnapshot<T>>,
pub start_time: Instant,
}
impl<T: TensorElement> TrackedTensor<T> {
pub fn new(id: TrackId, label: String, tensor: Tensor<T>) -> Self {
Self {
id,
label,
tensor,
operations: Vec::new(),
snapshots: Vec::new(),
start_time: Instant::now(),
}
}
pub fn record_operation(
&mut self,
operation: String,
parameters: Vec<String>,
new_tensor: &Tensor<T>,
duration: Option<Duration>,
) {
let shape_before = self.tensor.shape().dims().to_vec();
let shape_after = new_tensor.shape().dims().to_vec();
self.operations.push(OperationRecord {
operation,
parameters,
timestamp: Instant::now(),
duration,
shape_before,
shape_after,
});
self.tensor = new_tensor.clone();
}
pub fn take_snapshot(&mut self, label: String) -> Result<()>
where
T: Copy,
{
let values = self.tensor.to_vec()?;
let shape = self.tensor.shape().dims().to_vec();
self.snapshots.push(TensorSnapshot {
values,
shape,
timestamp: Instant::now(),
label,
});
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct TrackingConfig {
pub enabled: bool,
pub max_operations: usize,
pub max_snapshots: usize,
pub auto_snapshot: bool,
pub operation_filter: Vec<String>,
}
impl Default for TrackingConfig {
fn default() -> Self {
Self {
enabled: true,
max_operations: 1000,
max_snapshots: 100,
auto_snapshot: false,
operation_filter: Vec::new(),
}
}
}
impl TrackingConfig {
pub fn minimal() -> Self {
Self {
enabled: true,
max_operations: 100,
max_snapshots: 10,
auto_snapshot: false,
operation_filter: Vec::new(),
}
}
pub fn comprehensive() -> Self {
Self {
enabled: true,
max_operations: 10000,
max_snapshots: 1000,
auto_snapshot: true,
operation_filter: Vec::new(),
}
}
pub fn filtered(operations: Vec<String>) -> Self {
Self {
enabled: true,
max_operations: 1000,
max_snapshots: 100,
auto_snapshot: false,
operation_filter: operations,
}
}
}
pub struct TensorTracker<T: TensorElement> {
config: Arc<RwLock<TrackingConfig>>,
tensors: Arc<RwLock<HashMap<TrackId, TrackedTensor<T>>>>,
next_id: Arc<RwLock<TrackId>>,
}
impl<T: TensorElement> TensorTracker<T> {
pub fn new() -> Self {
Self::with_config(TrackingConfig::default())
}
pub fn with_config(config: TrackingConfig) -> Self {
Self {
config: Arc::new(RwLock::new(config)),
tensors: Arc::new(RwLock::new(HashMap::new())),
next_id: Arc::new(RwLock::new(0)),
}
}
pub fn track(&mut self, tensor: Tensor<T>, label: impl Into<String>) -> Result<TrackId>
where
T: Copy,
{
let config = self.config.read().expect("lock should not be poisoned");
if !config.enabled {
return Err(TorshError::InvalidArgument(
"Tracking is disabled".to_string(),
));
}
drop(config);
let mut next_id = self.next_id.write().expect("lock should not be poisoned");
let id = *next_id;
*next_id += 1;
drop(next_id);
let mut tracked = TrackedTensor::new(id, label.into(), tensor.clone());
let config = self.config.read().expect("lock should not be poisoned");
if config.auto_snapshot {
tracked.take_snapshot("initial".to_string())?;
}
drop(config);
self.tensors
.write()
.expect("lock should not be poisoned")
.insert(id, tracked);
Ok(id)
}
pub fn untrack(&mut self, id: TrackId) -> Result<()> {
self.tensors
.write()
.expect("lock should not be poisoned")
.remove(&id);
Ok(())
}
pub fn record_operation<P: fmt::Display>(
&self,
id: TrackId,
operation: impl Into<String>,
parameters: Vec<P>,
result_tensor: &Tensor<T>,
) -> Result<()>
where
T: Copy,
{
let config = self.config.read().expect("lock should not be poisoned");
if !config.enabled {
return Ok(());
}
let operation_str = operation.into();
if !config.operation_filter.is_empty() && !config.operation_filter.contains(&operation_str)
{
return Ok(());
}
let auto_snapshot = config.auto_snapshot;
let max_operations = config.max_operations;
drop(config);
let mut tensors = self.tensors.write().expect("lock should not be poisoned");
let tracked = tensors.get_mut(&id).ok_or_else(|| {
TorshError::InvalidArgument(format!("Tensor with ID {} is not tracked", id))
})?;
let params: Vec<String> = parameters.iter().map(|p| format!("{}", p)).collect();
tracked.record_operation(operation_str.clone(), params, result_tensor, None);
if tracked.operations.len() > max_operations {
tracked.operations.remove(0);
}
if auto_snapshot {
tracked.take_snapshot(format!("after_{}", operation_str))?;
}
Ok(())
}
pub fn snapshot(&self, id: TrackId, label: impl Into<String>) -> Result<()>
where
T: Copy,
{
let mut tensors = self.tensors.write().expect("lock should not be poisoned");
let tracked = tensors.get_mut(&id).ok_or_else(|| {
TorshError::InvalidArgument(format!("Tensor with ID {} is not tracked", id))
})?;
tracked.take_snapshot(label.into())?;
let config = self.config.read().expect("lock should not be poisoned");
if tracked.snapshots.len() > config.max_snapshots {
tracked.snapshots.remove(0);
}
Ok(())
}
pub fn generate_report(&self, id: TrackId) -> Result<String>
where
T: Copy + PartialOrd + num_traits::Zero + num_traits::ToPrimitive + fmt::Display,
{
let tensors = self.tensors.read().expect("lock should not be poisoned");
let tracked = tensors.get(&id).ok_or_else(|| {
TorshError::InvalidArgument(format!("Tensor with ID {} is not tracked", id))
})?;
let mut report = String::new();
report.push_str(&format!(
"=== Tracking Report for '{}' (ID: {}) ===\n\n",
tracked.label, tracked.id
));
report.push_str(&format!(
"Tracking duration: {:?}\n",
tracked.start_time.elapsed()
));
report.push_str(&format!(
"Current shape: {:?}\n",
tracked.tensor.shape().dims()
));
report.push_str(&format!(
"Operations performed: {}\n",
tracked.operations.len()
));
report.push_str(&format!("Snapshots taken: {}\n\n", tracked.snapshots.len()));
if let Ok(stats) = TensorValueStats::from_tensor(&tracked.tensor) {
report.push_str("Current Value Statistics:\n");
report.push_str(&format!("{}\n", stats));
}
if !tracked.operations.is_empty() {
report.push_str("\nOperation History:\n");
for (i, op) in tracked.operations.iter().enumerate() {
report.push_str(&format!(" {}. {}\n", i + 1, op));
}
}
if !tracked.snapshots.is_empty() {
report.push_str("\nSnapshots:\n");
for (i, snapshot) in tracked.snapshots.iter().enumerate() {
report.push_str(&format!(
" {}. '{}' - shape: {:?}, elements: {}\n",
i + 1,
snapshot.label,
snapshot.shape,
snapshot.values.len()
));
}
}
Ok(report)
}
pub fn get_tensor(&self, id: TrackId) -> Result<Tensor<T>> {
let tensors = self.tensors.read().expect("lock should not be poisoned");
let tracked = tensors.get(&id).ok_or_else(|| {
TorshError::InvalidArgument(format!("Tensor with ID {} is not tracked", id))
})?;
Ok(tracked.tensor.clone())
}
pub fn tracked_ids(&self) -> Vec<TrackId> {
self.tensors
.read()
.expect("lock should not be poisoned")
.keys()
.copied()
.collect()
}
pub fn clear(&mut self) {
self.tensors
.write()
.expect("lock should not be poisoned")
.clear();
*self.next_id.write().expect("lock should not be poisoned") = 0;
}
}
impl<T: TensorElement> Default for TensorTracker<T> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::creation;
use torsh_core::device::DeviceType;
#[test]
fn test_tensor_tracker_basic() {
let mut tracker = TensorTracker::new();
let tensor = creation::ones::<f32>(&[2, 2]).expect("ones creation should succeed");
let id = tracker
.track(tensor.clone(), "test_tensor")
.expect("tracking should succeed");
assert_eq!(tracker.tracked_ids().len(), 1);
let result = tensor
.mul_scalar(2.0)
.expect("scalar multiplication should succeed");
tracker
.record_operation(id, "mul_scalar", vec![2.0], &result)
.expect("multiplication should succeed");
let retrieved = tracker.get_tensor(id).expect("operation should succeed");
assert_eq!(retrieved.shape().dims(), &[2, 2]);
tracker.untrack(id).expect("untracking should succeed");
assert_eq!(tracker.tracked_ids().len(), 0);
}
#[test]
fn test_tensor_value_stats() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let tensor = Tensor::from_data(data, vec![5], DeviceType::Cpu)
.expect("tensor creation should succeed");
let stats = TensorValueStats::from_tensor(&tensor).expect("from_tensor should succeed");
assert_eq!(stats.total_elements, 5);
assert_eq!(stats.min, Some(1.0));
assert_eq!(stats.max, Some(5.0));
assert!((stats.mean.expect("stat value should be available") - 3.0).abs() < 1e-6);
}
#[test]
fn test_tracking_snapshots() {
let mut tracker = TensorTracker::new();
let tensor = creation::ones::<f32>(&[3, 3]).expect("ones creation should succeed");
let id = tracker
.track(tensor.clone(), "snapshot_test")
.expect("tracking should succeed");
tracker
.snapshot(id, "first_snapshot")
.expect("snapshot should succeed");
tracker
.snapshot(id, "second_snapshot")
.expect("snapshot should succeed");
let tensors = tracker.tensors.read().expect("lock should not be poisoned");
let tracked = tensors.get(&id).expect("get should succeed");
assert_eq!(tracked.snapshots.len(), 2);
}
#[test]
fn test_tracking_report() {
let mut tracker = TensorTracker::new();
let data = vec![1.0f32, 2.0, 3.0];
let tensor = Tensor::from_data(data, vec![3], DeviceType::Cpu)
.expect("tensor creation should succeed");
let id = tracker
.track(tensor.clone(), "report_test")
.expect("tracking should succeed");
let result = tensor
.mul_scalar(2.0)
.expect("scalar multiplication should succeed");
tracker
.record_operation(id, "mul_scalar", vec![2.0], &result)
.expect("tensor creation should succeed");
let report = tracker
.generate_report(id)
.expect("report generation should succeed");
assert!(report.contains("report_test"));
assert!(report.contains("mul_scalar"));
assert!(report.contains("Operations performed: 1"));
}
#[test]
fn test_tracking_config() {
let config = TrackingConfig::minimal();
let mut tracker = TensorTracker::with_config(config);
let tensor = creation::ones::<f32>(&[2, 2]).expect("ones creation should succeed");
let id = tracker
.track(tensor, "config_test")
.expect("tracking should succeed");
assert_eq!(tracker.tracked_ids().len(), 1);
assert!(id == 0);
}
#[test]
fn test_operation_filtering() {
let config = TrackingConfig::filtered(vec!["add".to_string(), "mul".to_string()]);
let mut tracker = TensorTracker::with_config(config);
let tensor = creation::ones::<f32>(&[2, 2]).expect("ones creation should succeed");
let id = tracker
.track(tensor.clone(), "filter_test")
.expect("tracking should succeed");
let result = tensor
.mul_scalar(2.0)
.expect("scalar multiplication should succeed");
tracker
.record_operation(id, "mul", vec![2.0], &result)
.expect("multiplication should succeed");
let result2 = result
.add_scalar(1.0)
.expect("scalar addition should succeed");
tracker
.record_operation(id, "sub", vec![1.0], &result2)
.expect("multiplication should succeed");
let tensors = tracker.tensors.read().expect("lock should not be poisoned");
let tracked = tensors.get(&id).expect("get should succeed");
assert_eq!(tracked.operations.len(), 1); assert_eq!(tracked.operations[0].operation, "mul");
}
}