1use crate::{Array, Tensor};
9use anyhow::{anyhow, Result};
10use std::cell::RefCell;
11use std::rc::Rc;
12
13pub trait Module {
24 fn forward(&self, input: &Tensor) -> Result<Tensor>;
26
27 fn parameters(&self) -> Vec<Rc<RefCell<Tensor>>>;
29
30 fn train(&mut self) {
32 }
34
35 fn eval(&mut self) {
37 }
39
40 fn box_clone(&self) -> Box<dyn Module>;
42}
43
44impl Clone for Box<dyn Module> {
45 fn clone(&self) -> Box<dyn Module> {
46 self.box_clone()
47 }
48}
49
50#[derive(Clone)]
58pub struct Linear {
59 weight: Rc<RefCell<Tensor>>, bias: Rc<RefCell<Tensor>>, in_features: usize,
62 out_features: usize,
63}
64
65impl Linear {
66 pub fn new(in_features: usize, out_features: usize) -> Result<Self> {
68 let k = (1.0 / in_features as f32).sqrt();
70
71 let mut weight_data = Vec::with_capacity(in_features * out_features);
74 for i in 0..(in_features * out_features) {
75 let val = (i as f32 * 0.123).sin() * k;
77 weight_data.push(val);
78 }
79
80 let weight = Rc::new(RefCell::new(Tensor::new(
81 Array::new(vec![out_features, in_features], weight_data),
82 true, )));
84
85 let bias = Rc::new(RefCell::new(Tensor::new(
86 Array::new(vec![out_features], vec![0.0; out_features]),
87 true,
88 )));
89
90 Ok(Linear {
91 weight,
92 bias,
93 in_features,
94 out_features,
95 })
96 }
97
98 pub fn weight(&self) -> Rc<RefCell<Tensor>> {
99 self.weight.clone()
100 }
101
102 pub fn bias(&self) -> Rc<RefCell<Tensor>> {
103 self.bias.clone()
104 }
105
106 pub fn with_weights(
108 in_features: usize,
109 out_features: usize,
110 weight_data: Vec<f32>,
111 bias_data: Vec<f32>,
112 ) -> Result<Self> {
113 let weight = Rc::new(RefCell::new(Tensor::new(
114 Array::new(vec![out_features, in_features], weight_data),
115 true,
116 )));
117
118 let bias = Rc::new(RefCell::new(Tensor::new(
119 Array::new(vec![out_features], bias_data),
120 true,
121 )));
122
123 Ok(Linear {
124 weight,
125 bias,
126 in_features,
127 out_features,
128 })
129 }
130
131 pub fn parameters_mut(&self) -> Vec<Rc<RefCell<Tensor>>> {
133 vec![self.weight.clone(), self.bias.clone()]
134 }
135
136 pub fn backward(&self, input: &Tensor, grad_output: &Array) -> Result<()> {
138 let batch_size = input.data.shape[0];
139
140 {
142 let weight_shape;
143 let weight_data_len;
144 {
145 let weight = self.weight.borrow();
146 weight_shape = weight.data.shape.clone();
147 weight_data_len = weight.data.data.len();
148 }
149
150 let mut weight = self.weight.borrow_mut();
151
152 if weight.grad.is_none() {
153 weight.grad = Some(Rc::new(RefCell::new(Array::new(
154 weight_shape.clone(),
155 vec![0.0; weight_data_len],
156 ))));
157 }
158
159 let grad_weight = weight
160 .grad
161 .as_ref()
162 .ok_or_else(|| anyhow!("Failed to initialize weight gradient"))?;
163 let mut grad_w = grad_weight.borrow_mut();
164
165 for i in 0..self.out_features {
166 for j in 0..self.in_features {
167 let mut sum = 0.0;
168 for b in 0..batch_size {
169 let grad_out_val = grad_output.data[b * self.out_features + i];
170 let input_val = input.data.data[b * self.in_features + j];
171 sum += grad_out_val * input_val;
172 }
173 grad_w.data[i * self.in_features + j] += sum;
174 }
175 }
176 }
177
178 {
180 let bias_shape;
181 let bias_data_len;
182 {
183 let bias = self.bias.borrow();
184 bias_shape = bias.data.shape.clone();
185 bias_data_len = bias.data.data.len();
186 }
187
188 let mut bias = self.bias.borrow_mut();
189
190 if bias.grad.is_none() {
191 bias.grad = Some(Rc::new(RefCell::new(Array::new(
192 bias_shape.clone(),
193 vec![0.0; bias_data_len],
194 ))));
195 }
196
197 let grad_bias = bias
198 .grad
199 .as_ref()
200 .ok_or_else(|| anyhow!("Failed to initialize bias gradient"))?;
201 let mut grad_b = grad_bias.borrow_mut();
202
203 for i in 0..self.out_features {
204 let mut sum = 0.0;
205 for b in 0..batch_size {
206 sum += grad_output.data[b * self.out_features + i];
207 }
208 grad_b.data[i] += sum;
209 }
210 }
211
212 Ok(())
213 }
214}
215
216impl Module for Linear {
217 fn forward(&self, input: &Tensor) -> Result<Tensor> {
218 let w = self.weight.borrow();
225 let b = self.bias.borrow();
226
227 let w_t = w.transpose()?;
229
230 let output = input.matmul(&w_t)?;
232
233 output.add(&*b)
235 }
236
237 fn parameters(&self) -> Vec<Rc<RefCell<Tensor>>> {
238 vec![self.weight.clone(), self.bias.clone()]
239 }
240
241 fn box_clone(&self) -> Box<dyn Module> {
242 Box::new(self.clone())
243 }
244}
245
246#[derive(Clone)]
257pub struct Sequential {
258 layers: Vec<Box<dyn Module>>,
259}
260
261impl Sequential {
262 pub fn new(layers: Vec<Box<dyn Module>>) -> Self {
263 Sequential { layers }
264 }
265
266 pub fn add(&mut self, layer: Box<dyn Module>) {
267 self.layers.push(layer);
268 }
269}
270
271impl Module for Sequential {
272 fn forward(&self, input: &Tensor) -> Result<Tensor> {
273 let mut output = input.clone();
274 for layer in &self.layers {
275 output = layer.forward(&output)?;
276 }
277 Ok(output)
278 }
279
280 fn parameters(&self) -> Vec<Rc<RefCell<Tensor>>> {
281 self.layers
282 .iter()
283 .flat_map(|layer| layer.parameters())
284 .collect()
285 }
286
287 fn train(&mut self) {
288 for layer in &mut self.layers {
289 layer.train();
290 }
291 }
292
293 fn eval(&mut self) {
294 for layer in &mut self.layers {
295 layer.eval();
296 }
297 }
298
299 fn box_clone(&self) -> Box<dyn Module> {
300 Box::new(self.clone())
301 }
302}
303
304#[derive(Clone)]
306pub struct ReLU;
307
308impl Module for ReLU {
309 fn forward(&self, input: &Tensor) -> Result<Tensor> {
310 input.relu()
311 }
312
313 fn parameters(&self) -> Vec<Rc<RefCell<Tensor>>> {
314 vec![] }
316
317 fn box_clone(&self) -> Box<dyn Module> {
318 Box::new(self.clone())
319 }
320}
321
322#[derive(Clone)]
324pub struct Sigmoid;
325
326impl Module for Sigmoid {
327 fn forward(&self, input: &Tensor) -> Result<Tensor> {
328 input.sigmoid()
329 }
330
331 fn parameters(&self) -> Vec<Rc<RefCell<Tensor>>> {
332 vec![]
333 }
334
335 fn box_clone(&self) -> Box<dyn Module> {
336 Box::new(self.clone())
337 }
338}
339
340#[derive(Clone)]
345pub struct Softmax;
346
347impl Module for Softmax {
348 fn forward(&self, input: &Tensor) -> Result<Tensor> {
349 let axis = input.data.shape.len() - 1;
353 let result = crate::ops::stats::softmax::softmax(&input.data, Some(axis))?;
354
355 if !crate::autograd::is_grad_enabled() || !input.requires_grad {
356 return Ok(Tensor::new(result, false));
357 }
358
359 let backward_fn = Box::new(crate::autograd::backward::softmax_backward);
360 let node = crate::autograd::ComputeNode::new(
361 crate::autograd::OpKind::Softmax,
362 vec![input.clone()],
363 Some(backward_fn),
364 );
365 Ok(Tensor::from_operation(result, node, true))
366 }
367
368 fn parameters(&self) -> Vec<Rc<RefCell<Tensor>>> {
369 vec![]
370 }
371
372 fn box_clone(&self) -> Box<dyn Module> {
373 Box::new(self.clone())
374 }
375}
376
377#[derive(Clone)]
379pub struct Dropout {
380 p: f32,
381 training: bool,
382}
383
384impl Dropout {
385 pub fn new(p: f32) -> Self {
386 Dropout { p, training: true }
387 }
388}
389
390impl Module for Dropout {
391 fn forward(&self, input: &Tensor) -> Result<Tensor> {
392 input.dropout(self.p, self.training)
393 }
394
395 fn parameters(&self) -> Vec<Rc<RefCell<Tensor>>> {
396 vec![]
397 }
398
399 fn train(&mut self) {
400 self.training = true;
401 }
402
403 fn eval(&mut self) {
404 self.training = false;
405 }
406
407 fn box_clone(&self) -> Box<dyn Module> {
408 Box::new(self.clone())
409 }
410}
411
412#[derive(Clone)]
414pub struct Flatten {
415 start_dim: usize,
416 end_dim: usize,
417}
418
419impl Flatten {
420 pub fn new(start_dim: usize, end_dim: usize) -> Self {
421 Flatten { start_dim, end_dim }
422 }
423}
424
425impl Module for Flatten {
426 fn forward(&self, input: &Tensor) -> Result<Tensor> {
427 input.flatten(self.start_dim, self.end_dim)
428 }
429
430 fn parameters(&self) -> Vec<Rc<RefCell<Tensor>>> {
431 vec![]
432 }
433
434 fn box_clone(&self) -> Box<dyn Module> {
435 Box::new(self.clone())
436 }
437}
438
439#[derive(Clone)]
441pub struct Conv1d {
442 weight: Rc<RefCell<Tensor>>, bias: Rc<RefCell<Tensor>>, stride: usize,
445 padding: usize,
446 _in_channels: usize,
447 _out_channels: usize,
448 _kernel_size: usize,
449}
450
451impl Conv1d {
452 pub fn new(
453 in_channels: usize,
454 out_channels: usize,
455 kernel_size: usize,
456 stride: usize,
457 padding: usize,
458 ) -> Result<Self> {
459 let k = (1.0 / (in_channels * kernel_size) as f32).sqrt();
461
462 let mut weight_data = Vec::with_capacity(out_channels * in_channels * kernel_size);
463 for _ in 0..(out_channels * in_channels * kernel_size) {
464 weight_data.push(0.01); }
469
470 let weight_data: Vec<f32> = (0..(out_channels * in_channels * kernel_size))
474 .map(|i| ((i as f32 * 0.123).sin()) * k)
475 .collect();
476
477 let weight = Rc::new(RefCell::new(Tensor::new(
478 Array::new(vec![out_channels, in_channels, kernel_size], weight_data),
479 true,
480 )));
481
482 let bias = Rc::new(RefCell::new(Tensor::new(
483 Array::new(vec![out_channels], vec![0.0; out_channels]),
484 true,
485 )));
486
487 Ok(Conv1d {
488 weight,
489 bias,
490 stride,
491 padding,
492 _in_channels: in_channels,
493 _out_channels: out_channels,
494 _kernel_size: kernel_size,
495 })
496 }
497
498 pub fn weight(&self) -> Rc<RefCell<Tensor>> {
499 self.weight.clone()
500 }
501
502 pub fn bias(&self) -> Rc<RefCell<Tensor>> {
503 self.bias.clone()
504 }
505}
506
507impl Module for Conv1d {
508 fn forward(&self, input: &Tensor) -> Result<Tensor> {
509 let w = self.weight.borrow();
510 let b = self.bias.borrow();
511 input.conv1d(&*w, Some(&*b), self.stride, self.padding)
512 }
513
514 fn parameters(&self) -> Vec<Rc<RefCell<Tensor>>> {
515 vec![self.weight.clone(), self.bias.clone()]
516 }
517
518 fn box_clone(&self) -> Box<dyn Module> {
519 Box::new(self.clone())
520 }
521}
522
523#[derive(Clone)]
525pub struct BatchNorm1d {
526 weight: Rc<RefCell<Tensor>>, bias: Rc<RefCell<Tensor>>, running_mean: Rc<RefCell<Tensor>>,
529 running_var: Rc<RefCell<Tensor>>,
530 momentum: f32,
531 eps: f32,
532 training: bool,
533}
534
535impl BatchNorm1d {
536 pub fn new(num_features: usize) -> Result<Self> {
537 let weight = Rc::new(RefCell::new(Tensor::new(
539 Array::new(vec![num_features], vec![1.0; num_features]),
540 true,
541 )));
542
543 let bias = Rc::new(RefCell::new(Tensor::new(
545 Array::new(vec![num_features], vec![0.0; num_features]),
546 true,
547 )));
548
549 let running_mean = Rc::new(RefCell::new(Tensor::new(
551 Array::new(vec![num_features], vec![0.0; num_features]),
552 true, )));
554
555 let running_var = Rc::new(RefCell::new(Tensor::new(
556 Array::new(vec![num_features], vec![1.0; num_features]),
557 true, )));
559
560 Ok(BatchNorm1d {
561 weight,
562 bias,
563 running_mean,
564 running_var,
565 momentum: 0.1,
566 eps: 1e-5,
567 training: true,
568 })
569 }
570
571 pub fn running_mean(&self) -> Rc<RefCell<Tensor>> {
572 self.running_mean.clone()
573 }
574}
575
576impl Module for BatchNorm1d {
577 fn forward(&self, input: &Tensor) -> Result<Tensor> {
578 let w = self.weight.borrow();
579 let b = self.bias.borrow();
580 let mut rm = self.running_mean.borrow_mut();
581 let mut rv = self.running_var.borrow_mut();
582
583 input.batch_norm(
584 &mut *rm,
585 &mut *rv,
586 &*w,
587 &*b,
588 self.training,
589 self.momentum,
590 self.eps,
591 )
592 }
593
594 fn parameters(&self) -> Vec<Rc<RefCell<Tensor>>> {
595 vec![self.weight.clone(), self.bias.clone()]
596 }
597
598 fn train(&mut self) {
599 self.training = true;
600 }
601
602 fn eval(&mut self) {
603 self.training = false;
604 }
605
606 fn box_clone(&self) -> Box<dyn Module> {
607 Box::new(self.clone())
608 }
609}
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614
615 #[test]
616 fn test_linear_forward() -> Result<()> {
617 let linear = Linear::with_weights(
619 2,
620 3,
621 vec![
622 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
625 ], vec![0.1, 0.2, 0.3],
627 )?;
628
629 let input = Tensor::new(Array::new(vec![1, 2], vec![1.0, 2.0]), false);
631
632 let output = linear.forward(&input)?;
637
638 assert_eq!(output.shape(), &[1, 3]);
639 let vals = output.values();
640 assert!((vals[0] - 5.1).abs() < 1e-5);
641 assert!((vals[1] - 11.2).abs() < 1e-5);
642 assert!((vals[2] - 17.3).abs() < 1e-5);
643
644 Ok(())
645 }
646
647 #[test]
648 fn test_sequential() -> Result<()> {
649 let model = Sequential::new(vec![
650 Box::new(Linear::with_weights(
651 2,
652 2,
653 vec![1.0, 0.0, 0.0, 1.0],
654 vec![1.0, 2.0],
655 )?),
656 Box::new(ReLU),
657 ]);
658
659 let input = Tensor::new(Array::new(vec![1, 2], vec![1.0, 2.0]), false);
660 let output = model.forward(&input)?;
661
662 assert_eq!(output.shape(), &[1, 2]);
663
664 Ok(())
665 }
666}