1use crate::nab_array::NDArray;
2use crate::nab_layers::NabLayer;
3use crate::nab_optimizers::NablaOptimizer;
4use crate::nab_loss::NabLoss;
5use std::collections::HashMap;
6use serde::{Serialize, Deserialize};
7use std::path::Path;
8use serde_json;
9use flate2::write::GzEncoder;
10use flate2::read::GzDecoder;
11use flate2::Compression;
12use std::io::{Write, Read};
13
14static mut NEXT_NODE_ID: usize = 0;
15
16pub struct Node {
18 pub layer: NabLayer,
19 pub inputs: Vec<usize>, pub output_shape: Vec<usize>,
21}
22
23#[allow(dead_code)]
48#[derive(Clone)]
49pub struct NabModel {
50 layers: Vec<NabLayer>,
51 optimizer_type: String,
52 learning_rate: f64,
53 loss_type: String, metrics: Vec<String>,
55}
56
57#[derive(Clone)]
59pub struct Input {
60 shape: Vec<usize>,
61 node_index: Option<usize>,
62}
63
64#[derive(Clone)]
66#[allow(dead_code)]
67pub struct Output {
68 layer: NabLayer,
69 inputs: Vec<usize>,
70 output_shape: Vec<usize>,
71 previous_output: Option<Box<Output>>,
72}
73
74impl Input {
75 pub fn apply<L: Into<NabLayer>>(&self, layer: L) -> Output {
77 let mut layer = layer.into();
78 let output_shape = layer.compute_output_shape(&self.shape);
79
80 let layer_id = unsafe {
82 NEXT_NODE_ID += 1;
83 NEXT_NODE_ID
84 };
85
86 layer.set_node_index(layer_id);
87
88 println!("Connecting layer {} (id: {}) to input (id: {})",
89 layer.get_name(),
90 layer_id,
91 self.node_index.unwrap()
92 );
93
94 Output {
95 layer,
96 inputs: vec![self.node_index.unwrap()],
97 output_shape,
98 previous_output: None,
99 }
100 }
101
102 pub fn get_input_shape(&self) -> &Vec<usize> {
108 &self.shape
109 }
110}
111
112impl Output {
113 pub fn apply<L: Into<NabLayer>>(&self, layer: L) -> Output {
115 let mut layer = layer.into();
116 let output_shape = layer.compute_output_shape(&self.output_shape);
117
118 let layer_id = unsafe {
120 NEXT_NODE_ID += 1;
121 NEXT_NODE_ID
122 };
123
124 layer.set_node_index(layer_id);
125
126 println!("Connecting layer {} (id: {}) to {} (id: {})",
127 layer.get_name(),
128 layer_id,
129 self.layer.get_name(),
130 self.layer.node_index.unwrap()
131 );
132
133 Output {
134 layer,
135 inputs: vec![self.layer.node_index.unwrap()],
136 output_shape,
137 previous_output: Some(Box::new(self.clone())),
138 }
139 }
140
141 pub fn get_previous_layer(&self) -> Option<&NabLayer> {
143 None }
146}
147
148#[allow(dead_code)]
149impl NabModel {
150 pub fn input(shape: Vec<usize>) -> Input {
160 let node_index = unsafe {
161 NEXT_NODE_ID += 1;
162 NEXT_NODE_ID
163 };
164
165 Input {
166 shape,
167 node_index: Some(node_index),
168 }
169 }
170
171 pub fn new() -> Self {
173 NabModel {
174 layers: Vec::new(),
175 optimizer_type: String::new(),
176 learning_rate: 0.0,
177 loss_type: String::new(),
178 metrics: Vec::new(),
179 }
180 }
181
182 pub fn add(&mut self, layer: NabLayer) -> &mut Self {
184 self.layers.push(layer);
185 self
186 }
187
188 pub fn compile(&mut self, optimizer_type: &str, learning_rate: f64,
196 loss_type: &str, metrics: Vec<String>) {
197 self.optimizer_type = optimizer_type.to_string();
198 self.learning_rate = learning_rate;
199 self.loss_type = loss_type.to_string();
200 self.metrics = metrics;
201 }
202
203 fn train_epoch(&mut self, x: &NDArray, y: &NDArray, batch_size: usize) -> HashMap<String, f64> {
205 let mut metrics = HashMap::new();
206 let mut total_loss = 0.0;
207 let mut total_correct = 0;
208 let num_samples = x.shape()[0];
209 let num_batches = (num_samples + batch_size - 1) / batch_size;
210
211 for batch_idx in 0..num_batches {
213 let start_idx = batch_idx * batch_size;
214 let end_idx = (start_idx + batch_size).min(num_samples);
215
216 let x_batch = x.slice(start_idx, end_idx);
218 let y_batch = y.slice(start_idx, end_idx);
219
220 let (predictions, loss) = self.forward_backward(&x_batch, &y_batch);
222
223 total_loss += loss * (end_idx - start_idx) as f64;
225 total_correct += self.count_correct(&predictions, &y_batch);
226 }
227
228 metrics.insert("loss".to_string(), total_loss / num_samples as f64);
230 metrics.insert("accuracy".to_string(), total_correct as f64 / num_samples as f64);
231
232 metrics
233 }
234
235 fn forward_backward(&mut self, x_batch: &NDArray, y_batch: &NDArray) -> (NDArray, f64) {
236 let predictions = self.predict(x_batch);
238 let loss = self.calculate_loss(&predictions, y_batch);
239 let loss_grad = self.calculate_loss_gradient(&predictions, y_batch);
240
241 let mut gradient = loss_grad;
243 let learning_rate = self.learning_rate; for layer in self.layers.iter_mut().rev() {
246 if layer.is_trainable() {
247 gradient = layer.backward(&gradient);
248
249 if let Some(weights) = layer.weights.as_mut() {
251 let weight_grads = layer.weight_gradients.as_ref().unwrap();
252 NablaOptimizer::sgd_update(weights, weight_grads, learning_rate);
253 }
254 if let Some(biases) = layer.biases.as_mut() {
255 let bias_grads = layer.bias_gradients.as_ref().unwrap();
256 NablaOptimizer::sgd_update(biases, bias_grads, learning_rate);
257 }
258 }
259 }
260
261 (predictions, loss)
262 }
263
264 fn count_correct(&self, predictions: &NDArray, targets: &NDArray) -> usize {
265 let pred_classes = predictions.argmax(Some(1));
266 let true_classes = targets.argmax(Some(1));
267
268 pred_classes.iter()
269 .zip(true_classes.iter())
270 .filter(|(&p, &t)| p == t)
271 .count()
272 }
273
274 pub fn new_functional(inputs: Vec<Input>, outputs: Vec<Output>) -> Self {
276 let mut layers = Vec::new();
277 let mut visited = std::collections::HashSet::new();
278
279 for input in inputs {
281 let mut layer = NabLayer::input(input.shape.clone(), None);
282 layer.set_node_index(input.node_index.unwrap());
283 visited.insert(input.node_index.unwrap());
284 layers.push(layer);
285 }
286
287 for output in outputs {
289 let mut current = Some(output);
290 let mut layer_stack = Vec::new();
291
292 while let Some(curr) = current {
294 if !visited.contains(&curr.layer.node_index.unwrap()) {
295 visited.insert(curr.layer.node_index.unwrap());
296 layer_stack.push(curr.layer);
297 }
298 current = curr.previous_output.map(|prev| *prev);
299 }
300
301 layers.extend(layer_stack.into_iter().rev());
303 }
304
305 NabModel {
306 layers,
307 optimizer_type: String::new(),
308 learning_rate: 0.0,
309 loss_type: String::new(),
310 metrics: Vec::new(),
311 }
312 }
313
314 pub fn fit(&mut self, x_train: &NDArray, y_train: &NDArray,
337 batch_size: usize, epochs: usize,
338 validation_data: Option<(&NDArray, &NDArray)>)
339 -> HashMap<String, Vec<f64>> {
340 let mut history = HashMap::new();
341 let mut train_metrics = Vec::new();
342 let mut val_metrics = Vec::new();
343
344 for epoch in 0..epochs {
345 let metrics = self.train_epoch(x_train, y_train, batch_size);
347 train_metrics.push(metrics);
348
349 if let Some((x_val, y_val)) = validation_data {
351 let val_metric = self.evaluate(x_val, y_val, batch_size);
352 val_metrics.push(val_metric);
353 }
354
355 self.print_progress(epoch + 1, epochs, &train_metrics[epoch],
357 val_metrics.last());
358 }
359
360 history.insert("loss".to_string(),
362 train_metrics.iter().map(|m| m["loss"]).collect());
363 history.insert("accuracy".to_string(),
364 train_metrics.iter().map(|m| m["accuracy"]).collect());
365
366 if !val_metrics.is_empty() {
367 history.insert("val_loss".to_string(),
368 val_metrics.iter().map(|m| m["loss"]).collect());
369 history.insert("val_accuracy".to_string(),
370 val_metrics.iter().map(|m| m["accuracy"]).collect());
371 }
372
373 history
374 }
375
376 fn print_progress(
378 &self,
379 epoch: usize,
380 total_epochs: usize,
381 train_metrics: &HashMap<String, f64>,
382 val_metrics: Option<&HashMap<String, f64>>,
383 ) {
384 print!("Epoch {}/{} - ", epoch, total_epochs);
385 for (name, value) in train_metrics {
386 print!("{}: {:.4} ", name, value);
387 }
388 if let Some(val_metrics) = val_metrics {
389 for (name, value) in val_metrics {
390 print!("val_{}: {:.4} ", name, value);
391 }
392 }
393 println!();
394 }
395
396 #[allow(unused_variables)]
406 pub fn evaluate(&mut self, x_test: &NDArray, y_test: &NDArray,
407 batch_size: usize) -> HashMap<String, f64> {
408 let mut metrics = HashMap::new();
409 let predictions = self.predict(x_test);
410
411 let loss = self.calculate_loss(&predictions, y_test);
413 metrics.insert("loss".to_string(), loss);
414
415 for metric in &self.metrics {
417 match metric.as_str() {
418 "accuracy" => {
419 let acc = self.calculate_accuracy(&predictions, y_test);
420 metrics.insert("accuracy".to_string(), acc);
421 }
422 _ => {}
423 }
424 }
425
426 metrics
427 }
428
429 fn calculate_accuracy(&self, predictions: &NDArray, targets: &NDArray) -> f64 {
431 let pred_classes = predictions.argmax(Some(1));
432 let true_classes = targets.argmax(Some(1));
433
434 let correct = pred_classes.iter()
435 .zip(true_classes.iter())
436 .filter(|(&p, &t)| p == t)
437 .count();
438
439 correct as f64 / predictions.shape()[0] as f64
440 }
441
442 pub fn predict(&mut self, x: &NDArray) -> NDArray {
450 let mut current = x.clone();
451 for layer in &mut self.layers {
452 current = layer.forward(¤t, false);
453 }
454 current
455 }
456
457 fn calculate_loss(&self, predictions: &NDArray, targets: &NDArray) -> f64 {
458 match self.loss_type.as_str() {
459 "mse" => NabLoss::mean_squared_error(predictions, targets),
460 "categorical_crossentropy" => NabLoss::cross_entropy_loss(predictions, targets),
461 _ => NabLoss::mean_squared_error(predictions, targets),
462 }
463 }
464
465 fn calculate_loss_gradient(&self, predictions: &NDArray, targets: &NDArray) -> NDArray {
466 match self.loss_type.as_str() {
467 "mse" => predictions.subtract(targets).divide_scalar(predictions.shape()[0] as f64),
468 "categorical_crossentropy" => predictions.subtract(targets).divide_scalar(predictions.shape()[0] as f64),
469 _ => predictions.subtract(targets).divide_scalar(predictions.shape()[0] as f64),
470 }
471 }
472
473 pub fn print_layers(&self) {
475 println!("\nLayer stack:");
476 for (i, layer) in self.layers.iter().enumerate() {
477 println!("{}: {} -> {:?}", i, layer.get_name(), layer.get_output_shape());
478 }
479 }
480
481 pub fn summary(&self) {
513 println!("Model: \"functional\"");
514 println!("─────────────────────────────────────────────────────");
515 println!("{:<20} {:<18} {:<10}", "Layer (type)", "Output Shape", "Param #");
516 println!("=================================================");
517
518 let mut total_params = 0;
519 let mut trainable_params = 0;
520 let mut non_trainable_params = 0;
521
522 for layer in &self.layers {
524 let (params, trainable) = self.count_params(layer);
525 total_params += params;
526 if trainable {
527 trainable_params += params;
528 } else {
529 non_trainable_params += params;
530 }
531
532 let shape_str = format!("(None, {})",
533 layer.get_output_shape()
534 .iter()
535 .map(|x| x.to_string())
536 .collect::<Vec<_>>()
537 .join(", ")
538 );
539
540 let layer_type = if layer.get_name().contains("input") {
541 layer.get_name().to_string()
542 } else {
543 format!("{} ({})",
544 layer.get_name(),
545 layer.get_type()
546 )
547 };
548
549 println!("{:<20} {:<18} {:<10}",
550 layer_type,
551 shape_str,
552 self.format_number(params)
553 );
554 }
555
556 println!("=================================================");
557 println!("Total params: {}", self.format_number(total_params));
558 println!("Trainable params: {}", self.format_number(trainable_params));
559 println!("Non-trainable params: {}", self.format_number(non_trainable_params));
560 }
561
562 fn count_params(&self, layer: &NabLayer) -> (usize, bool) {
564 let mut params = 0;
565
566 if let Some(weights) = &layer.weights {
568 params += weights.data().len();
569 }
570
571 if let Some(biases) = &layer.biases {
573 params += biases.data().len();
574 }
575
576 (params, layer.is_trainable())
577 }
578
579 fn format_number(&self, n: usize) -> String {
581 n.to_string()
582 .chars()
583 .rev()
584 .collect::<Vec<_>>()
585 .chunks(3)
586 .map(|chunk| chunk.iter().collect::<String>())
587 .collect::<Vec<_>>()
588 .join(",")
589 .chars()
590 .rev()
591 .collect()
592 }
593
594 pub fn save_compressed<P: AsRef<Path>>(&self, path: P) -> std::io::Result<()> {
599 let file = std::fs::File::create(path)?;
601 let mut encoder = GzEncoder::new(file, Compression::best());
602
603 let model_data = ModelData {
605 config: ModelConfig {
606 optimizer_type: self.optimizer_type.clone(),
607 learning_rate: self.learning_rate,
608 loss_type: self.loss_type.clone(),
609 metrics: self.metrics.clone(),
610 },
611 layers: self.layers.iter().map(|layer| LayerState {
612 layer_type: layer.get_type().to_string(),
613 name: layer.get_name().to_string(),
614 input_shape: layer.input_shape.clone(),
615 output_shape: layer.output_shape.clone(),
616 weights: layer.weights.as_ref().map(|w| w.data().to_vec()),
617 biases: layer.biases.as_ref().map(|b| b.data().to_vec()),
618 activation: layer.activation.clone(),
619 }).collect(),
620 };
621
622 let serialized = serde_json::to_string(&model_data)?;
624 encoder.write_all(serialized.as_bytes())?;
625 encoder.finish()?;
626
627 Ok(())
628 }
629
630 pub fn load_compressed<P: AsRef<Path>>(path: P) -> std::io::Result<Self> {
635 let file = std::fs::File::open(path)?;
637 let mut decoder = GzDecoder::new(file);
638
639 let mut contents = String::new();
641 decoder.read_to_string(&mut contents)?;
642
643 let model_data: ModelData = serde_json::from_str(&contents)?;
645
646 let mut layers = Vec::new();
648 for state in model_data.layers {
649 let mut layer = match state.layer_type.as_str() {
650 "Input" => NabLayer::input(state.input_shape.clone(), Some(&state.name)),
651 "Dense" => NabLayer::dense(
652 state.input_shape[0],
653 state.output_shape[0],
654 state.activation.as_deref(),
655 Some(&state.name)
656 ),
657 _ => return Err(std::io::Error::new(
658 std::io::ErrorKind::InvalidData,
659 format!("Unknown layer type: {}", state.layer_type)
660 )),
661 };
662
663 if let Some(weights) = state.weights {
665 let weight_shape = match state.layer_type.as_str() {
666 "Dense" => vec![state.input_shape[0], state.output_shape[0]],
667 _ => state.input_shape.clone()
668 };
669 layer.weights = Some(NDArray::new(weights, weight_shape));
670 }
671 if let Some(biases) = state.biases {
672 layer.biases = Some(NDArray::new(biases, vec![state.output_shape[0]]));
673 }
674
675 layers.push(layer);
676 }
677
678 Ok(NabModel {
679 layers,
680 optimizer_type: model_data.config.optimizer_type,
681 learning_rate: model_data.config.learning_rate,
682 loss_type: model_data.config.loss_type,
683 metrics: model_data.config.metrics,
684 })
685 }
686}
687
688#[derive(Serialize, Deserialize)]
690struct ModelConfig {
691 optimizer_type: String,
692 learning_rate: f64,
693 loss_type: String,
694 metrics: Vec<String>,
695}
696
697#[derive(Serialize, Deserialize)]
699struct LayerState {
700 layer_type: String,
701 name: String,
702 input_shape: Vec<usize>,
703 output_shape: Vec<usize>,
704 weights: Option<Vec<f64>>,
705 biases: Option<Vec<f64>>,
706 activation: Option<String>,
707}
708
709#[derive(Serialize, Deserialize)]
710struct ModelData {
711 config: ModelConfig,
712 layers: Vec<LayerState>,
713}
714
715pub fn reset_node_id() {
718 unsafe {
719 NEXT_NODE_ID = 0;
720 }
721}
722
723#[cfg(test)]
724#[allow(unused_imports)]
725#[allow(unused_variables)]
726mod tests {
727 use super::*;
728 use crate::nab_activations::NablaActivation;
729 use crate::nab_optimizers::NablaOptimizer;
730 use crate::nab_loss::NabLoss;
731 use crate::nab_mnist::NabMnist;
732 use crate::nab_utils::NabUtils;
733
734 #[test]
735 fn test_linear_regression() {
736 reset_node_id();
738
739 let x_data = NDArray::from_matrix(vec![
742 vec![1.0], vec![2.0], vec![3.0], vec![4.0], vec![5.0]
743 ]);
744 let y_data = NDArray::from_matrix(vec![
745 vec![3.1], vec![5.0], vec![6.9], vec![9.2], vec![11.0]
746 ]);
747
748 let input = NabModel::input(vec![1]);
750 let output_layer = NabLayer::dense(1, 1, None, Some("linear_output"));
751 let output = input.apply(output_layer);
752
753 let mut model = NabModel::new_functional(vec![input], vec![output]);
755 model.compile(
756 "sgd",
757 0.01,
758 "mse",
759 vec!["mse".to_string()]
760 );
761
762 for _ in 0..100 { model.train_epoch(&x_data, &y_data, x_data.shape()[0]); }
766
767 let predictions = model.predict(&x_data);
769
770 let pred_vec = predictions.data();
772 for i in 1..pred_vec.len() {
773 assert!(pred_vec[i] > pred_vec[i-1],
774 "Predictions should increase monotonically. Found {} <= {} at index {}",
775 pred_vec[i], pred_vec[i-1], i
776 );
777 }
778 }
779
780
781 #[test]
794 fn test_mnist_full_pipeline() {
795 println!("Internal test ... skipping ...");
797 }
878
879 #[test]
880 fn test_model_summary() {
881 reset_node_id();
883
884 let input = NabModel::input(vec![784]);
886 let dense1 = NabLayer::dense(784, 32, Some("relu"), Some("dense1"));
887 let x = input.apply(dense1);
888
889 let dense2 = NabLayer::dense(32, 32, Some("relu"), Some("dense2"));
890 let x = x.apply(dense2);
891
892 let output_layer = NabLayer::dense(32, 10, Some("softmax"), Some("output"));
893 let output = x.apply(output_layer);
894
895 let model = NabModel::new_functional(vec![input], vec![output]);
896
897 let output = std::io::stdout();
899 let handle = output.lock();
900
901 model.summary();
902
903 let total_params: usize = model.layers.iter()
905 .map(|l| model.count_params(l).0)
906 .sum();
907
908 assert_eq!(total_params, 784*32 + 32 + 32*32 + 32 + 32*10 + 10); }
910
911 #[test]
912 fn test_model_save_load() {
913 reset_node_id();
915
916 let input = NabModel::input(vec![784]);
918 let dense1 = NabLayer::dense(784, 32, Some("relu"), Some("dense1"));
919 let x = input.apply(dense1);
920
921 let dense2 = NabLayer::dense(32, 32, Some("relu"), Some("dense2"));
922 let x = x.apply(dense2);
923
924 let output_layer = NabLayer::dense(32, 10, Some("softmax"), Some("output"));
925 let output = x.apply(output_layer);
926
927 let mut model = NabModel::new_functional(vec![input], vec![output]);
928 model.compile("sgd", 0.1, "categorical_crossentropy", vec!["accuracy".to_string()]);
929
930
931 model.save_compressed("test_model.ez").expect("Failed to save model");
933
934 let loaded_model = NabModel::load_compressed("test_model.ez").expect("Failed to load model");
936
937 assert_eq!(loaded_model.optimizer_type, model.optimizer_type);
939 assert_eq!(loaded_model.learning_rate, model.learning_rate);
940 assert_eq!(loaded_model.loss_type, model.loss_type);
941 assert_eq!(loaded_model.metrics, model.metrics);
942
943 assert_eq!(loaded_model.layers.len(), model.layers.len());
945 for (loaded, original) in loaded_model.layers.iter().zip(model.layers.iter()) {
946 assert_eq!(loaded.get_type(), original.get_type());
947 assert_eq!(loaded.get_output_shape(), original.get_output_shape());
948
949 if let (Some(w1), Some(w2)) = (&loaded.weights, &original.weights) {
950 assert_eq!(w1.shape(), w2.shape(), "Weight shapes don't match");
951 assert!(w1.data().iter().zip(w2.data().iter())
952 .all(|(a, b)| (a - b).abs() < 1e-6),
953 "Weight values don't match");
954 }
955
956 if let (Some(b1), Some(b2)) = (&loaded.biases, &original.biases) {
957 assert_eq!(b1.shape(), b2.shape(), "Bias shapes don't match");
958 assert!(b1.data().iter().zip(b2.data().iter())
959 .all(|(a, b)| (a - b).abs() < 1e-6),
960 "Bias values don't match");
961 }
962 }
963
964 std::fs::remove_file("test_model.ez").expect("Failed to clean up test file");
966 }
967
968 #[test]
969 fn test_input_shape() {
970 reset_node_id();
972
973 let shape = vec![784, 32];
974 let input = NabModel::input(shape.clone());
975
976 assert_eq!(input.get_input_shape(), &shape);
978
979 let dense = NabLayer::dense(784, 128, Some("relu"), Some("dense1"));
981 let output = input.apply(dense);
982 assert_eq!(input.get_input_shape(), &shape, "Input shape should remain unchanged after applying layer");
983 }
984}
985