mod tests {
use super::super::*;
use std::collections::HashMap;
#[test]
fn test_forecasting_engine_creation() {
let engine = ForecastingEngine::new();
assert!(engine.is_ok());
}
#[test]
fn test_anomaly_detector_creation() {
let detector = AnomalyDetector::new();
assert!(detector.is_ok());
}
#[test]
fn test_causal_analyzer_creation() {
let analyzer = CausalAnalyzer::new();
assert!(analyzer.is_ok());
}
#[test]
fn test_density_adaptive_parameter_selection() -> Result<()> {
let dbscan = AdaptiveDBSCAN::new(DBSCANConfig::default())?;
let mut points = Vec::new();
for i in 0..10 {
points.push(Point::new(i, &[i as f64], None, i as f64));
}
for i in 10..13 {
points.push(Point::new(i, &[i as f64 * 10.0], None, i as f64));
}
let params = dbscan.calculate_adaptive_parameters(&points);
let (epsilon, min_samples) = params;
assert!(epsilon > 0.0);
assert!(epsilon < 100.0);
assert!(min_samples >= 2);
assert!(min_samples <= points.len());
Ok(())
}
#[test]
fn test_anomaly_detection_accuracy() -> Result<()> {
let mut dbscan = AdaptiveDBSCAN::new(DBSCANConfig::default())?;
let values = vec![1.0, 1.1, 0.9, 1.0, 0.95, 1.05, 50.0, 1.0, 0.98, 1.02];
let timestamps: Vec<f64> = (0..values.len()).map(|i| i as f64).collect();
let labels = dbscan.detect_anomalies_dbscan(&values, ×tamps);
let noise_count = labels
.iter()
.filter(|&label| matches!(label, ClusterLabel::Noise))
.count();
assert!(noise_count >= 1, "Should detect at least one anomaly");
let cluster_count = labels
.iter()
.filter(|&label| matches!(label, ClusterLabel::Cluster(_)))
.count();
assert!(
cluster_count >= values.len() - 3,
"Should classify most points as clusters"
);
Ok(())
}
#[test]
fn test_streaming_cluster_updates() -> Result<()> {
let mut dbscan = AdaptiveDBSCAN::new(DBSCANConfig {
window_size: 5,
density: 0.1,
min_cluster_size: 2,
max_distance: 1.0,
})?;
let mut labels = Vec::new();
for i in 0..8 {
let point = Point::new(i, &[i as f64], None, i as f64);
let label = dbscan.update_streaming_clusters(point);
labels.push(label);
assert!(dbscan.streaming_clusters.window.len() <= 5);
}
let has_clusters = labels
.iter()
.any(|&label| matches!(label, ClusterLabel::Cluster(_)));
let has_noise = labels
.iter()
.any(|&label| matches!(label, ClusterLabel::Noise));
assert!(
has_clusters || has_noise,
"Should produce some clustering results"
);
Ok(())
}
#[test]
fn test_kdtree_neighbor_queries() -> Result<()> {
let mut points = Vec::new();
for i in 0..10 {
let features = vec![i as f64, (i * 2) as f64];
points.push(Point::new(i, &[i as f64], None, i as f64));
points[i].features = features;
}
let kd_tree = KDTree::build(&points);
let center = vec![5.0, 10.0];
let neighbors = kd_tree.find_neighbors(¢er, 3.0);
assert!(!neighbors.is_empty());
for neighbor in &neighbors {
let distance = calculate_distance(¢er, &neighbor.features);
assert!(distance <= 3.0);
}
Ok(())
}
#[test]
fn test_dbscan_edge_cases() -> Result<()> {
let dbscan = AdaptiveDBSCAN::new(DBSCANConfig::default())?;
let empty_labels = dbscan.apply_dbscan(&[], (1.0, 2));
assert!(empty_labels.is_empty());
let single_point = vec![Point::new(0, &[1.0], None, 0.0)];
let single_labels = dbscan.apply_dbscan(&single_point, (1.0, 1));
assert_eq!(single_labels.len(), 1);
assert!(matches!(single_labels[0], ClusterLabel::Cluster(_)));
let mut high_dim_point = Point::new(0, &[1.0], None, 0.0);
high_dim_point.features = vec![1.0; 20]; let high_dim_labels = dbscan.apply_dbscan(&[high_dim_point], (2.0, 1));
assert_eq!(high_dim_labels.len(), 1);
Ok(())
}
#[test]
fn test_multidimensional_feature_handling() -> Result<()> {
let dbscan = AdaptiveDBSCAN::new(DBSCANConfig::default())?;
let mut points = Vec::new();
for i in 0..5 {
let mut point = Point::new(i, &[1.0], None, i as f64);
point.features = vec![1.0, 1.0, 1.0];
points.push(point);
}
for i in 5..10 {
let mut point = Point::new(i, &[2.0], None, i as f64);
point.features = vec![2.0, 2.0, 2.0];
points.push(point);
}
let mut outlier = Point::new(10, &[5.0], None, 10.0);
outlier.features = vec![5.0, 5.0, 5.0];
points.push(outlier);
let labels = dbscan.apply_dbscan(&points, (0.5, 2));
let cluster_ids: std::collections::HashSet<usize> = labels
.iter()
.filter_map(|label| {
if let ClusterLabel::Cluster(id) = label {
Some(*id)
} else {
None
}
})
.collect();
assert!(
!cluster_ids.is_empty(),
"Should identify at least one cluster"
);
let _outlier_label = labels[10];
Ok(())
}
#[test]
fn test_dbscan_integration_with_anomaly_detector() -> Result<()> {
let mut detector = AnomalyDetector::new()?;
let mut data = HashMap::new();
data.insert(
"test_series".to_string(),
vec![
1.0, 1.1, 0.9, 1.0, 0.95, 1.05, 25.0, 1.0, 0.98, 1.02, 1.1, 0.89,
],
);
let anomalies = detector.detect_anomalies(&data)?;
assert!(!anomalies.is_empty());
let anomaly_result = &anomalies[0];
assert_eq!(anomaly_result.variable, "test_series");
assert_eq!(anomaly_result.method, "DBSCAN");
assert!(!anomaly_result.anomaly_indices.is_empty());
for &score in &anomaly_result.anomaly_scores {
assert!(score >= 0.0);
}
assert!(anomaly_result.confidence >= 0.0 && anomaly_result.confidence <= 1.0);
Ok(())
}
#[test]
fn test_forecast_generation() -> Result<()> {
let mut engine = ForecastingEngine::new()?;
let mut data = HashMap::new();
data.insert(
"trend".to_string(),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
);
let forecasts = engine.forecast(&data)?;
assert!(!forecasts.is_empty());
let forecast = &forecasts[0];
assert_eq!(forecast.variable, "trend");
assert_eq!(forecast.point_forecasts.len(), 10);
Ok(())
}
#[test]
fn test_anomaly_detection() -> Result<()> {
let mut detector = AnomalyDetector::new()?;
let mut data = HashMap::new();
let series = vec![1.0, 1.1, 0.9, 1.0, 0.95, 1.05, 50.0, 1.0, 0.98, 1.02];
data.insert("test".to_string(), series);
let anomalies = detector.detect_anomalies(&data)?;
assert!(!anomalies.is_empty());
let anomaly = &anomalies[0];
assert_eq!(anomaly.variable, "test");
assert!(!anomaly.anomaly_indices.is_empty());
Ok(())
}
#[test]
fn test_causal_analysis() -> Result<()> {
let analyzer = CausalAnalyzer::new()?;
let mut data = HashMap::new();
let x: Vec<f64> = (0..20).map(|i| i as f64).collect();
let y: Vec<f64> = x.iter().map(|&val| 2.0 * val + 1.0).collect();
data.insert("x".to_string(), x);
data.insert("y".to_string(), y);
let _causal_results = analyzer.analyze_causality(&data)?;
Ok(())
}
#[test]
fn test_comprehensive_analysis() -> Result<()> {
let mut data = HashMap::new();
data.insert("series1".to_string(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
data.insert("series2".to_string(), vec![2.0, 4.0, 6.0, 8.0, 10.0]);
let config = PredictiveConfig::default();
let results = run_predictive_analysis(&data, config)?;
assert!(!results.forecasts.is_empty());
assert!(!results.anomalies.is_empty());
assert_eq!(results.metadata.variables_analyzed, 2);
Ok(())
}
#[test]
fn test_ets_seasonality_detection() -> Result<()> {
let engine = ForecastingEngine::new()?;
let seasonal_data: Vec<f64> = (0..20)
.map(|i| {
let base = 10.0;
let trend = i as f64 * 0.5;
let seasonal = [0.0, 2.0, -1.0, 1.0][i % 4];
base + trend + seasonal
})
.collect();
let seasonality = engine.detect_seasonality(&seasonal_data)?;
assert!(
seasonality.strength > 0.05,
"Should detect some seasonal strength"
);
Ok(())
}
#[test]
fn test_ets_additive_vs_multiplicative_selection() -> Result<()> {
let engine = ForecastingEngine::new()?;
let additive_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let result = engine.select_and_fit_ets_model(&additive_data, 0)?;
assert!(result.aic.is_finite());
assert!(result.fit_quality >= 0.0 && result.fit_quality <= 1.0);
let multiplicative_data = vec![1.0, 2.0, 4.0, 8.0, 16.0, 32.0];
let result2 = engine.select_and_fit_ets_model(&multiplicative_data, 0)?;
assert!(result2.aic.is_finite());
assert!(result2.fit_quality >= 0.0 && result2.fit_quality <= 1.0);
Ok(())
}
#[test]
fn test_ets_confidence_intervals() -> Result<()> {
let engine = ForecastingEngine::new()?;
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let result = engine.select_and_fit_ets_model(&data, 0)?;
let forecasts = engine.forecast_ets(&result.model, &result.state, 5)?;
let (lower_bounds, upper_bounds) =
engine.calculate_confidence_intervals(&forecasts, &result.state, 5)?;
assert_eq!(forecasts.len(), 5);
assert_eq!(lower_bounds.len(), 5);
assert_eq!(upper_bounds.len(), 5);
for i in 0..5 {
assert!(lower_bounds[i] <= forecasts[i]);
assert!(upper_bounds[i] >= forecasts[i]);
}
Ok(())
}
#[test]
fn test_ets_parameter_estimation_convergence() -> Result<()> {
let engine = ForecastingEngine::new()?;
let data = vec![1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5];
let result = engine.fit_ets_model(
&data,
&ETSModelSpec {
error: ETSErrorType::Additive,
trend: ETSTrendType::Additive,
seasonal: ETSSeasonalType::None,
},
)?;
assert!(result.model.alpha > 0.0 && result.model.alpha < 1.0);
assert!(result.model.beta > 0.0 && result.model.beta < 1.0);
assert!(result.model.gamma >= 0.0 && result.model.gamma < 1.0);
assert!(result.aic.is_finite());
assert!(result.log_likelihood.is_finite());
Ok(())
}
#[test]
fn test_ets_edge_cases() -> Result<()> {
let engine = ForecastingEngine::new()?;
let single_obs = vec![5.0];
let result = engine.select_and_fit_ets_model(&single_obs, 0);
assert!(result.is_err());
let two_obs = vec![1.0, 2.0];
let result = engine.select_and_fit_ets_model(&two_obs, 0);
assert!(result.is_ok());
let constant_data = vec![5.0; 10];
let result = engine.select_and_fit_ets_model(&constant_data, 0)?;
assert!(result.aic.is_finite());
let trend_data: Vec<f64> = (0..10).map(|i| i as f64).collect();
let result = engine.select_and_fit_ets_model(&trend_data, 0)?;
assert!(result.fit_quality >= 0.0);
Ok(())
}
#[test]
fn test_ets_incremental_updates() -> Result<()> {
let engine = ForecastingEngine::new()?;
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let model = engine.initialize_parameters(
&data,
&ETSModelSpec {
error: ETSErrorType::Additive,
trend: ETSTrendType::Additive,
seasonal: ETSSeasonalType::None,
},
)?;
let mut state = engine.initialize_state(&data, &model)?;
let new_observation = 6.0;
state = engine.update_ets_state(&state, new_observation, &model)?;
assert_eq!(state.n_obs, 6);
assert_eq!(state.last_observation, 6.0);
Ok(())
}
#[test]
fn test_ets_model_types() {
assert_eq!(ETSErrorType::Additive.as_str(), "A");
assert_eq!(ETSErrorType::Multiplicative.as_str(), "M");
assert_eq!(ETSTrendType::None.as_str(), "N");
assert_eq!(ETSTrendType::Additive.as_str(), "A");
assert_eq!(ETSTrendType::AdditiveDamped.as_str(), "Ad");
assert_eq!(ETSSeasonalType::None.as_str(), "N");
assert_eq!(ETSSeasonalType::Additive.as_str(), "A");
assert_eq!(ETSSeasonalType::Multiplicative.as_str(), "M");
}
}