hextral/
lib.rs

1use nalgebra::{DVector, DMatrix};
2use rand::{Rng, thread_rng};
3use rand::seq::SliceRandom;
4use futures::future::join_all;
5use serde::{Serialize, Deserialize};
6use std::path::Path;
7use tokio::fs;
8
9pub mod activation;
10pub mod optimizer;
11
12#[cfg(feature = "datasets")]
13pub mod dataset;
14
15pub use activation::ActivationFunction;
16pub use optimizer::{Optimizer, OptimizerState};
17
18#[cfg(feature = "datasets")]
19pub use dataset::{Dataset, DatasetLoader, DatasetError, PreprocessingConfig, FillStrategy};
20
21#[derive(Debug, Clone)]
22pub struct EarlyStopping {
23    pub patience: usize,
24    pub min_delta: f64,
25    pub best_loss: f64,
26    pub counter: usize,
27    pub restore_best_weights: bool,
28}
29
30impl EarlyStopping {
31    pub fn new(patience: usize, min_delta: f64, restore_best_weights: bool) -> Self {
32        Self {
33            patience,
34            min_delta,
35            best_loss: f64::INFINITY,
36            counter: 0,
37            restore_best_weights,
38        }
39    }
40    
41    pub fn should_stop(&mut self, current_loss: f64) -> bool {
42        if current_loss < self.best_loss - self.min_delta {
43            self.best_loss = current_loss;
44            self.counter = 0;
45            false
46        } else {
47            self.counter += 1;
48            self.counter >= self.patience
49        }
50    }
51    
52    pub fn reset(&mut self) {
53        self.best_loss = f64::INFINITY;
54        self.counter = 0;
55    }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct CheckpointConfig {
60    pub save_best: bool,
61    pub save_every: Option<usize>,
62    pub filepath: String,
63    pub monitor_loss: bool,
64}
65
66impl CheckpointConfig {
67    pub fn new<P: AsRef<Path>>(filepath: P) -> Self {
68        Self {
69            save_best: true,
70            save_every: None,
71            filepath: filepath.as_ref().to_string_lossy().to_string(),
72            monitor_loss: true,
73        }
74    }
75    
76    pub fn save_every(mut self, epochs: usize) -> Self {
77        self.save_every = Some(epochs);
78        self
79    }
80    
81    pub async fn save_weights(&self, weights: &[(DMatrix<f64>, DVector<f64>)]) -> Result<(), Box<dyn std::error::Error>> {
82        let data = bincode::serialize(weights)?;
83        fs::write(&self.filepath, data).await?;
84        Ok(())
85    }
86    
87    pub async fn load_weights(&self) -> Result<Vec<(DMatrix<f64>, DVector<f64>)>, Box<dyn std::error::Error>> {
88        let data = fs::read(&self.filepath).await?;
89        let weights = bincode::deserialize(&data)?;
90        Ok(weights)
91    }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub enum Regularization {
96    L2(f64),
97    L1(f64),
98    Dropout(f64),
99    None,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub enum LossFunction {
104    MeanSquaredError,
105    MeanAbsoluteError,
106    BinaryCrossEntropy,
107    CategoricalCrossEntropy,
108    /// Huber Loss: smooth combination of MSE and MAE
109    Huber { delta: f64 },
110}
111
112impl Default for LossFunction {
113    fn default() -> Self {
114        LossFunction::MeanSquaredError
115    }
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct BatchNormLayer {
120    gamma: DVector<f64>,
121    beta: DVector<f64>,
122    running_mean: DVector<f64>,
123    running_var: DVector<f64>,
124    momentum: f64,
125    epsilon: f64,
126    training: bool,
127}
128
129impl BatchNormLayer {
130    pub fn new(size: usize) -> Self {
131        Self {
132            gamma: DVector::from_element(size, 1.0),
133            beta: DVector::zeros(size),
134            running_mean: DVector::zeros(size),
135            running_var: DVector::from_element(size, 1.0),
136            momentum: 0.1,
137            epsilon: 1e-5,
138            training: true,
139        }
140    }
141    
142    /// Apply batch normalization forward pass
143    pub fn forward(&mut self, x: &DVector<f64>) -> (DVector<f64>, Option<(DVector<f64>, DVector<f64>, DVector<f64>)>) {
144        if self.training {
145            // Training mode: compute batch statistics
146            let mean = x.mean();
147            let var = x.iter().map(|xi| (xi - mean).powi(2)).sum::<f64>() / x.len() as f64;
148            let std_dev = (var + self.epsilon).sqrt();
149            
150            // Normalize
151            let normalized = x.map(|xi| (xi - mean) / std_dev);
152            
153            // Scale and shift
154            let output = normalized.component_mul(&self.gamma) + &self.beta;
155            
156            // Update running statistics
157            self.running_mean = &self.running_mean * (1.0 - self.momentum) + &DVector::from_element(x.len(), mean * self.momentum);
158            self.running_var = &self.running_var * (1.0 - self.momentum) + &DVector::from_element(x.len(), var * self.momentum);
159            
160            // Return normalized values and cache for backward pass
161            let cache = Some((normalized, DVector::from_element(x.len(), mean), DVector::from_element(x.len(), std_dev)));
162            (output, cache)
163        } else {
164            // Inference mode: use running statistics
165            let normalized = x.zip_map(&self.running_mean, |xi, mean| {
166                (xi - mean) / (self.running_var[0] + self.epsilon).sqrt()
167            });
168            let output = normalized.component_mul(&self.gamma) + &self.beta;
169            (output, None)
170        }
171    }
172    
173    pub fn set_training(&mut self, training: bool) {
174        self.training = training;
175    }
176}
177
178#[derive(Clone, Serialize, Deserialize)]
179pub struct Hextral {
180    layers: Vec<(DMatrix<f64>, DVector<f64>)>,
181    activation: ActivationFunction,
182    optimizer: Optimizer,
183    optimizer_state: OptimizerState,
184    regularization: Regularization,
185    loss_function: LossFunction,
186    batch_norm_layers: Vec<Option<BatchNormLayer>>,
187    use_batch_norm: bool,
188}
189
190impl Hextral {
191    pub fn new(
192        input_size: usize,
193        hidden_sizes: &[usize],
194        output_size: usize,
195        activation: ActivationFunction,
196        optimizer: Optimizer,
197    ) -> Self {
198        let mut layers = Vec::with_capacity(hidden_sizes.len() + 1);
199        let mut rng = thread_rng();
200
201        let mut prev_size = input_size;
202        
203        // Initialize hidden layers with Xavier initialization
204        for &size in hidden_sizes {
205            let bound = (6.0 / (size + prev_size) as f64).sqrt();
206            let weight = DMatrix::from_fn(size, prev_size, |_, _| {
207                rng.gen_range(-bound..bound)
208            });
209            let bias = DVector::zeros(size);
210            layers.push((weight, bias));
211            prev_size = size;
212        }
213        
214        // Initialize output layer
215        let bound = (6.0 / (output_size + prev_size) as f64).sqrt();
216        let weight = DMatrix::from_fn(output_size, prev_size, |_, _| {
217            rng.gen_range(-bound..bound)
218        });
219        let bias = DVector::zeros(output_size);
220        layers.push((weight, bias));
221
222        // Create layer shapes for optimizer state initialization
223        let layer_shapes: Vec<(usize, usize)> = layers.iter()
224            .map(|(w, _)| (w.nrows(), w.ncols()))
225            .collect();
226
227        Hextral {
228            layers,
229            activation,
230            optimizer_state: OptimizerState::new(&layer_shapes),
231            optimizer,
232            regularization: Regularization::None,
233            loss_function: LossFunction::default(),
234            batch_norm_layers: Vec::new(),
235            use_batch_norm: false,
236        }
237    }
238
239    /// Set regularization
240    pub fn set_regularization(&mut self, reg: Regularization) {
241        self.regularization = reg;
242    }
243
244    /// Set loss function
245    pub fn set_loss_function(&mut self, loss: LossFunction) {
246        self.loss_function = loss;
247    }
248    
249    /// Enable batch normalization for all hidden layers
250    pub fn enable_batch_norm(&mut self) {
251        if !self.use_batch_norm {
252            self.use_batch_norm = true;
253            self.batch_norm_layers.clear();
254            
255            // Add batch norm layers for all but the output layer
256            for i in 0..self.layers.len() - 1 {
257                let layer_size = self.layers[i].0.nrows(); // Number of outputs from this layer
258                self.batch_norm_layers.push(Some(BatchNormLayer::new(layer_size)));
259            }
260            // No batch norm for output layer
261            self.batch_norm_layers.push(None);
262        }
263    }
264    
265    /// Disable batch normalization
266    pub fn disable_batch_norm(&mut self) {
267        self.use_batch_norm = false;
268        self.batch_norm_layers.clear();
269    }
270    
271    /// Set training mode for batch normalization
272    pub fn set_training_mode(&mut self, training: bool) {
273        for bn_layer in &mut self.batch_norm_layers {
274            if let Some(bn) = bn_layer {
275                bn.set_training(training);
276            }
277        }
278    }
279
280    pub async fn forward(&self, input: &DVector<f64>) -> DVector<f64> {
281        let mut output = input.clone();
282        
283        // Only yield if network has many layers
284        if self.layers.len() > 5 {
285            let mid = self.layers.len() / 2;
286            
287            for (i, (weight, bias)) in self.layers.iter().enumerate() {
288                output = weight * &output + bias;
289                if i < self.layers.len() - 1 {
290                    output = self.activation.apply(&output);
291                }
292                if i == mid {
293                    tokio::task::yield_now().await;
294                }
295            }
296        } else {
297            for (i, (weight, bias)) in self.layers.iter().enumerate() {
298                output = weight * &output + bias;
299                if i < self.layers.len() - 1 {
300                    output = self.activation.apply(&output);
301                }
302            }
303        }
304        
305        output
306    }
307
308    pub async fn predict(&self, input: &DVector<f64>) -> DVector<f64> {
309        self.forward(input).await
310    }
311
312    pub async fn predict_batch(&self, inputs: &[DVector<f64>]) -> Vec<DVector<f64>> {
313        if inputs.len() > 10 {
314            let futures: Vec<_> = inputs.iter()
315                .map(|input| self.predict(input))
316                .collect();
317            join_all(futures).await
318        } else {
319            let mut results = Vec::new();
320            for input in inputs {
321                results.push(self.predict(input).await);
322            }
323            results
324        }
325    }
326
327    /// Compute loss between prediction and target
328    pub fn compute_loss(&self, prediction: &DVector<f64>, target: &DVector<f64>) -> f64 {
329        match &self.loss_function {
330            LossFunction::MeanSquaredError => {
331                let error = prediction - target;
332                0.5 * error.dot(&error)
333            },
334            LossFunction::MeanAbsoluteError => {
335                let error = prediction - target;
336                error.iter().map(|x| x.abs()).sum::<f64>()
337            },
338            LossFunction::BinaryCrossEntropy => {
339                let mut loss = 0.0;
340                for (pred, targ) in prediction.iter().zip(target.iter()) {
341                    let p = pred.max(1e-15).min(1.0 - 1e-15); // Clamp to avoid log(0)
342                    loss -= targ * p.ln() + (1.0 - targ) * (1.0 - p).ln();
343                }
344                loss
345            },
346            LossFunction::CategoricalCrossEntropy => {
347                let mut loss = 0.0;
348                for (pred, targ) in prediction.iter().zip(target.iter()) {
349                    if *targ > 0.0 {
350                        loss -= targ * pred.max(1e-15).ln();
351                    }
352                }
353                loss
354            },
355            LossFunction::Huber { delta } => {
356                let error = prediction - target;
357                let mut loss = 0.0;
358                for e in error.iter() {
359                    if e.abs() <= *delta {
360                        loss += 0.5 * e * e;
361                    } else {
362                        loss += delta * (e.abs() - 0.5 * delta);
363                    }
364                }
365                loss
366            }
367        }
368    }
369
370    /// Compute loss gradient for backpropagation
371    pub fn compute_loss_gradient(&self, prediction: &DVector<f64>, target: &DVector<f64>) -> DVector<f64> {
372        match &self.loss_function {
373            LossFunction::MeanSquaredError => {
374                prediction - target
375            },
376            LossFunction::MeanAbsoluteError => {
377                let error = prediction - target;
378                error.map(|x| if x > 0.0 { 1.0 } else if x < 0.0 { -1.0 } else { 0.0 })
379            },
380            LossFunction::BinaryCrossEntropy => {
381                let mut grad = DVector::zeros(prediction.len());
382                for i in 0..prediction.len() {
383                    let p = prediction[i].max(1e-15).min(1.0 - 1e-15);
384                    let t = target[i];
385                    grad[i] = (p - t) / (p * (1.0 - p));
386                }
387                grad
388            },
389            LossFunction::CategoricalCrossEntropy => {
390                let mut grad = DVector::zeros(prediction.len());
391                for i in 0..prediction.len() {
392                    if target[i] > 0.0 {
393                        grad[i] = -target[i] / prediction[i].max(1e-15);
394                    }
395                }
396                grad
397            },
398            LossFunction::Huber { delta } => {
399                let error = prediction - target;
400                error.map(|e| {
401                    if e.abs() <= *delta {
402                        e
403                    } else {
404                        delta * e.signum()
405                    }
406                })
407            }
408        }
409    }
410
411    pub async fn train_step(&mut self, input: &DVector<f64>, target: &DVector<f64>, learning_rate: f64) -> f64 {
412        // Forward pass - collect activations
413        let mut activations = vec![input.clone()];
414        let mut current = input.clone();
415        
416        for (i, (weight, bias)) in self.layers.iter().enumerate() {
417            current = weight * &current + bias;
418            if i < self.layers.len() - 1 {
419                current = self.activation.apply(&current);
420            }
421            activations.push(current.clone());
422        }
423        
424        let prediction = &activations[activations.len() - 1];
425        
426        // Compute loss using configured loss function
427        let loss = self.compute_loss(prediction, target);
428        
429        // Backward pass - compute loss gradient
430        let mut delta = self.compute_loss_gradient(prediction, target);
431        
432        for i in (0..self.layers.len()).rev() {
433            let input_activation = &activations[i];
434            let output_activation = &activations[i + 1];
435            
436            // Apply activation derivative (except for output layer)
437            if i < self.layers.len() - 1 {
438                let activation_grad = self.activation.apply_derivative(output_activation);
439                delta = delta.component_mul(&activation_grad);
440            }
441            
442            // Compute gradients
443            let weight_grad = &delta * input_activation.transpose();
444            let bias_grad = delta.clone();
445            
446            // Apply regularization
447            let reg_weight_grad = match &self.regularization {
448                Regularization::L2(lambda) => &self.layers[i].0 * *lambda,
449                Regularization::L1(lambda) => self.layers[i].0.map(|w| *lambda * w.signum()),
450                _ => DMatrix::zeros(self.layers[i].0.nrows(), self.layers[i].0.ncols()),
451            };
452            
453            let final_weight_grad = weight_grad + reg_weight_grad;
454            
455            // Update parameters using the new optimizer system
456            let (mut weights, mut biases) = self.layers[i].clone();
457            self.optimizer.update_parameters(
458                &mut weights,
459                &mut biases,
460                &final_weight_grad,
461                &bias_grad,
462                &mut self.optimizer_state,
463                i,
464                learning_rate,
465            );
466            self.layers[i] = (weights, biases);
467            
468            // Propagate error to previous layer
469            if i > 0 {
470                delta = self.layers[i].0.transpose() * &delta;
471            }
472        }
473        
474        // Yield occasionally for async compatibility
475        if self.layers.len() > 3 {
476            tokio::task::yield_now().await;
477        }
478        
479        loss
480    }
481
482    /// Full async training method with early stopping and checkpoints
483    pub async fn train(
484        &mut self,
485        train_inputs: &[DVector<f64>],
486        train_targets: &[DVector<f64>],
487        learning_rate: f64,
488        epochs: usize,
489        batch_size: Option<usize>,
490        val_inputs: Option<&[DVector<f64>]>,
491        val_targets: Option<&[DVector<f64>]>,
492        early_stopping: Option<EarlyStopping>,
493        checkpoint_config: Option<CheckpointConfig>,
494    ) -> Result<(Vec<f64>, Vec<f64>), Box<dyn std::error::Error>> {
495        let mut train_loss_history = Vec::new();
496        let mut val_loss_history = Vec::new();
497        let mut early_stop = early_stopping;
498        let mut best_val_loss = f64::INFINITY;
499        let batch_size = batch_size.unwrap_or(32);
500
501        for epoch in 0..epochs {
502            // Training phase
503            let mut epoch_loss = 0.0;
504            let mut indices: Vec<usize> = (0..train_inputs.len()).collect();
505            indices.shuffle(&mut thread_rng());
506            
507            for batch in indices.chunks(batch_size) {
508                for &i in batch {
509                    epoch_loss += self.train_step(&train_inputs[i], &train_targets[i], learning_rate).await;
510                }
511                if batch_size > 10 {
512                    tokio::task::yield_now().await;
513                }
514            }
515            
516            let train_loss = epoch_loss / train_inputs.len() as f64;
517            train_loss_history.push(train_loss);
518
519            // Validation phase
520            let val_loss = if let (Some(val_inputs), Some(val_targets)) = (val_inputs, val_targets) {
521                self.evaluate(val_inputs, val_targets).await
522            } else {
523                train_loss // Use training loss if no validation data
524            };
525            val_loss_history.push(val_loss);
526
527            // Checkpoint management
528            if let Some(ref config) = checkpoint_config {
529                let should_save_best = config.save_best && val_loss < best_val_loss;
530                let should_save_periodic = config.save_every.map_or(false, |freq| (epoch + 1) % freq == 0);
531                
532                if should_save_best {
533                    best_val_loss = val_loss;
534                    config.save_weights(&self.layers).await?;
535                }
536                
537                if should_save_periodic {
538                    let periodic_path = format!("{}_epoch_{}", config.filepath, epoch + 1);
539                    let periodic_config = CheckpointConfig::new(&periodic_path);
540                    periodic_config.save_weights(&self.layers).await?;
541                }
542            }
543
544            // Early stopping check
545            if let Some(ref mut early_stop) = early_stop {
546                if early_stop.should_stop(val_loss) {
547                    if early_stop.restore_best_weights {
548                        if let Some(ref config) = checkpoint_config {
549                            if config.save_best {
550                                match config.load_weights().await {
551                                    Ok(weights) => self.set_weights(weights),
552                                    Err(_) => {} // Continue with current weights if loading fails
553                                }
554                            }
555                        }
556                    }
557                    break;
558                }
559            }
560
561            // Yield occasionally for long training
562            if epoch % 10 == 0 {
563                tokio::task::yield_now().await;
564            }
565        }
566
567        Ok((train_loss_history, val_loss_history))
568    }
569
570    pub async fn evaluate(&self, test_inputs: &[DVector<f64>], test_targets: &[DVector<f64>]) -> f64 {
571        if test_inputs.len() > 10 {
572            // Process predictions in parallel for large datasets
573            let predictions = self.predict_batch(test_inputs).await;
574            
575            let mut total_loss = 0.0;
576            for (prediction, target) in predictions.iter().zip(test_targets.iter()) {
577                let loss = self.compute_loss(prediction, target);
578                total_loss += loss;
579            }
580            total_loss / test_inputs.len() as f64
581        } else {
582            // Process sequentially for small datasets
583            let mut total_loss = 0.0;
584            for (input, target) in test_inputs.iter().zip(test_targets.iter()) {
585                let prediction = self.predict(input).await;
586                let loss = self.compute_loss(&prediction, target);
587                total_loss += loss;
588            }
589            total_loss / test_inputs.len() as f64
590        }
591    }
592
593    /// Get the number of parameters in the network
594    pub fn parameter_count(&self) -> usize {
595        self.layers.iter()
596            .map(|(weight, bias)| weight.len() + bias.len())
597            .sum()
598    }
599
600    /// Get network architecture info
601    pub fn architecture(&self) -> Vec<usize> {
602        let mut arch = vec![self.layers[0].0.ncols()]; // input size
603        arch.extend(self.layers.iter().map(|(weight, _)| weight.nrows()));
604        arch
605    }
606
607    /// Save network weights (simplified serialization)
608    pub fn get_weights(&self) -> Vec<(DMatrix<f64>, DVector<f64>)> {
609        self.layers.clone()
610    }
611
612    /// Load network weights (simplified deserialization)  
613    pub fn set_weights(&mut self, weights: Vec<(DMatrix<f64>, DVector<f64>)>) {
614        if weights.len() == self.layers.len() {
615            self.layers = weights;
616        }
617    }
618}