1use nalgebra::{DVector, DMatrix};
2use rand::{Rng, thread_rng};
3use rand::seq::SliceRandom;
4use futures::future::join_all;
5use serde::{Serialize, Deserialize};
6use std::path::Path;
7use tokio::fs;
8
9pub mod activation;
10pub mod optimizer;
11
12#[cfg(feature = "datasets")]
13pub mod dataset;
14
15pub use activation::ActivationFunction;
16pub use optimizer::{Optimizer, OptimizerState};
17
18#[cfg(feature = "datasets")]
19pub use dataset::{Dataset, DatasetLoader, DatasetError, PreprocessingConfig, FillStrategy};
20
21#[derive(Debug, Clone)]
22pub struct EarlyStopping {
23 pub patience: usize,
24 pub min_delta: f64,
25 pub best_loss: f64,
26 pub counter: usize,
27 pub restore_best_weights: bool,
28}
29
30impl EarlyStopping {
31 pub fn new(patience: usize, min_delta: f64, restore_best_weights: bool) -> Self {
32 Self {
33 patience,
34 min_delta,
35 best_loss: f64::INFINITY,
36 counter: 0,
37 restore_best_weights,
38 }
39 }
40
41 pub fn should_stop(&mut self, current_loss: f64) -> bool {
42 if current_loss < self.best_loss - self.min_delta {
43 self.best_loss = current_loss;
44 self.counter = 0;
45 false
46 } else {
47 self.counter += 1;
48 self.counter >= self.patience
49 }
50 }
51
52 pub fn reset(&mut self) {
53 self.best_loss = f64::INFINITY;
54 self.counter = 0;
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct CheckpointConfig {
60 pub save_best: bool,
61 pub save_every: Option<usize>,
62 pub filepath: String,
63 pub monitor_loss: bool,
64}
65
66impl CheckpointConfig {
67 pub fn new<P: AsRef<Path>>(filepath: P) -> Self {
68 Self {
69 save_best: true,
70 save_every: None,
71 filepath: filepath.as_ref().to_string_lossy().to_string(),
72 monitor_loss: true,
73 }
74 }
75
76 pub fn save_every(mut self, epochs: usize) -> Self {
77 self.save_every = Some(epochs);
78 self
79 }
80
81 pub async fn save_weights(&self, weights: &[(DMatrix<f64>, DVector<f64>)]) -> Result<(), Box<dyn std::error::Error>> {
82 let data = bincode::serialize(weights)?;
83 fs::write(&self.filepath, data).await?;
84 Ok(())
85 }
86
87 pub async fn load_weights(&self) -> Result<Vec<(DMatrix<f64>, DVector<f64>)>, Box<dyn std::error::Error>> {
88 let data = fs::read(&self.filepath).await?;
89 let weights = bincode::deserialize(&data)?;
90 Ok(weights)
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub enum Regularization {
96 L2(f64),
97 L1(f64),
98 Dropout(f64),
99 None,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub enum LossFunction {
104 MeanSquaredError,
105 MeanAbsoluteError,
106 BinaryCrossEntropy,
107 CategoricalCrossEntropy,
108 Huber { delta: f64 },
110}
111
112impl Default for LossFunction {
113 fn default() -> Self {
114 LossFunction::MeanSquaredError
115 }
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct BatchNormLayer {
120 gamma: DVector<f64>,
121 beta: DVector<f64>,
122 running_mean: DVector<f64>,
123 running_var: DVector<f64>,
124 momentum: f64,
125 epsilon: f64,
126 training: bool,
127}
128
129impl BatchNormLayer {
130 pub fn new(size: usize) -> Self {
131 Self {
132 gamma: DVector::from_element(size, 1.0),
133 beta: DVector::zeros(size),
134 running_mean: DVector::zeros(size),
135 running_var: DVector::from_element(size, 1.0),
136 momentum: 0.1,
137 epsilon: 1e-5,
138 training: true,
139 }
140 }
141
142 pub fn forward(&mut self, x: &DVector<f64>) -> (DVector<f64>, Option<(DVector<f64>, DVector<f64>, DVector<f64>)>) {
144 if self.training {
145 let mean = x.mean();
147 let var = x.iter().map(|xi| (xi - mean).powi(2)).sum::<f64>() / x.len() as f64;
148 let std_dev = (var + self.epsilon).sqrt();
149
150 let normalized = x.map(|xi| (xi - mean) / std_dev);
152
153 let output = normalized.component_mul(&self.gamma) + &self.beta;
155
156 self.running_mean = &self.running_mean * (1.0 - self.momentum) + &DVector::from_element(x.len(), mean * self.momentum);
158 self.running_var = &self.running_var * (1.0 - self.momentum) + &DVector::from_element(x.len(), var * self.momentum);
159
160 let cache = Some((normalized, DVector::from_element(x.len(), mean), DVector::from_element(x.len(), std_dev)));
162 (output, cache)
163 } else {
164 let normalized = x.zip_map(&self.running_mean, |xi, mean| {
166 (xi - mean) / (self.running_var[0] + self.epsilon).sqrt()
167 });
168 let output = normalized.component_mul(&self.gamma) + &self.beta;
169 (output, None)
170 }
171 }
172
173 pub fn set_training(&mut self, training: bool) {
174 self.training = training;
175 }
176}
177
178#[derive(Clone, Serialize, Deserialize)]
179pub struct Hextral {
180 layers: Vec<(DMatrix<f64>, DVector<f64>)>,
181 activation: ActivationFunction,
182 optimizer: Optimizer,
183 optimizer_state: OptimizerState,
184 regularization: Regularization,
185 loss_function: LossFunction,
186 batch_norm_layers: Vec<Option<BatchNormLayer>>,
187 use_batch_norm: bool,
188}
189
190impl Hextral {
191 pub fn new(
192 input_size: usize,
193 hidden_sizes: &[usize],
194 output_size: usize,
195 activation: ActivationFunction,
196 optimizer: Optimizer,
197 ) -> Self {
198 let mut layers = Vec::with_capacity(hidden_sizes.len() + 1);
199 let mut rng = thread_rng();
200
201 let mut prev_size = input_size;
202
203 for &size in hidden_sizes {
205 let bound = (6.0 / (size + prev_size) as f64).sqrt();
206 let weight = DMatrix::from_fn(size, prev_size, |_, _| {
207 rng.gen_range(-bound..bound)
208 });
209 let bias = DVector::zeros(size);
210 layers.push((weight, bias));
211 prev_size = size;
212 }
213
214 let bound = (6.0 / (output_size + prev_size) as f64).sqrt();
216 let weight = DMatrix::from_fn(output_size, prev_size, |_, _| {
217 rng.gen_range(-bound..bound)
218 });
219 let bias = DVector::zeros(output_size);
220 layers.push((weight, bias));
221
222 let layer_shapes: Vec<(usize, usize)> = layers.iter()
224 .map(|(w, _)| (w.nrows(), w.ncols()))
225 .collect();
226
227 Hextral {
228 layers,
229 activation,
230 optimizer_state: OptimizerState::new(&layer_shapes),
231 optimizer,
232 regularization: Regularization::None,
233 loss_function: LossFunction::default(),
234 batch_norm_layers: Vec::new(),
235 use_batch_norm: false,
236 }
237 }
238
239 pub fn set_regularization(&mut self, reg: Regularization) {
241 self.regularization = reg;
242 }
243
244 pub fn set_loss_function(&mut self, loss: LossFunction) {
246 self.loss_function = loss;
247 }
248
249 pub fn enable_batch_norm(&mut self) {
251 if !self.use_batch_norm {
252 self.use_batch_norm = true;
253 self.batch_norm_layers.clear();
254
255 for i in 0..self.layers.len() - 1 {
257 let layer_size = self.layers[i].0.nrows(); self.batch_norm_layers.push(Some(BatchNormLayer::new(layer_size)));
259 }
260 self.batch_norm_layers.push(None);
262 }
263 }
264
265 pub fn disable_batch_norm(&mut self) {
267 self.use_batch_norm = false;
268 self.batch_norm_layers.clear();
269 }
270
271 pub fn set_training_mode(&mut self, training: bool) {
273 for bn_layer in &mut self.batch_norm_layers {
274 if let Some(bn) = bn_layer {
275 bn.set_training(training);
276 }
277 }
278 }
279
280 pub async fn forward(&self, input: &DVector<f64>) -> DVector<f64> {
281 let mut output = input.clone();
282
283 if self.layers.len() > 5 {
285 let mid = self.layers.len() / 2;
286
287 for (i, (weight, bias)) in self.layers.iter().enumerate() {
288 output = weight * &output + bias;
289 if i < self.layers.len() - 1 {
290 output = self.activation.apply(&output);
291 }
292 if i == mid {
293 tokio::task::yield_now().await;
294 }
295 }
296 } else {
297 for (i, (weight, bias)) in self.layers.iter().enumerate() {
298 output = weight * &output + bias;
299 if i < self.layers.len() - 1 {
300 output = self.activation.apply(&output);
301 }
302 }
303 }
304
305 output
306 }
307
308 pub async fn predict(&self, input: &DVector<f64>) -> DVector<f64> {
309 self.forward(input).await
310 }
311
312 pub async fn predict_batch(&self, inputs: &[DVector<f64>]) -> Vec<DVector<f64>> {
313 if inputs.len() > 10 {
314 let futures: Vec<_> = inputs.iter()
315 .map(|input| self.predict(input))
316 .collect();
317 join_all(futures).await
318 } else {
319 let mut results = Vec::new();
320 for input in inputs {
321 results.push(self.predict(input).await);
322 }
323 results
324 }
325 }
326
327 pub fn compute_loss(&self, prediction: &DVector<f64>, target: &DVector<f64>) -> f64 {
329 match &self.loss_function {
330 LossFunction::MeanSquaredError => {
331 let error = prediction - target;
332 0.5 * error.dot(&error)
333 },
334 LossFunction::MeanAbsoluteError => {
335 let error = prediction - target;
336 error.iter().map(|x| x.abs()).sum::<f64>()
337 },
338 LossFunction::BinaryCrossEntropy => {
339 let mut loss = 0.0;
340 for (pred, targ) in prediction.iter().zip(target.iter()) {
341 let p = pred.max(1e-15).min(1.0 - 1e-15); loss -= targ * p.ln() + (1.0 - targ) * (1.0 - p).ln();
343 }
344 loss
345 },
346 LossFunction::CategoricalCrossEntropy => {
347 let mut loss = 0.0;
348 for (pred, targ) in prediction.iter().zip(target.iter()) {
349 if *targ > 0.0 {
350 loss -= targ * pred.max(1e-15).ln();
351 }
352 }
353 loss
354 },
355 LossFunction::Huber { delta } => {
356 let error = prediction - target;
357 let mut loss = 0.0;
358 for e in error.iter() {
359 if e.abs() <= *delta {
360 loss += 0.5 * e * e;
361 } else {
362 loss += delta * (e.abs() - 0.5 * delta);
363 }
364 }
365 loss
366 }
367 }
368 }
369
370 pub fn compute_loss_gradient(&self, prediction: &DVector<f64>, target: &DVector<f64>) -> DVector<f64> {
372 match &self.loss_function {
373 LossFunction::MeanSquaredError => {
374 prediction - target
375 },
376 LossFunction::MeanAbsoluteError => {
377 let error = prediction - target;
378 error.map(|x| if x > 0.0 { 1.0 } else if x < 0.0 { -1.0 } else { 0.0 })
379 },
380 LossFunction::BinaryCrossEntropy => {
381 let mut grad = DVector::zeros(prediction.len());
382 for i in 0..prediction.len() {
383 let p = prediction[i].max(1e-15).min(1.0 - 1e-15);
384 let t = target[i];
385 grad[i] = (p - t) / (p * (1.0 - p));
386 }
387 grad
388 },
389 LossFunction::CategoricalCrossEntropy => {
390 let mut grad = DVector::zeros(prediction.len());
391 for i in 0..prediction.len() {
392 if target[i] > 0.0 {
393 grad[i] = -target[i] / prediction[i].max(1e-15);
394 }
395 }
396 grad
397 },
398 LossFunction::Huber { delta } => {
399 let error = prediction - target;
400 error.map(|e| {
401 if e.abs() <= *delta {
402 e
403 } else {
404 delta * e.signum()
405 }
406 })
407 }
408 }
409 }
410
411 pub async fn train_step(&mut self, input: &DVector<f64>, target: &DVector<f64>, learning_rate: f64) -> f64 {
412 let mut activations = vec![input.clone()];
414 let mut current = input.clone();
415
416 for (i, (weight, bias)) in self.layers.iter().enumerate() {
417 current = weight * ¤t + bias;
418 if i < self.layers.len() - 1 {
419 current = self.activation.apply(¤t);
420 }
421 activations.push(current.clone());
422 }
423
424 let prediction = &activations[activations.len() - 1];
425
426 let loss = self.compute_loss(prediction, target);
428
429 let mut delta = self.compute_loss_gradient(prediction, target);
431
432 for i in (0..self.layers.len()).rev() {
433 let input_activation = &activations[i];
434 let output_activation = &activations[i + 1];
435
436 if i < self.layers.len() - 1 {
438 let activation_grad = self.activation.apply_derivative(output_activation);
439 delta = delta.component_mul(&activation_grad);
440 }
441
442 let weight_grad = &delta * input_activation.transpose();
444 let bias_grad = delta.clone();
445
446 let reg_weight_grad = match &self.regularization {
448 Regularization::L2(lambda) => &self.layers[i].0 * *lambda,
449 Regularization::L1(lambda) => self.layers[i].0.map(|w| *lambda * w.signum()),
450 _ => DMatrix::zeros(self.layers[i].0.nrows(), self.layers[i].0.ncols()),
451 };
452
453 let final_weight_grad = weight_grad + reg_weight_grad;
454
455 let (mut weights, mut biases) = self.layers[i].clone();
457 self.optimizer.update_parameters(
458 &mut weights,
459 &mut biases,
460 &final_weight_grad,
461 &bias_grad,
462 &mut self.optimizer_state,
463 i,
464 learning_rate,
465 );
466 self.layers[i] = (weights, biases);
467
468 if i > 0 {
470 delta = self.layers[i].0.transpose() * δ
471 }
472 }
473
474 if self.layers.len() > 3 {
476 tokio::task::yield_now().await;
477 }
478
479 loss
480 }
481
482 pub async fn train(
484 &mut self,
485 train_inputs: &[DVector<f64>],
486 train_targets: &[DVector<f64>],
487 learning_rate: f64,
488 epochs: usize,
489 batch_size: Option<usize>,
490 val_inputs: Option<&[DVector<f64>]>,
491 val_targets: Option<&[DVector<f64>]>,
492 early_stopping: Option<EarlyStopping>,
493 checkpoint_config: Option<CheckpointConfig>,
494 ) -> Result<(Vec<f64>, Vec<f64>), Box<dyn std::error::Error>> {
495 let mut train_loss_history = Vec::new();
496 let mut val_loss_history = Vec::new();
497 let mut early_stop = early_stopping;
498 let mut best_val_loss = f64::INFINITY;
499 let batch_size = batch_size.unwrap_or(32);
500
501 for epoch in 0..epochs {
502 let mut epoch_loss = 0.0;
504 let mut indices: Vec<usize> = (0..train_inputs.len()).collect();
505 indices.shuffle(&mut thread_rng());
506
507 for batch in indices.chunks(batch_size) {
508 for &i in batch {
509 epoch_loss += self.train_step(&train_inputs[i], &train_targets[i], learning_rate).await;
510 }
511 if batch_size > 10 {
512 tokio::task::yield_now().await;
513 }
514 }
515
516 let train_loss = epoch_loss / train_inputs.len() as f64;
517 train_loss_history.push(train_loss);
518
519 let val_loss = if let (Some(val_inputs), Some(val_targets)) = (val_inputs, val_targets) {
521 self.evaluate(val_inputs, val_targets).await
522 } else {
523 train_loss };
525 val_loss_history.push(val_loss);
526
527 if let Some(ref config) = checkpoint_config {
529 let should_save_best = config.save_best && val_loss < best_val_loss;
530 let should_save_periodic = config.save_every.map_or(false, |freq| (epoch + 1) % freq == 0);
531
532 if should_save_best {
533 best_val_loss = val_loss;
534 config.save_weights(&self.layers).await?;
535 }
536
537 if should_save_periodic {
538 let periodic_path = format!("{}_epoch_{}", config.filepath, epoch + 1);
539 let periodic_config = CheckpointConfig::new(&periodic_path);
540 periodic_config.save_weights(&self.layers).await?;
541 }
542 }
543
544 if let Some(ref mut early_stop) = early_stop {
546 if early_stop.should_stop(val_loss) {
547 if early_stop.restore_best_weights {
548 if let Some(ref config) = checkpoint_config {
549 if config.save_best {
550 match config.load_weights().await {
551 Ok(weights) => self.set_weights(weights),
552 Err(_) => {} }
554 }
555 }
556 }
557 break;
558 }
559 }
560
561 if epoch % 10 == 0 {
563 tokio::task::yield_now().await;
564 }
565 }
566
567 Ok((train_loss_history, val_loss_history))
568 }
569
570 pub async fn evaluate(&self, test_inputs: &[DVector<f64>], test_targets: &[DVector<f64>]) -> f64 {
571 if test_inputs.len() > 10 {
572 let predictions = self.predict_batch(test_inputs).await;
574
575 let mut total_loss = 0.0;
576 for (prediction, target) in predictions.iter().zip(test_targets.iter()) {
577 let loss = self.compute_loss(prediction, target);
578 total_loss += loss;
579 }
580 total_loss / test_inputs.len() as f64
581 } else {
582 let mut total_loss = 0.0;
584 for (input, target) in test_inputs.iter().zip(test_targets.iter()) {
585 let prediction = self.predict(input).await;
586 let loss = self.compute_loss(&prediction, target);
587 total_loss += loss;
588 }
589 total_loss / test_inputs.len() as f64
590 }
591 }
592
593 pub fn parameter_count(&self) -> usize {
595 self.layers.iter()
596 .map(|(weight, bias)| weight.len() + bias.len())
597 .sum()
598 }
599
600 pub fn architecture(&self) -> Vec<usize> {
602 let mut arch = vec![self.layers[0].0.ncols()]; arch.extend(self.layers.iter().map(|(weight, _)| weight.nrows()));
604 arch
605 }
606
607 pub fn get_weights(&self) -> Vec<(DMatrix<f64>, DVector<f64>)> {
609 self.layers.clone()
610 }
611
612 pub fn set_weights(&mut self, weights: Vec<(DMatrix<f64>, DVector<f64>)>) {
614 if weights.len() == self.layers.len() {
615 self.layers = weights;
616 }
617 }
618}