use crate::analysis::types::*;
use crate::error::{IntegrateError, IntegrateResult};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::{Rng, RngExt};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct BifurcationPredictionNetwork {
pub architecture: NetworkArchitecture,
pub model_parameters: ModelParameters,
pub training_config: super::training::TrainingConfiguration,
pub feature_extraction: super::features::FeatureExtraction,
pub performance_metrics: super::uncertainty::PerformanceMetrics,
pub uncertainty_quantification: super::uncertainty::UncertaintyQuantification,
}
#[derive(Debug, Clone)]
pub struct NetworkArchitecture {
pub input_size: usize,
pub hidden_layers: Vec<usize>,
pub output_size: usize,
pub activation_functions: Vec<ActivationFunction>,
pub dropoutrates: Vec<f64>,
pub batch_normalization: Vec<bool>,
pub skip_connections: Vec<SkipConnection>,
}
#[derive(Debug, Clone, Copy)]
pub enum ActivationFunction {
ReLU,
LeakyReLU(f64),
Tanh,
Sigmoid,
Softmax,
Swish,
GELU,
ELU(f64),
}
#[derive(Debug, Clone)]
pub struct SkipConnection {
pub from_layer: usize,
pub to_layer: usize,
pub connection_type: ConnectionType,
}
#[derive(Debug, Clone, Copy)]
pub enum ConnectionType {
Addition,
Concatenation,
Gated,
}
#[derive(Debug, Clone)]
pub struct ModelParameters {
pub weights: Vec<Array2<f64>>,
pub biases: Vec<Array1<f64>>,
pub batch_norm_params: Vec<BatchNormParams>,
pub dropout_masks: Vec<Array1<bool>>,
}
#[derive(Debug, Clone)]
pub struct BatchNormParams {
pub scale: Array1<f64>,
pub shift: Array1<f64>,
pub running_mean: Array1<f64>,
pub running_var: Array1<f64>,
}
#[derive(Debug, Clone)]
pub struct BifurcationPrediction {
pub bifurcation_type: BifurcationType,
pub predicted_parameter: f64,
pub confidence: f64,
pub raw_output: Array1<f64>,
pub uncertainty_estimate: Option<UncertaintyEstimate>,
}
#[derive(Debug, Clone)]
pub struct UncertaintyEstimate {
pub epistemic_uncertainty: f64,
pub aleatoric_uncertainty: f64,
pub total_uncertainty: f64,
pub confidence_interval: (f64, f64),
}
impl BifurcationPredictionNetwork {
pub fn new(input_size: usize, hidden_layers: Vec<usize>, output_size: usize) -> Self {
let architecture = NetworkArchitecture {
input_size,
hidden_layers: hidden_layers.clone(),
output_size,
activation_functions: vec![ActivationFunction::ReLU; hidden_layers.len() + 1],
dropoutrates: vec![0.0; hidden_layers.len() + 1],
batch_normalization: vec![false; hidden_layers.len() + 1],
skip_connections: Vec::new(),
};
let model_parameters = Self::initialize_parameters(&architecture);
Self {
architecture,
model_parameters,
training_config: super::training::TrainingConfiguration::default(),
feature_extraction: super::features::FeatureExtraction::default(),
performance_metrics: super::uncertainty::PerformanceMetrics::default(),
uncertainty_quantification: super::uncertainty::UncertaintyQuantification::default(),
}
}
fn initialize_parameters(arch: &NetworkArchitecture) -> ModelParameters {
let mut weights = Vec::new();
let mut biases = Vec::new();
let mut prev_size = arch.input_size;
for &layer_size in &arch.hidden_layers {
weights.push(Array2::zeros((prev_size, layer_size)));
biases.push(Array1::zeros(layer_size));
prev_size = layer_size;
}
weights.push(Array2::zeros((prev_size, arch.output_size)));
biases.push(Array1::zeros(arch.output_size));
ModelParameters {
weights,
biases,
batch_norm_params: Vec::new(),
dropout_masks: Vec::new(),
}
}
pub fn forward(&self, input: &Array1<f64>) -> IntegrateResult<Array1<f64>> {
let mut activation = input.clone();
for (i, (weights, bias)) in self
.model_parameters
.weights
.iter()
.zip(&self.model_parameters.biases)
.enumerate()
{
activation = weights.t().dot(&activation) + bias;
activation = self.apply_activation_function(
&activation,
self.architecture.activation_functions[i],
)?;
if self.architecture.dropoutrates[i] > 0.0 {
activation = Self::apply_dropout(&activation, self.architecture.dropoutrates[i])?;
}
}
Ok(activation)
}
fn apply_activation_function(
&self,
x: &Array1<f64>,
func: ActivationFunction,
) -> IntegrateResult<Array1<f64>> {
let result = match func {
ActivationFunction::ReLU => x.mapv(|v| v.max(0.0)),
ActivationFunction::LeakyReLU(alpha) => x.mapv(|v| if v > 0.0 { v } else { alpha * v }),
ActivationFunction::Tanh => x.mapv(|v| v.tanh()),
ActivationFunction::Sigmoid => x.mapv(|v| 1.0 / (1.0 + (-v).exp())),
ActivationFunction::Softmax => {
let exp_x = x.mapv(|v| v.exp());
let sum = exp_x.sum();
exp_x / sum
}
ActivationFunction::Swish => x.mapv(|v| v / (1.0 + (-v).exp())),
ActivationFunction::GELU => x.mapv(|v| 0.5 * v * (1.0 + (v / (2.0_f64).sqrt()).tanh())),
ActivationFunction::ELU(alpha) => {
x.mapv(|v| if v > 0.0 { v } else { alpha * (v.exp() - 1.0) })
}
};
Ok(result)
}
fn apply_dropout(x: &Array1<f64>, dropout_rate: f64) -> IntegrateResult<Array1<f64>> {
if dropout_rate == 0.0 {
return Ok(x.clone());
}
let mut rng = scirs2_core::random::rng();
let mask: Array1<f64> = Array1::from_shape_fn(x.len(), |_| {
if rng.random::<f64>() < dropout_rate {
0.0
} else {
1.0 / (1.0 - dropout_rate)
}
});
Ok(x * &mask)
}
pub fn train(
&mut self,
training_data: &[(Array1<f64>, Array1<f64>)],
validation_data: Option<&[(Array1<f64>, Array1<f64>)]>,
) -> IntegrateResult<()> {
let mut training_metrics = Vec::new();
let mut validation_metrics = Vec::new();
for epoch in 0..self.training_config.epochs {
let epoch_loss = self.train_epoch(training_data)?;
let epoch_metric = super::uncertainty::EpochMetrics {
epoch,
loss: epoch_loss,
accuracy: None,
precision: None,
recall: None,
f1_score: None,
learning_rate: self.get_current_learning_rate(epoch),
};
training_metrics.push(epoch_metric.clone());
if let Some(val_data) = validation_data {
let val_loss = self.evaluate(val_data)?;
let val_metric = super::uncertainty::EpochMetrics {
epoch,
loss: val_loss,
accuracy: None,
precision: None,
recall: None,
f1_score: None,
learning_rate: epoch_metric.learning_rate,
};
validation_metrics.push(val_metric);
}
if self.should_early_stop(&training_metrics, &validation_metrics) {
break;
}
}
self.performance_metrics.training_metrics = training_metrics;
self.performance_metrics.validation_metrics = validation_metrics;
Ok(())
}
fn train_epoch(
&mut self,
training_data: &[(Array1<f64>, Array1<f64>)],
) -> IntegrateResult<f64> {
let mut total_loss = 0.0;
let batch_size = self.training_config.batch_size;
for batch_start in (0..training_data.len()).step_by(batch_size) {
let batch_end = (batch_start + batch_size).min(training_data.len());
let batch = &training_data[batch_start..batch_end];
let batch_loss = self.train_batch(batch)?;
total_loss += batch_loss;
}
Ok(total_loss / (training_data.len() as f64 / batch_size as f64))
}
fn train_batch(&mut self, batch: &[(Array1<f64>, Array1<f64>)]) -> IntegrateResult<f64> {
let mut total_loss = 0.0;
for (input, target) in batch {
let prediction = self.forward(input)?;
let loss = self.calculate_loss(&prediction, target)?;
total_loss += loss;
self.backward(&prediction, target, input)?;
}
Ok(total_loss / batch.len() as f64)
}
fn calculate_loss(
&self,
prediction: &Array1<f64>,
target: &Array1<f64>,
) -> IntegrateResult<f64> {
match self.training_config.loss_function {
super::training::LossFunction::MSE => {
let diff = prediction - target;
Ok(diff.dot(&diff) / prediction.len() as f64)
}
super::training::LossFunction::CrossEntropy => {
let epsilon = 1e-15;
let pred_clipped = prediction.mapv(|p| p.max(epsilon).min(1.0 - epsilon));
let loss = -target
.iter()
.zip(pred_clipped.iter())
.map(|(&t, &p)| t * p.ln())
.sum::<f64>();
Ok(loss)
}
super::training::LossFunction::FocalLoss(alpha, gamma) => {
let epsilon = 1e-15;
let pred_clipped = prediction.mapv(|p| p.max(epsilon).min(1.0 - epsilon));
let loss = -alpha
* target
.iter()
.zip(pred_clipped.iter())
.map(|(&t, &p)| t * (1.0 - p).powf(gamma) * p.ln())
.sum::<f64>();
Ok(loss)
}
super::training::LossFunction::HuberLoss(delta) => {
let diff = prediction - target;
let abs_diff = diff.mapv(|d| d.abs());
let loss = abs_diff
.iter()
.map(|&d| {
if d <= delta {
0.5 * d * d
} else {
delta * d - 0.5 * delta * delta
}
})
.sum::<f64>();
Ok(loss / prediction.len() as f64)
}
super::training::LossFunction::WeightedMSE => {
let diff = prediction - target;
Ok(diff.dot(&diff) / prediction.len() as f64)
}
}
}
fn backward(
&mut self,
_prediction: &Array1<f64>,
_target: &Array1<f64>,
_input: &Array1<f64>,
) -> IntegrateResult<()> {
Ok(())
}
pub fn evaluate(&self, test_data: &[(Array1<f64>, Array1<f64>)]) -> IntegrateResult<f64> {
let mut total_loss = 0.0;
for (input, target) in test_data {
let prediction = self.forward(input)?;
let loss = self.calculate_loss(&prediction, target)?;
total_loss += loss;
}
Ok(total_loss / test_data.len() as f64)
}
fn get_current_learning_rate(&self, epoch: usize) -> f64 {
match &self.training_config.learning_rate {
super::training::LearningRateSchedule::Constant(lr) => *lr,
super::training::LearningRateSchedule::ExponentialDecay {
initial_lr,
decay_rate,
decay_steps,
} => initial_lr * decay_rate.powf(epoch as f64 / *decay_steps as f64),
super::training::LearningRateSchedule::CosineAnnealing {
initial_lr,
min_lr,
cycle_length,
} => {
let cycle_pos = (epoch % cycle_length) as f64 / *cycle_length as f64;
min_lr
+ (initial_lr - min_lr) * (1.0 + (cycle_pos * std::f64::consts::PI).cos()) / 2.0
}
super::training::LearningRateSchedule::StepDecay {
initial_lr,
drop_rate,
epochs_drop,
} => initial_lr * drop_rate.powf((epoch / epochs_drop) as f64),
super::training::LearningRateSchedule::Adaptive { initial_lr, .. } => {
*initial_lr
}
}
}
fn should_early_stop(
&self,
_training_metrics: &[super::uncertainty::EpochMetrics],
_validation_metrics: &[super::uncertainty::EpochMetrics],
) -> bool {
if !self.training_config.early_stopping.enabled {
return false;
}
false
}
pub fn predict_bifurcation(
&self,
features: &Array1<f64>,
) -> IntegrateResult<BifurcationPrediction> {
let raw_output = self.forward(features)?;
let bifurcation_type = self.classify_bifurcation_type(&raw_output)?;
let confidence = self.calculate_confidence(&raw_output)?;
let predicted_parameter = raw_output[0];
Ok(BifurcationPrediction {
bifurcation_type,
predicted_parameter,
confidence,
raw_output,
uncertainty_estimate: None,
})
}
fn classify_bifurcation_type(&self, output: &Array1<f64>) -> IntegrateResult<BifurcationType> {
let max_idx = output
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("Operation failed"))
.map(|(idx, _)| idx)
.unwrap_or(0);
let bifurcation_type = match max_idx {
0 => BifurcationType::Fold,
1 => BifurcationType::Transcritical,
2 => BifurcationType::Pitchfork,
3 => BifurcationType::Hopf,
4 => BifurcationType::PeriodDoubling,
5 => BifurcationType::Homoclinic,
_ => BifurcationType::Unknown,
};
Ok(bifurcation_type)
}
fn calculate_confidence(&self, output: &Array1<f64>) -> IntegrateResult<f64> {
let max_prob = output.iter().cloned().fold(0.0, f64::max);
Ok(max_prob)
}
}