use scirs2_core::random::prelude::*;
use scirs2_core::random::ChaCha8Rng;
use scirs2_core::random::{Rng, SeedableRng};
use std::collections::HashMap;
use std::time::{Duration, Instant};
use super::error::{RLEmbeddingError, RLEmbeddingResult};
#[derive(Debug, Clone)]
pub struct EmbeddingDQN {
pub q_network: EmbeddingNetwork,
pub target_network: EmbeddingNetwork,
pub config: NetworkConfig,
pub training_state: NetworkTrainingState,
}
#[derive(Debug, Clone)]
pub struct EmbeddingPolicyNetwork {
pub actor_network: EmbeddingNetwork,
pub critic_network: EmbeddingNetwork,
pub config: NetworkConfig,
pub training_state: NetworkTrainingState,
}
#[derive(Debug, Clone)]
pub struct EmbeddingNetwork {
pub layers: Vec<NetworkLayer>,
pub input_norm: NormalizationLayer,
pub output_scaling: NormalizationLayer,
pub metadata: NetworkMetadata,
}
#[derive(Debug, Clone)]
pub struct NetworkLayer {
pub weights: Vec<Vec<f64>>,
pub biases: Vec<f64>,
pub activation: ActivationFunction,
pub dropout_rate: f64,
pub batch_norm: Option<BatchNormalization>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ActivationFunction {
ReLU,
LeakyReLU(f64),
Tanh,
Sigmoid,
Swish,
Linear,
}
#[derive(Debug, Clone)]
pub struct BatchNormalization {
pub running_mean: Vec<f64>,
pub running_var: Vec<f64>,
pub gamma: Vec<f64>,
pub beta: Vec<f64>,
pub epsilon: f64,
pub momentum: f64,
}
#[derive(Debug, Clone)]
pub struct NormalizationLayer {
pub means: Vec<f64>,
pub stds: Vec<f64>,
pub mins: Vec<f64>,
pub maxs: Vec<f64>,
pub norm_type: NormalizationType,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NormalizationType {
StandardScore,
MinMax,
Robust,
None,
}
#[derive(Debug, Clone)]
pub struct NetworkConfig {
pub layer_sizes: Vec<usize>,
pub learning_rate: f64,
pub regularization: RegularizationConfig,
pub optimizer: OptimizerType,
pub loss_function: LossFunction,
}
#[derive(Debug, Clone)]
pub struct RegularizationConfig {
pub l1_strength: f64,
pub l2_strength: f64,
pub dropout_rate: f64,
pub early_stopping_patience: usize,
}
#[derive(Debug, Clone, PartialEq)]
pub enum OptimizerType {
SGD,
Adam { beta1: f64, beta2: f64 },
RMSprop { decay_rate: f64 },
AdaGrad,
}
#[derive(Debug, Clone, PartialEq)]
pub enum LossFunction {
MSE,
Huber { delta: f64 },
CrossEntropy,
MultiObjective,
}
#[derive(Debug, Clone)]
pub struct NetworkMetadata {
pub created_at: Instant,
pub training_history: Vec<TrainingEpoch>,
pub performance_metrics: NetworkPerformanceMetrics,
pub version: String,
}
#[derive(Debug, Clone)]
pub struct TrainingEpoch {
pub epoch: usize,
pub training_loss: f64,
pub validation_loss: f64,
pub learning_rate: f64,
pub duration: Duration,
pub metrics: HashMap<String, f64>,
}
#[derive(Debug, Clone)]
pub struct NetworkPerformanceMetrics {
pub best_validation_loss: f64,
pub convergence_rate: f64,
pub generalization_gap: f64,
pub parameter_efficiency: f64,
}
#[derive(Debug, Clone)]
pub struct NetworkTrainingState {
pub current_epoch: usize,
pub current_lr: f64,
pub optimizer_state: OptimizerState,
pub best_weights: Option<Vec<Vec<Vec<f64>>>>,
pub early_stopping_counter: usize,
}
#[derive(Debug, Clone)]
pub struct OptimizerState {
pub momentum_buffers: Vec<Vec<Vec<f64>>>,
pub first_moments: Vec<Vec<Vec<f64>>>,
pub second_moments: Vec<Vec<Vec<f64>>>,
pub iteration: usize,
}
impl EmbeddingDQN {
pub fn new(layer_sizes: &[usize], seed: Option<u64>) -> RLEmbeddingResult<Self> {
let q_network = EmbeddingNetwork::new(layer_sizes, seed)?;
let target_network = q_network.clone();
let config = NetworkConfig {
layer_sizes: layer_sizes.to_vec(),
learning_rate: 0.001,
regularization: RegularizationConfig {
l1_strength: 0.0001,
l2_strength: 0.001,
dropout_rate: 0.1,
early_stopping_patience: 100,
},
optimizer: OptimizerType::Adam {
beta1: 0.9,
beta2: 0.999,
},
loss_function: LossFunction::MSE,
};
let training_state = NetworkTrainingState {
current_epoch: 0,
current_lr: 0.001,
optimizer_state: OptimizerState {
momentum_buffers: Vec::new(),
first_moments: Vec::new(),
second_moments: Vec::new(),
iteration: 0,
},
best_weights: None,
early_stopping_counter: 0,
};
Ok(Self {
q_network,
target_network,
config,
training_state,
})
}
}
impl EmbeddingPolicyNetwork {
pub fn new(layer_sizes: &[usize], seed: Option<u64>) -> RLEmbeddingResult<Self> {
let actor_network = EmbeddingNetwork::new(layer_sizes, seed)?;
let critic_network = EmbeddingNetwork::new(layer_sizes, seed)?;
let config = NetworkConfig {
layer_sizes: layer_sizes.to_vec(),
learning_rate: 0.0001,
regularization: RegularizationConfig {
l1_strength: 0.0001,
l2_strength: 0.001,
dropout_rate: 0.1,
early_stopping_patience: 100,
},
optimizer: OptimizerType::Adam {
beta1: 0.9,
beta2: 0.999,
},
loss_function: LossFunction::MultiObjective,
};
let training_state = NetworkTrainingState {
current_epoch: 0,
current_lr: 0.0001,
optimizer_state: OptimizerState {
momentum_buffers: Vec::new(),
first_moments: Vec::new(),
second_moments: Vec::new(),
iteration: 0,
},
best_weights: None,
early_stopping_counter: 0,
};
Ok(Self {
actor_network,
critic_network,
config,
training_state,
})
}
}
impl EmbeddingNetwork {
pub fn new(layer_sizes: &[usize], seed: Option<u64>) -> RLEmbeddingResult<Self> {
if layer_sizes.len() < 2 {
return Err(RLEmbeddingError::ConfigurationError(
"Network must have at least input and output layers".to_string(),
));
}
let mut rng = match seed {
Some(s) => ChaCha8Rng::seed_from_u64(s),
None => ChaCha8Rng::seed_from_u64(thread_rng().random()),
};
let mut layers = Vec::new();
for i in 0..layer_sizes.len() - 1 {
let input_size = layer_sizes[i];
let output_size = layer_sizes[i + 1];
let mut weights = vec![vec![0.0; input_size]; output_size];
let scale = (2.0 / input_size as f64).sqrt();
for row in &mut weights {
for weight in row {
*weight = rng.random_range(-scale..scale);
}
}
let biases = vec![0.0; output_size];
let activation = if i == layer_sizes.len() - 2 {
ActivationFunction::Linear } else {
ActivationFunction::ReLU };
layers.push(NetworkLayer {
weights,
biases,
activation,
dropout_rate: 0.1,
batch_norm: None,
});
}
let input_size = layer_sizes[0];
let output_size = layer_sizes[layer_sizes.len() - 1];
let input_norm = NormalizationLayer {
means: vec![0.0; input_size],
stds: vec![1.0; input_size],
mins: vec![0.0; input_size],
maxs: vec![1.0; input_size],
norm_type: NormalizationType::StandardScore,
};
let output_scaling = NormalizationLayer {
means: vec![0.0; output_size],
stds: vec![1.0; output_size],
mins: vec![0.0; output_size],
maxs: vec![1.0; output_size],
norm_type: NormalizationType::None,
};
let metadata = NetworkMetadata {
created_at: Instant::now(),
training_history: Vec::new(),
performance_metrics: NetworkPerformanceMetrics {
best_validation_loss: f64::INFINITY,
convergence_rate: 0.0,
generalization_gap: 0.0,
parameter_efficiency: 0.0,
},
version: "1.0.0".to_string(),
};
Ok(Self {
layers,
input_norm,
output_scaling,
metadata,
})
}
pub fn forward(&self, input: &[f64]) -> RLEmbeddingResult<Vec<f64>> {
let mut activations = input.to_vec();
for layer in &self.layers {
activations = self.layer_forward(&activations, layer)?;
}
Ok(activations)
}
fn layer_forward(&self, input: &[f64], layer: &NetworkLayer) -> RLEmbeddingResult<Vec<f64>> {
if input.len() != layer.weights[0].len() {
return Err(RLEmbeddingError::NeuralNetworkError(format!(
"Input size {} doesn't match layer input size {}",
input.len(),
layer.weights[0].len()
)));
}
let mut output = Vec::new();
for (neuron_weights, &bias) in layer.weights.iter().zip(&layer.biases) {
let mut activation = bias;
for (&inp, &weight) in input.iter().zip(neuron_weights) {
activation += inp * weight;
}
activation = match layer.activation {
ActivationFunction::ReLU => activation.max(0.0),
ActivationFunction::LeakyReLU(alpha) => {
if activation > 0.0 {
activation
} else {
alpha * activation
}
}
ActivationFunction::Tanh => activation.tanh(),
ActivationFunction::Sigmoid => 1.0 / (1.0 + (-activation).exp()),
ActivationFunction::Swish => activation / (1.0 + (-activation).exp()),
ActivationFunction::Linear => activation,
};
output.push(activation);
}
Ok(output)
}
}