use super::types::*;
use anyhow::Result;
use std::sync::Arc;
use tracing::info;
impl crate::DebugSession {
pub async fn enable_streaming(
&mut self,
config: StreamingDebugConfig,
) -> Result<Arc<StreamingDebugger>> {
let streaming_debugger = Arc::new(StreamingDebugger::new(config));
streaming_debugger.start().await?;
info!("Enabled streaming for debug session {}", self.id());
Ok(streaming_debugger)
}
}
#[macro_export]
macro_rules! stream_tensor {
($streamer:expr, $session_id:expr, $tensor:expr, $name:expr) => {{
let tensor_id = uuid::Uuid::new_v4();
let shape = $tensor.shape().to_vec();
let values: Vec<f64> = $tensor.iter().map(|&x| x.into()).collect();
$streamer
.send_tensor_data($session_id, tensor_id, $name.to_string(), shape, values)
.await
}};
}
#[macro_export]
macro_rules! stream_gradients {
($streamer:expr, $session_id:expr, $layer_name:expr, $gradients:expr) => {{
let gradient_values: Vec<f64> = $gradients.iter().map(|&x| x.into()).collect();
$streamer
.send_gradient_flow($session_id, $layer_name.to_string(), &gradient_values)
.await
}};
}
#[macro_export]
macro_rules! stream_anomaly {
(
$streamer:expr, $session_id:expr, $anomaly_type:expr, $severity:expr,
$description:expr
) => {{
$streamer
.send_anomaly_detected(
$session_id,
$anomaly_type,
$severity,
$description.to_string(),
0.95,
vec![],
)
.await
}};
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::time::{Duration, SystemTime};
use uuid::Uuid;
#[tokio::test]
async fn test_streaming_debugger_creation() {
let config = StreamingDebugConfig::default();
let debugger = StreamingDebugger::new(config);
assert!(!*debugger.is_running.read().await);
}
#[tokio::test(flavor = "multi_thread")]
#[ignore] async fn test_start_stop_streaming() {
let config = StreamingDebugConfig {
stream_interval_ms: 50,
..Default::default()
};
let debugger = StreamingDebugger::new(config);
let test_result = tokio::time::timeout(Duration::from_secs(3), async {
assert!(debugger.start().await.is_ok());
assert!(*debugger.is_running.read().await);
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(debugger.stop().await.is_ok());
assert!(!*debugger.is_running.read().await);
tokio::time::sleep(Duration::from_millis(100)).await;
Ok::<(), anyhow::Error>(())
})
.await;
assert!(test_result.is_ok(), "Test timed out");
assert!(test_result.expect("test should not time out").is_ok());
}
#[tokio::test(flavor = "multi_thread")]
#[ignore] async fn test_subscription() {
let config = StreamingDebugConfig {
stream_interval_ms: 50,
..Default::default()
};
let debugger = StreamingDebugger::new(config);
let test_result = tokio::time::timeout(Duration::from_secs(3), async {
debugger.start().await.expect("start should succeed");
let subscription = debugger
.subscribe(
"test_subscriber".to_string(),
StreamFormat::Json,
StreamFilter::default(),
)
.await
.expect("subscribe should succeed");
assert_eq!(debugger.get_subscribers().await.len(), 1);
debugger
.unsubscribe(subscription.subscriber_id())
.await
.expect("unsubscribe should succeed");
assert_eq!(debugger.get_subscribers().await.len(), 0);
debugger.stop().await.expect("stop should succeed");
tokio::time::sleep(Duration::from_millis(100)).await;
Ok::<(), anyhow::Error>(())
})
.await;
assert!(test_result.is_ok(), "Test timed out");
assert!(test_result.expect("test should not time out").is_ok());
}
#[tokio::test]
async fn test_tensor_statistics() {
let config = StreamingDebugConfig::default();
let debugger = StreamingDebugger::new(config);
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let stats = debugger.compute_tensor_statistics(&values);
assert_eq!(stats.mean, 3.0);
assert!(stats.std > 0.0);
assert_eq!(stats.min, 1.0);
assert_eq!(stats.max, 5.0);
assert_eq!(stats.zero_count, 0);
}
#[tokio::test]
async fn test_gradient_statistics() {
let config = StreamingDebugConfig::default();
let debugger = StreamingDebugger::new(config);
let gradients = vec![0.1, -0.2, 0.3, -0.1, 0.0];
let stats = debugger.compute_gradient_statistics(&gradients);
assert!(stats.l1_norm > 0.0);
assert!(stats.l2_norm > 0.0);
assert_eq!(stats.max_grad, 0.3);
assert_eq!(stats.min_grad, -0.2);
}
#[tokio::test]
async fn test_event_filtering() {
let session_id1 = Uuid::new_v4();
let session_id2 = Uuid::new_v4();
let filter = StreamFilter {
session_ids: Some(vec![session_id1]),
event_types: Some(vec!["TensorData".to_string()]),
min_severity: None,
time_range: None,
custom_filters: HashMap::new(),
};
let matching_event = StreamEvent::TensorData {
session_id: session_id1,
tensor_id: Uuid::new_v4(),
name: "test".to_string(),
shape: vec![2, 2],
values: vec![1.0, 2.0, 3.0, 4.0],
statistics: TensorStatistics {
mean: 2.5,
std: 1.29,
min: 1.0,
max: 4.0,
nan_count: 0,
inf_count: 0,
zero_count: 0,
sparsity: 0.0,
},
timestamp: SystemTime::now(),
};
let non_matching_event = StreamEvent::TensorData {
session_id: session_id2,
tensor_id: Uuid::new_v4(),
name: "test".to_string(),
shape: vec![2, 2],
values: vec![1.0, 2.0, 3.0, 4.0],
statistics: TensorStatistics {
mean: 2.5,
std: 1.29,
min: 1.0,
max: 4.0,
nan_count: 0,
inf_count: 0,
zero_count: 0,
sparsity: 0.0,
},
timestamp: SystemTime::now(),
};
assert!(StreamSubscription::matches_filter(&matching_event, &filter));
assert!(!StreamSubscription::matches_filter(
&non_matching_event,
&filter
));
}
}
pub trait AggregationRule {
fn aggregate(&self, events: &[StreamEvent]) -> Result<f64>;
fn rule_name(&self) -> &str;
}
#[cfg(test)]
mod enhanced_tests {
use super::*;
use std::time::{Duration, Instant, SystemTime};
use uuid::Uuid;
#[tokio::test(flavor = "multi_thread")]
#[ignore] async fn test_enhanced_streaming_debugger() {
let base_config = StreamingDebugConfig {
stream_interval_ms: 50,
..Default::default()
};
let adaptive_config = AdaptiveStreamingConfig {
monitoring_interval_ms: 500,
..Default::default()
};
let aggregation_config = RealTimeAggregationConfig {
window_size_seconds: 1,
..Default::default()
};
let buffering_config = IntelligentBufferingConfig::default();
let mut debugger = EnhancedStreamingDebugger::new(
base_config,
adaptive_config,
aggregation_config,
buffering_config,
);
let test_result = tokio::time::timeout(Duration::from_secs(5), async {
assert!(debugger.start_enhanced_streaming().await.is_ok());
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(debugger.stop_enhanced_streaming().await.is_ok());
tokio::time::sleep(Duration::from_millis(200)).await;
Ok::<(), anyhow::Error>(())
})
.await;
assert!(test_result.is_ok(), "Test timed out");
assert!(test_result.expect("test should not time out").is_ok());
}
#[tokio::test]
async fn test_network_condition_monitor() {
let mut monitor = NetworkConditionMonitor::new();
monitor.update_conditions().await;
assert!(monitor.quality_score >= 0.0);
assert!(monitor.quality_score <= 1.0);
assert!(!monitor.history.is_empty());
}
#[test]
fn test_buffer_performance_predictor() {
let predictor = BufferPerformancePredictor {
performance_history: vec![
BufferPerformancePoint {
buffer_size: 500,
throughput: 100.0,
latency: 50.0,
memory_usage: 50000,
timestamp: Instant::now(),
},
BufferPerformancePoint {
buffer_size: 1000,
throughput: 150.0,
latency: 40.0,
memory_usage: 100000,
timestamp: Instant::now(),
},
],
model_params: vec![],
accuracy: 0.8,
};
let optimal_size =
predictor.predict_optimal_size().expect("predict_optimal_size should succeed");
assert_eq!(optimal_size, 1000);
}
#[tokio::test]
async fn test_importance_scorer() {
let scorer = ImportanceScorer::new();
let critical_event = StreamEvent::AnomalyDetected {
session_id: Uuid::new_v4(),
anomaly_type: AnomalyType::GradientExplosion,
severity: AnomalySeverity::Critical,
description: "Critical gradient explosion".to_string(),
confidence: 0.95,
affected_components: vec!["layer1".to_string()],
timestamp: SystemTime::now(),
};
let low_event = StreamEvent::AnomalyDetected {
session_id: Uuid::new_v4(),
anomaly_type: AnomalyType::TrainingStagnation,
severity: AnomalySeverity::Low,
description: "Slow convergence detected".to_string(),
confidence: 0.6,
affected_components: vec!["layer2".to_string()],
timestamp: SystemTime::now(),
};
let critical_score = scorer
.calculate_importance(&critical_event)
.await
.expect("calculate_importance should succeed for critical event");
let low_score = scorer
.calculate_importance(&low_event)
.await
.expect("calculate_importance should succeed for low event");
assert!(critical_score > low_score);
}
}