numrs/autograd/
train.rs

1//! High-level Training API
2//! 
3//! Simplifica el training loop:
4//! - Trainer para ejecutar epochs
5//! - Métricas de evaluación
6//! - Early stopping
7
8use crate::{Tensor, Array};
9use crate::autograd::{Module, Optimizer};
10use anyhow::Result;
11
12
13/// Función de loss
14pub trait LossFunction {
15    fn compute(&self, predictions: &Tensor, targets: &Tensor) -> Result<Tensor>;
16}
17
18/// Mean Squared Error Loss
19pub struct MSELoss;
20
21impl LossFunction for MSELoss {
22    fn compute(&self, predictions: &Tensor, targets: &Tensor) -> Result<Tensor> {
23        predictions.mse_loss(targets)
24    }
25}
26
27/// Cross Entropy Loss
28pub struct CrossEntropyLoss;
29
30impl LossFunction for CrossEntropyLoss {
31    fn compute(&self, predictions: &Tensor, targets: &Tensor) -> Result<Tensor> {
32        predictions.cross_entropy_loss(targets)
33    }
34}
35
36/// Métricas de evaluación
37pub struct Metrics {
38    pub loss: f32,
39    pub accuracy: Option<f32>,
40}
41
42impl Metrics {
43    pub fn new(loss: f32) -> Self {
44        Metrics { loss, accuracy: None }
45    }
46    
47    pub fn with_accuracy(loss: f32, accuracy: f32) -> Self {
48        Metrics { loss, accuracy: Some(accuracy) }
49    }
50}
51
52/// Dataset simple para training
53pub struct Dataset {
54    pub inputs: Array<f32>,
55    pub targets: Array<f32>,
56    pub batch_size: usize,
57    pub num_samples: usize,
58}
59
60impl Dataset {
61    pub fn new(inputs: Vec<Vec<f32>>, targets: Vec<Vec<f32>>, batch_size: usize) -> Self {
62        assert_eq!(inputs.len(), targets.len(), "Inputs y targets deben tener mismo tamaño");
63        let num_samples = inputs.len();
64        let input_dim = if num_samples > 0 { inputs[0].len() } else { 0 };
65        let target_dim = if num_samples > 0 { targets[0].len() } else { 0 };
66        
67        // Flatten inputs
68        let mut flat_inputs = Vec::with_capacity(num_samples * input_dim);
69        for row in &inputs {
70            flat_inputs.extend_from_slice(row);
71        }
72        
73        // Flatten targets
74        let mut flat_targets = Vec::with_capacity(num_samples * target_dim);
75        for row in &targets {
76            flat_targets.extend_from_slice(row);
77        }
78        
79        Dataset { 
80            inputs: Array::new(vec![num_samples, input_dim], flat_inputs),
81            targets: Array::new(vec![num_samples, target_dim], flat_targets),
82            batch_size,
83            num_samples
84        }
85    }
86    
87    /// Devuelve el número de batches
88    pub fn num_batches(&self) -> usize {
89        (self.num_samples + self.batch_size - 1) / self.batch_size
90    }
91    
92    /// Devuelve un batch específico
93    pub fn get_batch(&self, batch_idx: usize) -> Result<(Tensor, Tensor)> {
94        let start = batch_idx * self.batch_size;
95        let end = (start + self.batch_size).min(self.num_samples);
96        
97        if start >= self.num_samples {
98            return Err(anyhow::anyhow!("Batch index fuera de rango"));
99        }
100        
101        let actual_batch_size = end - start;
102        let input_dim = self.inputs.shape[1];
103        let target_dim = self.targets.shape[1];
104        
105        // Slice input data efficiently (contiguous slice)
106        let input_start_idx = start * input_dim;
107        let input_end_idx = end * input_dim;
108        let batch_inputs_data = self.inputs.data[input_start_idx..input_end_idx].to_vec();
109        
110        // Slice target data efficiently
111        let target_start_idx = start * target_dim;
112        let target_end_idx = end * target_dim;
113        let batch_targets_data = self.targets.data[target_start_idx..target_end_idx].to_vec();
114        
115        let inputs_tensor = Tensor::new(
116            Array::new(vec![actual_batch_size, input_dim], batch_inputs_data),
117            false
118        );
119        
120        let targets_tensor = Tensor::new(
121            Array::new(vec![actual_batch_size, target_dim], batch_targets_data),
122            false
123        );
124        
125        Ok((inputs_tensor, targets_tensor))
126    }
127}
128
129/// Trainer de alto nivel
130/// 
131/// Ejemplo:
132/// ```ignore
133/// let trainer = Trainer::new(model, optimizer, MSELoss);
134/// let metrics = trainer.train_epoch(&dataset)?;
135/// println!("Loss: {:.4}", metrics.loss);
136/// ```
137pub struct Trainer<M: Module, O: Optimizer> {
138    pub model: M,
139    optimizer: O,
140    loss_fn: Box<dyn LossFunction>,
141}
142
143impl<M: Module, O: Optimizer> Trainer<M, O> {
144    pub fn new(model: M, optimizer: O, loss_fn: Box<dyn LossFunction>) -> Self {
145        Trainer { model, optimizer, loss_fn }
146    }
147    
148    /// Get reference to the model (for debugging)
149    pub fn model(&self) -> &M {
150        &self.model
151    }
152    
153    /// Get mutable reference to the model
154    pub fn model_mut(&mut self) -> &mut M {
155        &mut self.model
156    }
157    
158    /// Entrena una época completa
159    pub fn train_epoch(&mut self, dataset: &Dataset) -> Result<Metrics> {
160        let mut total_loss = 0.0;
161        let num_batches = dataset.num_batches();
162        
163        for batch_idx in 0..num_batches {
164            let (inputs, targets) = dataset.get_batch(batch_idx)?;
165            
166            // Forward pass
167            let predictions = self.model.forward(&inputs)?;
168            
169            // Compute loss
170            let loss = self.loss_fn.compute(&predictions, &targets)?;
171            total_loss += loss.values()[0];
172            
173            // Backward pass
174            loss.backward()?;
175            
176            // Update weights
177            self.optimizer.step()?;
178            self.optimizer.zero_grad();
179        }
180        
181        let avg_loss = total_loss / num_batches as f32;
182        Ok(Metrics::new(avg_loss))
183    }
184    
185    /// Evalúa el modelo sin actualizar pesos
186    pub fn evaluate(&self, dataset: &Dataset) -> Result<Metrics> {
187        let mut total_loss = 0.0;
188        let mut correct = 0;
189        let mut total = 0;
190        let num_batches = dataset.num_batches();
191        
192        for batch_idx in 0..num_batches {
193            let (inputs, targets) = dataset.get_batch(batch_idx)?;
194            
195            // Forward pass (sin gradientes)
196            let predictions = self.model.forward(&inputs)?;
197            
198            // Compute loss
199            let loss = self.loss_fn.compute(&predictions, &targets)?;
200            total_loss += loss.values()[0];
201            
202            // Compute accuracy (para clasificación)
203            let pred_vals = predictions.values();
204            let target_vals = targets.values();
205            
206            let batch_size = predictions.shape()[0];
207            let num_classes = predictions.shape()[1];
208            
209            for i in 0..batch_size {
210                // Argmax de predicciones
211                let pred_start = i * num_classes;
212                let pred_end = pred_start + num_classes;
213                let pred_class = pred_vals[pred_start..pred_end]
214                    .iter()
215                    .enumerate()
216                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
217                    .map(|(idx, _)| idx)
218                    .unwrap_or(0); // Default to class 0 if empty (should not happen)
219                
220                // Argmax de targets
221                let target_start = i * num_classes;
222                let target_end = target_start + num_classes;
223                let target_class = target_vals[target_start..target_end]
224                    .iter()
225                    .enumerate()
226                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
227                    .map(|(idx, _)| idx)
228                    .unwrap_or(0);
229                
230                if pred_class == target_class {
231                    correct += 1;
232                }
233                total += 1;
234            }
235        }
236        
237        let avg_loss = total_loss / num_batches as f32;
238        let accuracy = correct as f32 / total as f32;
239        
240        Ok(Metrics::with_accuracy(avg_loss, accuracy))
241    }
242    
243    /// Training loop completo con múltiples epochs
244    pub fn fit(
245        &mut self,
246        train_dataset: &Dataset,
247        val_dataset: Option<&Dataset>,
248        epochs: usize,
249        verbose: bool,
250    ) -> Result<Vec<(Metrics, Option<Metrics>)>> {
251        let mut history = Vec::new();
252        
253        for epoch in 0..epochs {
254            // Train
255            let train_metrics = self.train_epoch(train_dataset)?;
256            
257            // Validate
258            let val_metrics = if let Some(val_ds) = val_dataset {
259                Some(self.evaluate(val_ds)?)
260            } else {
261                None
262            };
263            
264            if verbose {
265                print!("Epoch {}/{}: train_loss={:.4}", epoch + 1, epochs, train_metrics.loss);
266                
267                if let Some(ref vm) = val_metrics {
268                    print!(", val_loss={:.4}", vm.loss);
269                    if let Some(acc) = vm.accuracy {
270                        print!(", val_acc={:.4}", acc);
271                    }
272                }
273                println!();
274            }
275            
276            history.push((train_metrics, val_metrics));
277        }
278        
279        Ok(history)
280    }
281}
282
283/// Builder para crear Trainer fácilmente
284pub struct TrainerBuilder<M: Module> {
285    model: M,
286    learning_rate: f32,
287}
288
289impl<M: Module> TrainerBuilder<M> {
290    pub fn new(model: M) -> Self {
291        TrainerBuilder {
292            model,
293            learning_rate: 0.01,
294        }
295    }
296    
297    pub fn learning_rate(mut self, lr: f32) -> Self {
298        self.learning_rate = lr;
299        self
300    }
301    
302    pub fn build_sgd(self, loss_fn: Box<dyn LossFunction>) -> Trainer<M, crate::autograd::SGD> {
303        let params = self.model.parameters();
304        let optimizer = crate::autograd::SGD::new(params, self.learning_rate, 0.9, 0.0);
305        Trainer::new(self.model, optimizer, loss_fn)
306    }
307    
308    pub fn build_adam(self, loss_fn: Box<dyn LossFunction>) -> Trainer<M, crate::autograd::Adam> {
309        let params = self.model.parameters();
310        let optimizer = crate::autograd::Adam::with_lr(params, self.learning_rate);
311        Trainer::new(self.model, optimizer, loss_fn)
312    }
313
314    /// Generic builder for any optimizer
315    pub fn build_with<O, F>(self, optimizer_factory: F, loss_fn: Box<dyn LossFunction>) -> Trainer<M, O>
316    where 
317        O: Optimizer,
318        F: FnOnce(Vec<std::rc::Rc<std::cell::RefCell<crate::Tensor>>>, f32) -> O,
319    {
320        let params = self.model.parameters();
321        let optimizer = optimizer_factory(params, self.learning_rate);
322        Trainer::new(self.model, optimizer, loss_fn)
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    
330    #[test]
331    fn test_dataset() -> Result<()> {
332        let inputs = vec![
333            vec![1.0, 2.0],
334            vec![3.0, 4.0],
335            vec![5.0, 6.0],
336        ];
337        let targets = vec![
338            vec![0.0],
339            vec![1.0],
340            vec![2.0],
341        ];
342        
343        let dataset = Dataset::new(inputs, targets, 2);
344        assert_eq!(dataset.num_batches(), 2);
345        
346        let (batch_inputs, batch_targets) = dataset.get_batch(0)?;
347        assert_eq!(batch_inputs.shape(), &[2, 2]);
348        assert_eq!(batch_targets.shape(), &[2, 1]);
349        
350        Ok(())
351    }
352}