1use crate::{Tensor, Array};
9use crate::autograd::{Module, Optimizer};
10use anyhow::Result;
11
12
13pub trait LossFunction {
15 fn compute(&self, predictions: &Tensor, targets: &Tensor) -> Result<Tensor>;
16}
17
18pub struct MSELoss;
20
21impl LossFunction for MSELoss {
22 fn compute(&self, predictions: &Tensor, targets: &Tensor) -> Result<Tensor> {
23 predictions.mse_loss(targets)
24 }
25}
26
27pub struct CrossEntropyLoss;
29
30impl LossFunction for CrossEntropyLoss {
31 fn compute(&self, predictions: &Tensor, targets: &Tensor) -> Result<Tensor> {
32 predictions.cross_entropy_loss(targets)
33 }
34}
35
36pub struct Metrics {
38 pub loss: f32,
39 pub accuracy: Option<f32>,
40}
41
42impl Metrics {
43 pub fn new(loss: f32) -> Self {
44 Metrics { loss, accuracy: None }
45 }
46
47 pub fn with_accuracy(loss: f32, accuracy: f32) -> Self {
48 Metrics { loss, accuracy: Some(accuracy) }
49 }
50}
51
52pub struct Dataset {
54 pub inputs: Array<f32>,
55 pub targets: Array<f32>,
56 pub batch_size: usize,
57 pub num_samples: usize,
58}
59
60impl Dataset {
61 pub fn new(inputs: Vec<Vec<f32>>, targets: Vec<Vec<f32>>, batch_size: usize) -> Self {
62 assert_eq!(inputs.len(), targets.len(), "Inputs y targets deben tener mismo tamaño");
63 let num_samples = inputs.len();
64 let input_dim = if num_samples > 0 { inputs[0].len() } else { 0 };
65 let target_dim = if num_samples > 0 { targets[0].len() } else { 0 };
66
67 let mut flat_inputs = Vec::with_capacity(num_samples * input_dim);
69 for row in &inputs {
70 flat_inputs.extend_from_slice(row);
71 }
72
73 let mut flat_targets = Vec::with_capacity(num_samples * target_dim);
75 for row in &targets {
76 flat_targets.extend_from_slice(row);
77 }
78
79 Dataset {
80 inputs: Array::new(vec![num_samples, input_dim], flat_inputs),
81 targets: Array::new(vec![num_samples, target_dim], flat_targets),
82 batch_size,
83 num_samples
84 }
85 }
86
87 pub fn num_batches(&self) -> usize {
89 (self.num_samples + self.batch_size - 1) / self.batch_size
90 }
91
92 pub fn get_batch(&self, batch_idx: usize) -> Result<(Tensor, Tensor)> {
94 let start = batch_idx * self.batch_size;
95 let end = (start + self.batch_size).min(self.num_samples);
96
97 if start >= self.num_samples {
98 return Err(anyhow::anyhow!("Batch index fuera de rango"));
99 }
100
101 let actual_batch_size = end - start;
102 let input_dim = self.inputs.shape[1];
103 let target_dim = self.targets.shape[1];
104
105 let input_start_idx = start * input_dim;
107 let input_end_idx = end * input_dim;
108 let batch_inputs_data = self.inputs.data[input_start_idx..input_end_idx].to_vec();
109
110 let target_start_idx = start * target_dim;
112 let target_end_idx = end * target_dim;
113 let batch_targets_data = self.targets.data[target_start_idx..target_end_idx].to_vec();
114
115 let inputs_tensor = Tensor::new(
116 Array::new(vec![actual_batch_size, input_dim], batch_inputs_data),
117 false
118 );
119
120 let targets_tensor = Tensor::new(
121 Array::new(vec![actual_batch_size, target_dim], batch_targets_data),
122 false
123 );
124
125 Ok((inputs_tensor, targets_tensor))
126 }
127}
128
129pub struct Trainer<M: Module, O: Optimizer> {
138 pub model: M,
139 optimizer: O,
140 loss_fn: Box<dyn LossFunction>,
141}
142
143impl<M: Module, O: Optimizer> Trainer<M, O> {
144 pub fn new(model: M, optimizer: O, loss_fn: Box<dyn LossFunction>) -> Self {
145 Trainer { model, optimizer, loss_fn }
146 }
147
148 pub fn model(&self) -> &M {
150 &self.model
151 }
152
153 pub fn model_mut(&mut self) -> &mut M {
155 &mut self.model
156 }
157
158 pub fn train_epoch(&mut self, dataset: &Dataset) -> Result<Metrics> {
160 let mut total_loss = 0.0;
161 let num_batches = dataset.num_batches();
162
163 for batch_idx in 0..num_batches {
164 let (inputs, targets) = dataset.get_batch(batch_idx)?;
165
166 let predictions = self.model.forward(&inputs)?;
168
169 let loss = self.loss_fn.compute(&predictions, &targets)?;
171 total_loss += loss.values()[0];
172
173 loss.backward()?;
175
176 self.optimizer.step()?;
178 self.optimizer.zero_grad();
179 }
180
181 let avg_loss = total_loss / num_batches as f32;
182 Ok(Metrics::new(avg_loss))
183 }
184
185 pub fn evaluate(&self, dataset: &Dataset) -> Result<Metrics> {
187 let mut total_loss = 0.0;
188 let mut correct = 0;
189 let mut total = 0;
190 let num_batches = dataset.num_batches();
191
192 for batch_idx in 0..num_batches {
193 let (inputs, targets) = dataset.get_batch(batch_idx)?;
194
195 let predictions = self.model.forward(&inputs)?;
197
198 let loss = self.loss_fn.compute(&predictions, &targets)?;
200 total_loss += loss.values()[0];
201
202 let pred_vals = predictions.values();
204 let target_vals = targets.values();
205
206 let batch_size = predictions.shape()[0];
207 let num_classes = predictions.shape()[1];
208
209 for i in 0..batch_size {
210 let pred_start = i * num_classes;
212 let pred_end = pred_start + num_classes;
213 let pred_class = pred_vals[pred_start..pred_end]
214 .iter()
215 .enumerate()
216 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
217 .map(|(idx, _)| idx)
218 .unwrap_or(0); let target_start = i * num_classes;
222 let target_end = target_start + num_classes;
223 let target_class = target_vals[target_start..target_end]
224 .iter()
225 .enumerate()
226 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
227 .map(|(idx, _)| idx)
228 .unwrap_or(0);
229
230 if pred_class == target_class {
231 correct += 1;
232 }
233 total += 1;
234 }
235 }
236
237 let avg_loss = total_loss / num_batches as f32;
238 let accuracy = correct as f32 / total as f32;
239
240 Ok(Metrics::with_accuracy(avg_loss, accuracy))
241 }
242
243 pub fn fit(
245 &mut self,
246 train_dataset: &Dataset,
247 val_dataset: Option<&Dataset>,
248 epochs: usize,
249 verbose: bool,
250 ) -> Result<Vec<(Metrics, Option<Metrics>)>> {
251 let mut history = Vec::new();
252
253 for epoch in 0..epochs {
254 let train_metrics = self.train_epoch(train_dataset)?;
256
257 let val_metrics = if let Some(val_ds) = val_dataset {
259 Some(self.evaluate(val_ds)?)
260 } else {
261 None
262 };
263
264 if verbose {
265 print!("Epoch {}/{}: train_loss={:.4}", epoch + 1, epochs, train_metrics.loss);
266
267 if let Some(ref vm) = val_metrics {
268 print!(", val_loss={:.4}", vm.loss);
269 if let Some(acc) = vm.accuracy {
270 print!(", val_acc={:.4}", acc);
271 }
272 }
273 println!();
274 }
275
276 history.push((train_metrics, val_metrics));
277 }
278
279 Ok(history)
280 }
281}
282
283pub struct TrainerBuilder<M: Module> {
285 model: M,
286 learning_rate: f32,
287}
288
289impl<M: Module> TrainerBuilder<M> {
290 pub fn new(model: M) -> Self {
291 TrainerBuilder {
292 model,
293 learning_rate: 0.01,
294 }
295 }
296
297 pub fn learning_rate(mut self, lr: f32) -> Self {
298 self.learning_rate = lr;
299 self
300 }
301
302 pub fn build_sgd(self, loss_fn: Box<dyn LossFunction>) -> Trainer<M, crate::autograd::SGD> {
303 let params = self.model.parameters();
304 let optimizer = crate::autograd::SGD::new(params, self.learning_rate, 0.9, 0.0);
305 Trainer::new(self.model, optimizer, loss_fn)
306 }
307
308 pub fn build_adam(self, loss_fn: Box<dyn LossFunction>) -> Trainer<M, crate::autograd::Adam> {
309 let params = self.model.parameters();
310 let optimizer = crate::autograd::Adam::with_lr(params, self.learning_rate);
311 Trainer::new(self.model, optimizer, loss_fn)
312 }
313
314 pub fn build_with<O, F>(self, optimizer_factory: F, loss_fn: Box<dyn LossFunction>) -> Trainer<M, O>
316 where
317 O: Optimizer,
318 F: FnOnce(Vec<std::rc::Rc<std::cell::RefCell<crate::Tensor>>>, f32) -> O,
319 {
320 let params = self.model.parameters();
321 let optimizer = optimizer_factory(params, self.learning_rate);
322 Trainer::new(self.model, optimizer, loss_fn)
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn test_dataset() -> Result<()> {
332 let inputs = vec![
333 vec![1.0, 2.0],
334 vec![3.0, 4.0],
335 vec![5.0, 6.0],
336 ];
337 let targets = vec![
338 vec![0.0],
339 vec![1.0],
340 vec![2.0],
341 ];
342
343 let dataset = Dataset::new(inputs, targets, 2);
344 assert_eq!(dataset.num_batches(), 2);
345
346 let (batch_inputs, batch_targets) = dataset.get_batch(0)?;
347 assert_eq!(batch_inputs.shape(), &[2, 2]);
348 assert_eq!(batch_targets.shape(), &[2, 1]);
349
350 Ok(())
351 }
352}