nabla_ml/
nab_model.rs

1use crate::nab_array::NDArray;
2use crate::nab_layers::NabLayer;
3use crate::nab_optimizers::NablaOptimizer;
4use crate::nab_loss::NabLoss;
5use std::collections::HashMap;
6use serde::{Serialize, Deserialize};
7use std::path::Path;
8use serde_json;
9use flate2::write::GzEncoder;
10use flate2::read::GzDecoder;
11use flate2::Compression;
12use std::io::{Write, Read};
13
14static mut NEXT_NODE_ID: usize = 0;
15
16/// Represents a node in the computation graph
17pub struct Node {
18    pub layer: NabLayer,
19    pub inputs: Vec<usize>,  // Indices of input nodes
20    pub output_shape: Vec<usize>,
21}
22
23/// Represents a model using the Functional API
24/// 
25/// # Examples
26/// 
27/// ```rust
28/// use nabla_ml::nab_model::NabModel;
29/// use nabla_ml::nab_layers::NabLayer;
30/// 
31/// // Create model architecture
32/// let input = NabModel::input(vec![784]);
33/// let dense1 = NabLayer::dense(784, 512, Some("relu"), Some("dense1"));
34/// let x = input.apply(dense1);
35/// let output_layer = NabLayer::dense(512, 10, Some("softmax"), Some("output"));
36/// let output = x.apply(output_layer);
37/// 
38/// // Create and compile model
39/// let mut model = NabModel::new_functional(vec![input], vec![output]);
40/// model.compile(
41///     "sgd",
42///     0.1,
43///     "categorical_crossentropy",
44///     vec!["accuracy".to_string()]
45/// );
46/// ```
47#[allow(dead_code)]
48#[derive(Clone)]
49pub struct NabModel {
50    layers: Vec<NabLayer>,
51    optimizer_type: String,
52    learning_rate: f64,
53    loss_type: String,  // e.g. "mse", "categorical_crossentropy"
54    metrics: Vec<String>,
55}
56
57/// Represents an input node in the computation graph
58#[derive(Clone)]
59pub struct Input {
60    shape: Vec<usize>,
61    node_index: Option<usize>,
62}
63
64/// Represents an output node in the computation graph
65#[derive(Clone)]
66#[allow(dead_code)]
67pub struct Output {
68    layer: NabLayer,
69    inputs: Vec<usize>,
70    output_shape: Vec<usize>,
71    previous_output: Option<Box<Output>>,
72}
73
74impl Input {
75    /// Applies a layer to this input, preserving node connectivity
76    pub fn apply<L: Into<NabLayer>>(&self, layer: L) -> Output {
77        let mut layer = layer.into();
78        let output_shape = layer.compute_output_shape(&self.shape);
79        
80        // Get next ID safely
81        let layer_id = unsafe {
82            NEXT_NODE_ID += 1;
83            NEXT_NODE_ID
84        };
85        
86        layer.set_node_index(layer_id);
87        
88        println!("Connecting layer {} (id: {}) to input (id: {})", 
89            layer.get_name(), 
90            layer_id,
91            self.node_index.unwrap()
92        );
93        
94        Output {
95            layer,
96            inputs: vec![self.node_index.unwrap()],
97            output_shape,
98            previous_output: None,
99        }
100    }
101
102    /// Returns the input shape of this Input node
103    /// 
104    /// # Returns
105    /// 
106    /// A reference to the shape vector
107    pub fn get_input_shape(&self) -> &Vec<usize> {
108        &self.shape
109    }
110}
111
112impl Output {
113    /// Applies a layer to this output, maintaining the graph structure
114    pub fn apply<L: Into<NabLayer>>(&self, layer: L) -> Output {
115        let mut layer = layer.into();
116        let output_shape = layer.compute_output_shape(&self.output_shape);
117        
118        // Get next ID safely
119        let layer_id = unsafe {
120            NEXT_NODE_ID += 1;
121            NEXT_NODE_ID
122        };
123        
124        layer.set_node_index(layer_id);
125        
126        println!("Connecting layer {} (id: {}) to {} (id: {})", 
127            layer.get_name(), 
128            layer_id,
129            self.layer.get_name(),
130            self.layer.node_index.unwrap()
131        );
132        
133        Output {
134            layer,
135            inputs: vec![self.layer.node_index.unwrap()],
136            output_shape,
137            previous_output: Some(Box::new(self.clone())),
138        }
139    }
140
141    /// Returns the previous layer that produced this output
142    pub fn get_previous_layer(&self) -> Option<&NabLayer> {
143        // Return layer that produced this output
144        None // TODO: Implement layer tracking
145    }
146}
147
148#[allow(dead_code)]
149impl NabModel {
150    /// Creates a new input layer with specified shape
151    /// 
152    /// # Arguments
153    /// * `shape` - Shape of input excluding batch dimension
154    /// 
155    /// # Examples
156    /// ```ignore
157    /// let input = NabModel::input(vec![784]); // For MNIST images
158    /// ```
159    pub fn input(shape: Vec<usize>) -> Input {
160        let node_index = unsafe {
161            NEXT_NODE_ID += 1;
162            NEXT_NODE_ID
163        };
164        
165        Input {
166            shape,
167            node_index: Some(node_index),
168        }
169    }
170
171    /// Creates a new model
172    pub fn new() -> Self {
173        NabModel {
174            layers: Vec::new(),
175            optimizer_type: String::new(),
176            learning_rate: 0.0,
177            loss_type: String::new(),
178            metrics: Vec::new(),
179        }
180    }
181
182    /// Adds a layer to the model
183    pub fn add(&mut self, layer: NabLayer) -> &mut Self {
184        self.layers.push(layer);
185        self
186    }
187
188    /// Compiles the model with training configuration
189    /// 
190    /// # Arguments
191    /// * `optimizer_type` - Optimization algorithm ("sgd", "adam", etc)
192    /// * `learning_rate` - Learning rate for optimization
193    /// * `loss_type` - Loss function ("mse", "categorical_crossentropy")
194    /// * `metrics` - Metrics to track during training
195    pub fn compile(&mut self, optimizer_type: &str, learning_rate: f64, 
196                  loss_type: &str, metrics: Vec<String>) {
197        self.optimizer_type = optimizer_type.to_string();
198        self.learning_rate = learning_rate;
199        self.loss_type = loss_type.to_string();
200        self.metrics = metrics;
201    }
202
203    /// Trains for one epoch
204    fn train_epoch(&mut self, x: &NDArray, y: &NDArray, batch_size: usize) -> HashMap<String, f64> {
205        let mut metrics = HashMap::new();
206        let mut total_loss = 0.0;
207        let mut total_correct = 0;
208        let num_samples = x.shape()[0];
209        let num_batches = (num_samples + batch_size - 1) / batch_size;
210
211        // Process mini-batches
212        for batch_idx in 0..num_batches {
213            let start_idx = batch_idx * batch_size;
214            let end_idx = (start_idx + batch_size).min(num_samples);
215            
216            // Get batch data
217            let x_batch = x.slice(start_idx, end_idx);
218            let y_batch = y.slice(start_idx, end_idx);
219            
220            // Forward and backward pass as one operation
221            let (predictions, loss) = self.forward_backward(&x_batch, &y_batch);
222            
223            // Accumulate metrics
224            total_loss += loss * (end_idx - start_idx) as f64;
225            total_correct += self.count_correct(&predictions, &y_batch);
226        }
227        
228        // Calculate average metrics
229        metrics.insert("loss".to_string(), total_loss / num_samples as f64);
230        metrics.insert("accuracy".to_string(), total_correct as f64 / num_samples as f64);
231        
232        metrics
233    }
234
235    fn forward_backward(&mut self, x_batch: &NDArray, y_batch: &NDArray) -> (NDArray, f64) {
236        // Forward pass
237        let predictions = self.predict(x_batch);
238        let loss = self.calculate_loss(&predictions, y_batch);
239        let loss_grad = self.calculate_loss_gradient(&predictions, y_batch);
240        
241        // Backward pass
242        let mut gradient = loss_grad;
243        let learning_rate = self.learning_rate;  // Cache learning rate
244        
245        for layer in self.layers.iter_mut().rev() {
246            if layer.is_trainable() {
247                gradient = layer.backward(&gradient);
248                
249                // Update weights using cached learning rate
250                if let Some(weights) = layer.weights.as_mut() {
251                    let weight_grads = layer.weight_gradients.as_ref().unwrap();
252                    NablaOptimizer::sgd_update(weights, weight_grads, learning_rate);
253                }
254                if let Some(biases) = layer.biases.as_mut() {
255                    let bias_grads = layer.bias_gradients.as_ref().unwrap();
256                    NablaOptimizer::sgd_update(biases, bias_grads, learning_rate);
257                }
258            }
259        }
260        
261        (predictions, loss)
262    }
263
264    fn count_correct(&self, predictions: &NDArray, targets: &NDArray) -> usize {
265        let pred_classes = predictions.argmax(Some(1));
266        let true_classes = targets.argmax(Some(1));
267        
268        pred_classes.iter()
269            .zip(true_classes.iter())
270            .filter(|(&p, &t)| p == t)
271            .count()
272    }
273
274    /// Creates a new model from input and output nodes
275    pub fn new_functional(inputs: Vec<Input>, outputs: Vec<Output>) -> Self {
276        let mut layers = Vec::new();
277        let mut visited = std::collections::HashSet::new();
278        
279        // First add input layers
280        for input in inputs {
281            let mut layer = NabLayer::input(input.shape.clone(), None);
282            layer.set_node_index(input.node_index.unwrap());
283            visited.insert(input.node_index.unwrap());
284            layers.push(layer);
285        }
286        
287        // Then add remaining layers by traversing backwards from each output
288        for output in outputs {
289            let mut current = Some(output);
290            let mut layer_stack = Vec::new();
291            
292            // Build stack of layers from output to input
293            while let Some(curr) = current {
294                if !visited.contains(&curr.layer.node_index.unwrap()) {
295                    visited.insert(curr.layer.node_index.unwrap());
296                    layer_stack.push(curr.layer);
297                }
298                current = curr.previous_output.map(|prev| *prev);
299            }
300            
301            // Add layers in reverse order (from input to output)
302            layers.extend(layer_stack.into_iter().rev());
303        }
304
305        NabModel {
306            layers,
307            optimizer_type: String::new(),
308            learning_rate: 0.0,
309            loss_type: String::new(),
310            metrics: Vec::new(),
311        }
312    }
313
314    /// Trains the model on input data
315    /// 
316    /// # Arguments
317    /// * `x_train` - Training features
318    /// * `y_train` - Training labels 
319    /// * `batch_size` - Mini-batch size
320    /// * `epochs` - Number of training epochs
321    /// * `validation_data` - Optional validation dataset
322    /// 
323    /// # Returns
324    /// HashMap containing training history metrics
325    /// 
326    /// # Examples
327    /// ```ignore
328    /// let history = model.fit(
329    ///     &x_train,
330    ///     &y_train, 
331    ///     64,    // batch_size
332    ///     5,     // epochs
333    ///     Some((&x_test, &y_test))
334    /// );
335    /// ```
336    pub fn fit(&mut self, x_train: &NDArray, y_train: &NDArray,
337               batch_size: usize, epochs: usize,
338               validation_data: Option<(&NDArray, &NDArray)>) 
339               -> HashMap<String, Vec<f64>> {
340        let mut history = HashMap::new();
341        let mut train_metrics = Vec::new();
342        let mut val_metrics = Vec::new();
343
344        for epoch in 0..epochs {
345            // Training phase
346            let metrics = self.train_epoch(x_train, y_train, batch_size);
347            train_metrics.push(metrics);
348
349            // Validation phase
350            if let Some((x_val, y_val)) = validation_data {
351                let val_metric = self.evaluate(x_val, y_val, batch_size);
352                val_metrics.push(val_metric);
353            }
354
355            // Print progress
356            self.print_progress(epoch + 1, epochs, &train_metrics[epoch], 
357                              val_metrics.last());
358        }
359
360        // Store history
361        history.insert("loss".to_string(), 
362            train_metrics.iter().map(|m| m["loss"]).collect());
363        history.insert("accuracy".to_string(), 
364            train_metrics.iter().map(|m| m["accuracy"]).collect());
365
366        if !val_metrics.is_empty() {
367            history.insert("val_loss".to_string(), 
368                val_metrics.iter().map(|m| m["loss"]).collect());
369            history.insert("val_accuracy".to_string(), 
370                val_metrics.iter().map(|m| m["accuracy"]).collect());
371        }
372
373        history
374    }
375
376    /// Prints training progress
377    fn print_progress(
378        &self,
379        epoch: usize,
380        total_epochs: usize,
381        train_metrics: &HashMap<String, f64>,
382        val_metrics: Option<&HashMap<String, f64>>,
383    ) {
384        print!("Epoch {}/{} - ", epoch, total_epochs);
385        for (name, value) in train_metrics {
386            print!("{}: {:.4} ", name, value);
387        }
388        if let Some(val_metrics) = val_metrics {
389            for (name, value) in val_metrics {
390                print!("val_{}: {:.4} ", name, value);
391            }
392        }
393        println!();
394    }
395
396    /// Evaluates model performance on test data
397    /// 
398    /// # Arguments
399    /// * `x_test` - Test features
400    /// * `y_test` - Test labels
401    /// * `batch_size` - Batch size for evaluation
402    /// 
403    /// # Returns
404    /// HashMap containing evaluation metrics
405    #[allow(unused_variables)]
406    pub fn evaluate(&mut self, x_test: &NDArray, y_test: &NDArray,
407                   batch_size: usize) -> HashMap<String, f64> {
408        let mut metrics = HashMap::new();
409        let predictions = self.predict(x_test);
410        
411        // Calculate loss
412        let loss = self.calculate_loss(&predictions, y_test);
413        metrics.insert("loss".to_string(), loss);
414        
415        // Calculate other metrics
416        for metric in &self.metrics {
417            match metric.as_str() {
418                "accuracy" => {
419                    let acc = self.calculate_accuracy(&predictions, y_test);
420                    metrics.insert("accuracy".to_string(), acc);
421                }
422                _ => {}
423            }
424        }
425        
426        metrics
427    }
428
429    /// Calculates accuracy for classification tasks
430    fn calculate_accuracy(&self, predictions: &NDArray, targets: &NDArray) -> f64 {
431        let pred_classes = predictions.argmax(Some(1));
432        let true_classes = targets.argmax(Some(1));
433        
434        let correct = pred_classes.iter()
435            .zip(true_classes.iter())
436            .filter(|(&p, &t)| p == t)
437            .count();
438            
439        correct as f64 / predictions.shape()[0] as f64
440    }
441
442    /// Makes predictions on input data
443    /// 
444    /// # Arguments
445    /// * `x` - Input features to predict on
446    /// 
447    /// # Returns
448    /// NDArray of model predictions
449    pub fn predict(&mut self, x: &NDArray) -> NDArray {
450        let mut current = x.clone();
451        for layer in &mut self.layers {
452            current = layer.forward(&current, false);
453        }
454        current
455    }
456
457    fn calculate_loss(&self, predictions: &NDArray, targets: &NDArray) -> f64 {
458        match self.loss_type.as_str() {
459            "mse" => NabLoss::mean_squared_error(predictions, targets),
460            "categorical_crossentropy" => NabLoss::cross_entropy_loss(predictions, targets),
461            _ => NabLoss::mean_squared_error(predictions, targets),
462        }
463    }
464
465    fn calculate_loss_gradient(&self, predictions: &NDArray, targets: &NDArray) -> NDArray {
466        match self.loss_type.as_str() {
467            "mse" => predictions.subtract(targets).divide_scalar(predictions.shape()[0] as f64),
468            "categorical_crossentropy" => predictions.subtract(targets).divide_scalar(predictions.shape()[0] as f64),
469            _ => predictions.subtract(targets).divide_scalar(predictions.shape()[0] as f64),
470        }
471    }
472
473    // Add debug method
474    pub fn print_layers(&self) {
475        println!("\nLayer stack:");
476        for (i, layer) in self.layers.iter().enumerate() {
477            println!("{}: {} -> {:?}", i, layer.get_name(), layer.get_output_shape());
478        }
479    }
480
481    /// Prints a summary of the model's layers and parameters
482    /// 
483    /// Displays a formatted table showing:
484    /// - Layer name and type
485    /// - Output shape
486    /// - Number of parameters
487    /// 
488    /// Also shows total parameters, trainable parameters, and non-trainable parameters
489    /// 
490    /// # Example
491    /// ```ignore
492    /// use nabla_ml::nab_model::NabModel;
493    /// use nabla_ml::nab_layers::NabLayer;
494    /// 
495    /// let input = NabModel::input(vec![784]);
496    /// let dense = NabLayer::dense(784, 128, Some("relu"), Some("dense1"));
497    /// let output = input.apply(dense);
498    /// let mut model = NabModel::new_functional(vec![input], vec![output]);
499    /// 
500    /// model.summary();
501    /// // Model: "sequential"
502    /// // ─────────────────────────────────────────────────────
503    /// // Layer (type)          Output Shape         Param #   
504    /// // =================================================
505    /// // input                 (None, 784)          0         
506    /// // dense1 (Dense)        (None, 128)          100,480   
507    /// // =================================================
508    /// // Total params: 100,480
509    /// // Trainable params: 100,480
510    /// // Non-trainable params: 0
511    /// ```
512    pub fn summary(&self) {
513        println!("Model: \"functional\"");
514        println!("─────────────────────────────────────────────────────");
515        println!("{:<20} {:<18} {:<10}", "Layer (type)", "Output Shape", "Param #");
516        println!("=================================================");
517
518        let mut total_params = 0;
519        let mut trainable_params = 0;
520        let mut non_trainable_params = 0;
521
522        // Print each layer's info
523        for layer in &self.layers {
524            let (params, trainable) = self.count_params(layer);
525            total_params += params;
526            if trainable {
527                trainable_params += params;
528            } else {
529                non_trainable_params += params;
530            }
531
532            let shape_str = format!("(None, {})", 
533                layer.get_output_shape()
534                    .iter()
535                    .map(|x| x.to_string())
536                    .collect::<Vec<_>>()
537                    .join(", ")
538            );
539
540            let layer_type = if layer.get_name().contains("input") {
541                layer.get_name().to_string()
542            } else {
543                format!("{} ({})", 
544                    layer.get_name(),
545                    layer.get_type()
546                )
547            };
548
549            println!("{:<20} {:<18} {:<10}", 
550                layer_type,
551                shape_str,
552                self.format_number(params)
553            );
554        }
555
556        println!("=================================================");
557        println!("Total params: {}", self.format_number(total_params));
558        println!("Trainable params: {}", self.format_number(trainable_params));
559        println!("Non-trainable params: {}", self.format_number(non_trainable_params));
560    }
561
562    /// Counts parameters for a given layer
563    fn count_params(&self, layer: &NabLayer) -> (usize, bool) {
564        let mut params = 0;
565        
566        // Count weights
567        if let Some(weights) = &layer.weights {
568            params += weights.data().len();
569        }
570        
571        // Count biases
572        if let Some(biases) = &layer.biases {
573            params += biases.data().len();
574        }
575
576        (params, layer.is_trainable())
577    }
578
579    /// Formats large numbers with commas
580    fn format_number(&self, n: usize) -> String {
581        n.to_string()
582            .chars()
583            .rev()
584            .collect::<Vec<_>>()
585            .chunks(3)
586            .map(|chunk| chunk.iter().collect::<String>())
587            .collect::<Vec<_>>()
588            .join(",")
589            .chars()
590            .rev()
591            .collect()
592    }
593
594    /// Saves the model to a compressed .ez file
595    /// 
596    /// # Arguments
597    /// * `path` - Path to save the model (e.g. "model.ez")
598    pub fn save_compressed<P: AsRef<Path>>(&self, path: P) -> std::io::Result<()> {
599        // Create encoder with best compression
600        let file = std::fs::File::create(path)?;
601        let mut encoder = GzEncoder::new(file, Compression::best());
602        
603        // Prepare and serialize model data
604        let model_data = ModelData {
605            config: ModelConfig {
606                optimizer_type: self.optimizer_type.clone(),
607                learning_rate: self.learning_rate,
608                loss_type: self.loss_type.clone(),
609                metrics: self.metrics.clone(),
610            },
611            layers: self.layers.iter().map(|layer| LayerState {
612                layer_type: layer.get_type().to_string(),
613                name: layer.get_name().to_string(),
614                input_shape: layer.input_shape.clone(),
615                output_shape: layer.output_shape.clone(),
616                weights: layer.weights.as_ref().map(|w| w.data().to_vec()),
617                biases: layer.biases.as_ref().map(|b| b.data().to_vec()),
618                activation: layer.activation.clone(),
619            }).collect(),
620        };
621
622        // Write serialized data
623        let serialized = serde_json::to_string(&model_data)?;
624        encoder.write_all(serialized.as_bytes())?;
625        encoder.finish()?;
626
627        Ok(())
628    }
629
630    /// Loads a model from a compressed .ez file
631    /// 
632    /// # Arguments
633    /// * `path` - Path to the model file (e.g. "model.ez")
634    pub fn load_compressed<P: AsRef<Path>>(path: P) -> std::io::Result<Self> {
635        // Create decoder
636        let file = std::fs::File::open(path)?;
637        let mut decoder = GzDecoder::new(file);
638        
639        // Read and decompress
640        let mut contents = String::new();
641        decoder.read_to_string(&mut contents)?;
642        
643        // Deserialize
644        let model_data: ModelData = serde_json::from_str(&contents)?;
645        
646        // Reconstruct model
647        let mut layers = Vec::new();
648        for state in model_data.layers {
649            let mut layer = match state.layer_type.as_str() {
650                "Input" => NabLayer::input(state.input_shape.clone(), Some(&state.name)),
651                "Dense" => NabLayer::dense(
652                    state.input_shape[0],
653                    state.output_shape[0],
654                    state.activation.as_deref(),
655                    Some(&state.name)
656                ),
657                _ => return Err(std::io::Error::new(
658                    std::io::ErrorKind::InvalidData,
659                    format!("Unknown layer type: {}", state.layer_type)
660                )),
661            };
662
663            // Restore weights and biases
664            if let Some(weights) = state.weights {
665                let weight_shape = match state.layer_type.as_str() {
666                    "Dense" => vec![state.input_shape[0], state.output_shape[0]],
667                    _ => state.input_shape.clone()
668                };
669                layer.weights = Some(NDArray::new(weights, weight_shape));
670            }
671            if let Some(biases) = state.biases {
672                layer.biases = Some(NDArray::new(biases, vec![state.output_shape[0]]));
673            }
674
675            layers.push(layer);
676        }
677
678        Ok(NabModel {
679            layers,
680            optimizer_type: model_data.config.optimizer_type,
681            learning_rate: model_data.config.learning_rate,
682            loss_type: model_data.config.loss_type,
683            metrics: model_data.config.metrics,
684        })
685    }
686}
687
688/// Serializable model configuration
689#[derive(Serialize, Deserialize)]
690struct ModelConfig {
691    optimizer_type: String,
692    learning_rate: f64,
693    loss_type: String,
694    metrics: Vec<String>,
695}
696
697/// Serializable layer state
698#[derive(Serialize, Deserialize)]
699struct LayerState {
700    layer_type: String,
701    name: String,
702    input_shape: Vec<usize>,
703    output_shape: Vec<usize>,
704    weights: Option<Vec<f64>>,
705    biases: Option<Vec<f64>>,
706    activation: Option<String>,
707}
708
709#[derive(Serialize, Deserialize)]
710struct ModelData {
711    config: ModelConfig,
712    layers: Vec<LayerState>,
713}
714
715/// Resets the global node ID counter
716/// Used for testing to ensure consistent behavior
717pub fn reset_node_id() {
718    unsafe {
719        NEXT_NODE_ID = 0;
720    }
721}
722
723#[cfg(test)]
724#[allow(unused_imports)]
725#[allow(unused_variables)]
726mod tests {
727    use super::*;
728    use crate::nab_activations::NablaActivation;
729    use crate::nab_optimizers::NablaOptimizer;
730    use crate::nab_loss::NabLoss;
731    use crate::nab_mnist::NabMnist;
732    use crate::nab_utils::NabUtils;
733
734    #[test]
735    fn test_linear_regression() {
736        // Reset node ID counter before test
737        reset_node_id();
738        
739        // Create synthetic data for linear regression
740        // y = 2x + 1 with some noise
741        let x_data = NDArray::from_matrix(vec![
742            vec![1.0], vec![2.0], vec![3.0], vec![4.0], vec![5.0]
743        ]);
744        let y_data = NDArray::from_matrix(vec![
745            vec![3.1], vec![5.0], vec![6.9], vec![9.2], vec![11.0]
746        ]);
747
748        // Create model architecture
749        let input = NabModel::input(vec![1]);
750        let output_layer = NabLayer::dense(1, 1, None, Some("linear_output"));
751        let output = input.apply(output_layer);
752
753        // Create and compile model
754        let mut model = NabModel::new_functional(vec![input], vec![output]);
755        model.compile(
756            "sgd",
757            0.01,
758            "mse",
759            vec!["mse".to_string()]
760        );
761
762        // Train model for multiple epochs
763        for _ in 0..100 {  // Increase training iterations
764            model.train_epoch(&x_data, &y_data, x_data.shape()[0]); // Use full batch
765        }
766        
767        // Make predictions
768        let predictions = model.predict(&x_data);
769        
770        // Verify predictions follow roughly linear pattern
771        let pred_vec = predictions.data();
772        for i in 1..pred_vec.len() {
773            assert!(pred_vec[i] > pred_vec[i-1], 
774                "Predictions should increase monotonically. Found {} <= {} at index {}", 
775                pred_vec[i], pred_vec[i-1], i
776            );
777        }
778    }
779
780
781    /// Tests full training pipeline on MNIST dataset
782    /// 
783    /// This test:
784    /// 1. Loads and preprocesses MNIST data
785    /// 2. Creates a neural network with:
786    ///    - Input layer (784 units)
787    ///    - Dense layer (512 units, ReLU)
788    ///    - Dense layer (256 units, ReLU) 
789    ///    - Output layer (10 units, softmax)
790    /// 3. Compiles with SGD optimizer and cross-entropy loss
791    /// 4. Trains for 5 epochs
792    /// 5. Verifies accuracy exceeds 85%
793    #[test]
794    fn test_mnist_full_pipeline() {
795        // Step 1: Load MNIST data
796        println!("Internal test ... skipping ...");
797        // println!("Loading MNIST data...");
798        // let ((x_train, y_train), (x_test, y_test)) = NabUtils::load_and_split_dataset("datasets/mnist_test", 80.0).unwrap();
799
800        // // Step 2: Normalize input data (scale pixels to 0-1)
801        // println!("Normalizing data...");
802        // let x_train = x_train.divide_scalar(255.0);
803        // let x_test = x_test.divide_scalar(255.0);
804
805        // // Step 2.5: Reshape input data
806        // let x_train = x_train.reshape(&[x_train.shape()[0], 784])
807        //     .expect("Failed to reshape training data");
808        // let x_test = x_test.reshape(&[x_test.shape()[0], 784])
809        //     .expect("Failed to reshape test data");
810
811        // // Step 2.6: One-hot encode target data
812        // println!("One-hot encoding targets...");
813        // let y_train = NDArray::one_hot_encode(&y_train);
814        // let y_test = NDArray::one_hot_encode(&y_test);
815            
816
817        // println!("Data shapes:");
818        // println!("x_train: {:?}", x_train.shape());
819        // println!("y_train: {:?}", y_train.shape());
820        // println!("x_test: {:?}", x_test.shape());
821        // println!("y_test: {:?}", y_test.shape());
822
823        // // Step 3: Create model architecture
824        // println!("Creating model...");
825        // let input = NabModel::input(vec![784]);  // 28x28 = 784 pixels
826
827        // // Dense layer with 512 units and ReLU activation
828        // let dense1 = NabLayer::dense(784, 32, Some("relu"), Some("dense1"));
829        // let x = input.apply(dense1);
830
831        // // Dense layer with 256 units and ReLU activation
832        // let dense2 = NabLayer::dense(32, 32, Some("relu"), Some("dense2"));
833        // let x = x.apply(dense2);
834
835        // // Output layer with 10 units (one per digit) and softmax activation
836        // let output_layer = NabLayer::dense(32, 10, Some("softmax"), Some("output"));
837        // let output = x.apply(output_layer);
838
839        // // Step 4: Create and compile model
840        // println!("Compiling model...");
841        // let mut model = NabModel::new_functional(vec![input], vec![output]);
842        // model.compile(
843        //     "adam",                      
844        //     0.1,                        // Increase learning rate from 0.01 to 0.1
845        //     "categorical_crossentropy", 
846        //     vec!["accuracy".to_string()]
847        // );
848
849        // // Step 5: Train model
850        // println!("Training model...");
851        // let history = model.fit(
852        //     &x_train,
853        //     &y_train,
854        //     32,             // Increase batch size from 32 to 64
855        //     10,             // Increase epochs from 2 to 10
856        //     Some((&x_test, &y_test))
857        // );
858
859        // // Step 6: Evaluate final model
860        // println!("Evaluating model...");
861        // let eval_metrics = model.evaluate(&x_test, &y_test, 32);
862        
863        // // Print final results
864        // println!("Final test accuracy: {:.2}%", eval_metrics["accuracy"] * 100.0);
865        
866        // // Verify model achieved reasonable accuracy (>84%)
867        // assert!(eval_metrics["accuracy"] > 0.84, 
868        //     "Model accuracy ({:.2}%) below expected threshold", 
869        //     eval_metrics["accuracy"] * 100.0
870        // );
871
872        // // Verify training history contains expected metrics
873        // assert!(history.contains_key("loss"));
874        // assert!(history.contains_key("accuracy"));
875        // assert!(history.contains_key("val_loss"));
876        // assert!(history.contains_key("val_accuracy"));
877    }
878
879    #[test]
880    fn test_model_summary() {
881        // Reset node ID counter before test
882        reset_node_id();
883        
884        // Create a simple model
885        let input = NabModel::input(vec![784]);
886        let dense1 = NabLayer::dense(784, 32, Some("relu"), Some("dense1"));
887        let x = input.apply(dense1);
888
889        let dense2 = NabLayer::dense(32, 32, Some("relu"), Some("dense2"));
890        let x = x.apply(dense2);
891
892        let output_layer = NabLayer::dense(32, 10, Some("softmax"), Some("output"));
893        let output = x.apply(output_layer);
894        
895        let model = NabModel::new_functional(vec![input], vec![output]);
896
897        // Capture stdout to verify summary output
898        let output = std::io::stdout();
899        let handle = output.lock();
900        
901        model.summary();
902
903        // Verify parameter counts
904        let total_params: usize = model.layers.iter()
905            .map(|l| model.count_params(l).0)
906            .sum();
907        
908        assert_eq!(total_params, 784*32 + 32 + 32*32 + 32 + 32*10 + 10); // weights + biases
909    }
910
911    #[test]
912    fn test_model_save_load() {
913        // Reset node ID counter before test
914        reset_node_id();
915        
916        // Create a simple model
917        let input = NabModel::input(vec![784]);
918        let dense1 = NabLayer::dense(784, 32, Some("relu"), Some("dense1"));
919        let x = input.apply(dense1);
920
921        let dense2 = NabLayer::dense(32, 32, Some("relu"), Some("dense2"));
922        let x = x.apply(dense2);
923
924        let output_layer = NabLayer::dense(32, 10, Some("softmax"), Some("output"));
925        let output = x.apply(output_layer);
926
927        let mut model = NabModel::new_functional(vec![input], vec![output]);
928        model.compile("sgd", 0.1, "categorical_crossentropy", vec!["accuracy".to_string()]);
929
930
931        // Save the model
932        model.save_compressed("test_model.ez").expect("Failed to save model");
933
934        // Load the model
935        let loaded_model = NabModel::load_compressed("test_model.ez").expect("Failed to load model");
936
937        // Verify model configuration
938        assert_eq!(loaded_model.optimizer_type, model.optimizer_type);
939        assert_eq!(loaded_model.learning_rate, model.learning_rate);
940        assert_eq!(loaded_model.loss_type, model.loss_type);
941        assert_eq!(loaded_model.metrics, model.metrics);
942
943        // Verify layers
944        assert_eq!(loaded_model.layers.len(), model.layers.len());
945        for (loaded, original) in loaded_model.layers.iter().zip(model.layers.iter()) {
946            assert_eq!(loaded.get_type(), original.get_type());
947            assert_eq!(loaded.get_output_shape(), original.get_output_shape());
948            
949            if let (Some(w1), Some(w2)) = (&loaded.weights, &original.weights) {
950                assert_eq!(w1.shape(), w2.shape(), "Weight shapes don't match");
951                assert!(w1.data().iter().zip(w2.data().iter())
952                    .all(|(a, b)| (a - b).abs() < 1e-6), 
953                    "Weight values don't match");
954            }
955
956            if let (Some(b1), Some(b2)) = (&loaded.biases, &original.biases) {
957                assert_eq!(b1.shape(), b2.shape(), "Bias shapes don't match");
958                assert!(b1.data().iter().zip(b2.data().iter())
959                    .all(|(a, b)| (a - b).abs() < 1e-6),
960                    "Bias values don't match");
961            }
962        }
963
964        // Clean up test file
965        std::fs::remove_file("test_model.ez").expect("Failed to clean up test file");
966    }
967
968    #[test]
969    fn test_input_shape() {
970        // Reset node ID counter before test
971        reset_node_id();
972        
973        let shape = vec![784, 32];
974        let input = NabModel::input(shape.clone());
975        
976        // Test that get_input_shape returns the correct shape
977        assert_eq!(input.get_input_shape(), &shape);
978        
979        // Test that the shape is preserved when applying a layer
980        let dense = NabLayer::dense(784, 128, Some("relu"), Some("dense1"));
981        let output = input.apply(dense);
982        assert_eq!(input.get_input_shape(), &shape, "Input shape should remain unchanged after applying layer");
983    }
984}
985