1use std::collections::HashMap;
2
3use ferrotorch_core::{Device, FerrotorchError, FerrotorchResult, Float, Tensor};
4
5use crate::buffer::Buffer;
6use crate::hooks::{BackwardHook, ForwardHook, ForwardPreHook, HookHandle, HookedModule};
7use crate::parameter::Parameter;
8
9pub type StateDict<T> = HashMap<String, Tensor<T>>;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum Reduction {
15 Mean,
17 Sum,
19 None,
21}
22
23pub trait Module<T: Float>: Send + Sync {
27 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>;
29
30 fn parameters(&self) -> Vec<&Parameter<T>>;
32
33 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>;
35
36 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>;
41
42 fn train(&mut self);
44
45 fn eval(&mut self);
47
48 fn is_training(&self) -> bool;
50
51 fn to_device(&mut self, device: Device) -> FerrotorchResult<()> {
56 for param in self.parameters_mut() {
57 *param = param.to(device)?;
58 }
59 for buffer in self.buffers_mut() {
60 *buffer = buffer.to(device)?;
61 }
62 Ok(())
63 }
64
65 fn state_dict(&self) -> StateDict<T> {
70 let mut out: StateDict<T> = self
71 .named_parameters()
72 .into_iter()
73 .map(|(name, param)| (name, param.tensor().clone()))
74 .collect();
75 for (name, buffer) in self.named_buffers() {
76 out.insert(name, buffer.tensor().clone());
77 }
78 out
79 }
80
81 fn buffers(&self) -> Vec<&Buffer<T>> {
89 Vec::new()
90 }
91
92 fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>> {
94 Vec::new()
95 }
96
97 fn named_buffers(&self) -> Vec<(String, &Buffer<T>)> {
100 Vec::new()
101 }
102
103 fn as_any(&self) -> Option<&dyn std::any::Any> {
128 None
129 }
130
131 fn children(&self) -> Vec<&dyn Module<T>> {
137 Vec::new()
138 }
139
140 fn named_children(&self) -> Vec<(String, &dyn Module<T>)> {
143 Vec::new()
144 }
145
146 fn modules(&self) -> Vec<&dyn Module<T>>
153 where
154 Self: Sized,
155 {
156 let mut out: Vec<&dyn Module<T>> = vec![self];
157 out.extend(self.descendants_dyn());
158 out
159 }
160
161 fn descendants_dyn(&self) -> Vec<&dyn Module<T>> {
163 let mut out: Vec<&dyn Module<T>> = Vec::new();
164 for child in self.children() {
165 out.push(child);
166 out.extend(child.descendants_dyn());
167 }
168 out
169 }
170
171 fn named_modules(&self) -> Vec<(String, &dyn Module<T>)>
174 where
175 Self: Sized,
176 {
177 let mut out: Vec<(String, &dyn Module<T>)> = vec![(String::new(), self)];
178 out.extend(self.named_descendants_dyn());
179 out
180 }
181
182 fn named_descendants_dyn(&self) -> Vec<(String, &dyn Module<T>)> {
184 let mut out: Vec<(String, &dyn Module<T>)> = Vec::new();
185 for (name, child) in self.named_children() {
186 out.push((name.clone(), child));
187 for (sub_name, sub_module) in child.named_descendants_dyn() {
188 let full = if sub_name.is_empty() {
189 name.clone()
190 } else if name.is_empty() {
191 sub_name
199 } else {
200 format!("{name}.{sub_name}")
201 };
202 out.push((full, sub_module));
203 }
204 }
205 out
206 }
207
208 fn with_forward_hook(self, hook: ForwardHook<T>) -> (HookedModule<Self, T>, HookHandle)
234 where
235 Self: Sized,
236 {
237 let wrapped = HookedModule::new(self);
238 let handle = wrapped.register_forward_hook(hook);
239 (wrapped, handle)
240 }
241
242 fn with_forward_pre_hook(self, hook: ForwardPreHook<T>) -> (HookedModule<Self, T>, HookHandle)
246 where
247 Self: Sized,
248 {
249 let wrapped = HookedModule::new(self);
250 let handle = wrapped.register_forward_pre_hook(hook);
251 (wrapped, handle)
252 }
253
254 fn with_backward_hook(self, hook: BackwardHook<T>) -> (HookedModule<Self, T>, HookHandle)
258 where
259 Self: Sized,
260 {
261 let wrapped = HookedModule::new(self);
262 let handle = wrapped.register_backward_hook(hook);
263 (wrapped, handle)
264 }
265
266 fn zero_grad(&self) -> FerrotorchResult<()> {
271 for param in self.parameters() {
272 param.tensor().zero_grad()?;
273 }
274 Ok(())
275 }
276
277 fn requires_grad_(&mut self, requires_grad: bool) {
280 for param in self.parameters_mut() {
281 param.set_requires_grad(requires_grad);
282 }
283 }
284
285 fn apply_to_parameters(&mut self, f: &mut dyn FnMut(&mut Parameter<T>)) {
294 for param in self.parameters_mut() {
295 f(param);
296 }
297 }
298
299 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
305 let mut known_keys: std::collections::HashSet<String> = self
307 .named_parameters()
308 .iter()
309 .map(|(k, _)| k.clone())
310 .collect();
311 for (k, _) in self.named_buffers() {
312 known_keys.insert(k);
313 }
314
315 if strict {
316 for key in state.keys() {
317 if !known_keys.contains(key) {
318 return Err(FerrotorchError::InvalidArgument {
319 message: format!("unexpected key in state_dict: \"{key}\""),
320 });
321 }
322 }
323 }
324
325 let param_names: Vec<String> = self
329 .named_parameters()
330 .into_iter()
331 .map(|(name, _)| name)
332 .collect();
333
334 let params_mut = self.parameters_mut();
335
336 for (name, param) in param_names.iter().zip(params_mut) {
337 if let Some(tensor) = state.get(name) {
338 if param.shape() != tensor.shape() {
339 return Err(FerrotorchError::ShapeMismatch {
340 message: format!(
341 "state_dict shape mismatch for \"{name}\": expected {:?}, got {:?}",
342 param.shape(),
343 tensor.shape()
344 ),
345 });
346 }
347 *param = Parameter::new(tensor.clone());
349 } else if strict {
350 return Err(FerrotorchError::InvalidArgument {
351 message: format!("missing key in state_dict: \"{name}\""),
352 });
353 }
354 }
355
356 let buffer_names: Vec<String> = self
358 .named_buffers()
359 .into_iter()
360 .map(|(name, _)| name)
361 .collect();
362 let buffers_mut = self.buffers_mut();
363 for (name, buf) in buffer_names.iter().zip(buffers_mut) {
364 if let Some(tensor) = state.get(name) {
365 if buf.shape() != tensor.shape() {
366 return Err(FerrotorchError::ShapeMismatch {
367 message: format!(
368 "state_dict shape mismatch for buffer \"{name}\": expected {:?}, got {:?}",
369 buf.shape(),
370 tensor.shape()
371 ),
372 });
373 }
374 *buf = Buffer::new(tensor.clone());
375 } else if strict {
376 return Err(FerrotorchError::InvalidArgument {
377 message: format!("missing buffer key in state_dict: \"{name}\""),
378 });
379 }
380 }
381
382 Ok(())
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 struct SimpleModule<T: Float> {
391 weight: Parameter<T>,
392 training: bool,
393 }
394
395 impl<T: Float> SimpleModule<T> {
396 fn new(size: usize) -> FerrotorchResult<Self> {
397 Ok(Self {
398 weight: Parameter::zeros(&[size])?,
399 training: true,
400 })
401 }
402 }
403
404 impl<T: Float> Module<T> for SimpleModule<T> {
405 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
406 Ok(input.clone())
408 }
409
410 fn parameters(&self) -> Vec<&Parameter<T>> {
411 vec![&self.weight]
412 }
413
414 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
415 vec![&mut self.weight]
416 }
417
418 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
419 vec![("weight".to_string(), &self.weight)]
420 }
421
422 fn train(&mut self) {
423 self.training = true;
424 }
425
426 fn eval(&mut self) {
427 self.training = false;
428 }
429
430 fn is_training(&self) -> bool {
431 self.training
432 }
433 }
434
435 #[test]
436 fn test_module_parameters() {
437 let m = SimpleModule::<f32>::new(5).unwrap();
438 assert_eq!(m.parameters().len(), 1);
439 assert_eq!(m.parameters()[0].shape(), &[5]);
440 }
441
442 #[test]
443 fn test_module_named_parameters() {
444 let m = SimpleModule::<f32>::new(3).unwrap();
445 let named = m.named_parameters();
446 assert_eq!(named.len(), 1);
447 assert_eq!(named[0].0, "weight");
448 }
449
450 #[test]
451 fn test_module_train_eval() {
452 let mut m = SimpleModule::<f32>::new(2).unwrap();
453 assert!(m.is_training());
454 m.eval();
455 assert!(!m.is_training());
456 m.train();
457 assert!(m.is_training());
458 }
459
460 #[test]
461 fn test_module_state_dict_roundtrip() {
462 let m = SimpleModule::<f32>::new(4).unwrap();
463 let sd = m.state_dict();
464 assert!(sd.contains_key("weight"));
465 assert_eq!(sd["weight"].shape(), &[4]);
466
467 let mut m2 = SimpleModule::<f32>::new(4).unwrap();
468 m2.load_state_dict(&sd, true).unwrap();
469 }
470
471 #[test]
472 fn test_module_state_dict_strict_extra_key() {
473 let mut m = SimpleModule::<f32>::new(3).unwrap();
474 let mut sd = HashMap::new();
475 sd.insert(
476 "weight".to_string(),
477 ferrotorch_core::zeros::<f32>(&[3]).unwrap(),
478 );
479 sd.insert(
480 "extra".to_string(),
481 ferrotorch_core::zeros::<f32>(&[1]).unwrap(),
482 );
483
484 assert!(m.load_state_dict(&sd, true).is_err());
485 assert!(m.load_state_dict(&sd, false).is_ok());
486 }
487
488 #[test]
489 fn test_module_state_dict_shape_mismatch() {
490 let mut m = SimpleModule::<f32>::new(3).unwrap();
491 let mut sd = HashMap::new();
492 sd.insert(
493 "weight".to_string(),
494 ferrotorch_core::zeros::<f32>(&[5]).unwrap(),
495 );
496
497 assert!(m.load_state_dict(&sd, true).is_err());
498 }
499
500 #[test]
501 fn test_module_is_send_sync() {
502 fn assert_send_sync<T: Send + Sync>() {}
503 assert_send_sync::<SimpleModule<f32>>();
504 }
505
506 #[test]
507 fn test_reduction_enum() {
508 assert_eq!(Reduction::Mean, Reduction::Mean);
509 assert_ne!(Reduction::Mean, Reduction::Sum);
510 }
511
512 #[test]
513 fn test_to_device_cpu_preserves_weights() {
514 let mut m = SimpleModule::<f32>::new(4).unwrap();
515 m.to_device(ferrotorch_core::Device::Cpu).unwrap();
516 assert_eq!(m.parameters().len(), 1);
517 assert_eq!(m.parameters()[0].shape(), &[4]);
518 }
519
520 #[test]
521 fn test_to_device_cuda_without_backend() {
522 let mut m = SimpleModule::<f32>::new(3).unwrap();
523 let result = m.to_device(ferrotorch_core::Device::Cuda(0));
524 assert!(result.is_err());
525 }
526
527 struct ParentModule<T: Float> {
534 weight: Parameter<T>,
535 running_mean: Buffer<T>,
536 child: SimpleModule<T>,
537 }
538
539 impl<T: Float> ParentModule<T> {
540 fn new() -> FerrotorchResult<Self> {
541 Ok(Self {
542 weight: Parameter::ones(&[2, 2])?,
543 running_mean: Buffer::zeros(&[2])?,
544 child: SimpleModule::new(3)?,
545 })
546 }
547 }
548
549 impl<T: Float> Module<T> for ParentModule<T> {
550 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
551 self.child.forward(input)
552 }
553
554 fn parameters(&self) -> Vec<&Parameter<T>> {
555 let mut out: Vec<&Parameter<T>> = vec![&self.weight];
557 out.extend(self.child.parameters());
558 out
559 }
560
561 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
562 let mut out: Vec<&mut Parameter<T>> = vec![&mut self.weight];
563 out.extend(self.child.parameters_mut());
564 out
565 }
566
567 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
568 let mut out: Vec<(String, &Parameter<T>)> = vec![("weight".to_string(), &self.weight)];
569 for (n, p) in self.child.named_parameters() {
570 out.push((format!("child.{n}"), p));
571 }
572 out
573 }
574
575 fn buffers(&self) -> Vec<&Buffer<T>> {
576 vec![&self.running_mean]
577 }
578
579 fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>> {
580 vec![&mut self.running_mean]
581 }
582
583 fn named_buffers(&self) -> Vec<(String, &Buffer<T>)> {
584 vec![("running_mean".to_string(), &self.running_mean)]
585 }
586
587 fn children(&self) -> Vec<&dyn Module<T>> {
588 vec![&self.child]
589 }
590
591 fn named_children(&self) -> Vec<(String, &dyn Module<T>)> {
592 vec![("child".to_string(), &self.child)]
593 }
594
595 fn train(&mut self) {
596 self.child.train();
597 }
598
599 fn eval(&mut self) {
600 self.child.eval();
601 }
602
603 fn is_training(&self) -> bool {
604 self.child.is_training()
605 }
606 }
607
608 #[test]
609 fn module_buffers_default_is_empty() {
610 let m = SimpleModule::<f32>::new(3).unwrap();
612 assert!(m.buffers().is_empty());
613 assert!(m.named_buffers().is_empty());
614 }
615
616 #[test]
617 fn module_buffers_listed_for_overriding_module() {
618 let m = ParentModule::<f32>::new().unwrap();
619 assert_eq!(m.buffers().len(), 1);
620 assert_eq!(m.buffers()[0].shape(), &[2]);
621 let nb = m.named_buffers();
622 assert_eq!(nb.len(), 1);
623 assert_eq!(nb[0].0, "running_mean");
624 }
625
626 #[test]
627 fn module_children_listed_for_parent() {
628 let m = ParentModule::<f32>::new().unwrap();
629 assert_eq!(m.children().len(), 1);
630 assert_eq!(m.named_children().len(), 1);
631 assert_eq!(m.named_children()[0].0, "child");
632 }
633
634 #[test]
635 fn module_named_modules_includes_self_and_descendants() {
636 let m = ParentModule::<f32>::new().unwrap();
637 let nm = m.named_modules();
638 assert_eq!(nm.len(), 2);
640 assert_eq!(nm[0].0, "");
641 assert_eq!(nm[1].0, "child");
642 }
643
644 #[test]
645 fn module_modules_includes_self_and_descendants() {
646 let m = ParentModule::<f32>::new().unwrap();
647 let mods = m.modules();
648 assert_eq!(mods.len(), 2);
649 }
650
651 #[test]
652 fn module_zero_grad_succeeds() {
653 let m = SimpleModule::<f32>::new(3).unwrap();
655 m.zero_grad().unwrap();
656 }
657
658 #[test]
659 fn module_requires_grad_toggles_all_parameters() {
660 let mut m = ParentModule::<f32>::new().unwrap();
661 for p in m.parameters() {
662 assert!(p.requires_grad());
663 }
664 m.requires_grad_(false);
665 for p in m.parameters() {
666 assert!(!p.requires_grad());
667 }
668 m.requires_grad_(true);
669 for p in m.parameters() {
670 assert!(p.requires_grad());
671 }
672 }
673
674 #[test]
675 fn module_apply_to_parameters_visits_all() {
676 let mut m = ParentModule::<f32>::new().unwrap();
677 let n_params = m.parameters().len();
678 let mut count = 0;
679 m.apply_to_parameters(&mut |_p| count += 1);
680 assert_eq!(count, n_params);
681 }
682
683 #[test]
684 fn module_state_dict_includes_buffers() {
685 let m = ParentModule::<f32>::new().unwrap();
686 let sd = m.state_dict();
687 assert!(sd.contains_key("weight"));
688 assert!(sd.contains_key("running_mean"));
689 assert!(sd.contains_key("child.weight"));
690 assert_eq!(sd.len(), 3);
691 }
692
693 #[test]
694 fn module_load_state_dict_with_buffer() {
695 let mut m = ParentModule::<f32>::new().unwrap();
696 let mut sd: StateDict<f32> = HashMap::new();
697 sd.insert(
698 "weight".into(),
699 ferrotorch_core::ones::<f32>(&[2, 2]).unwrap(),
700 );
701 sd.insert(
702 "running_mean".into(),
703 ferrotorch_core::from_slice::<f32>(&[7.0, 9.0], &[2]).unwrap(),
704 );
705 sd.insert(
706 "child.weight".into(),
707 ferrotorch_core::zeros::<f32>(&[3]).unwrap(),
708 );
709 m.load_state_dict(&sd, true).unwrap();
710 assert_eq!(m.buffers()[0].data().unwrap(), &[7.0, 9.0]);
711 }
712
713 #[test]
714 fn module_descendants_dyn_excludes_self() {
715 let m = ParentModule::<f32>::new().unwrap();
716 let d = m.descendants_dyn();
717 assert_eq!(d.len(), 1);
718 }
719
720 #[test]
721 fn module_named_descendants_dyn_paths() {
722 let m = ParentModule::<f32>::new().unwrap();
723 let nd = m.named_descendants_dyn();
724 assert_eq!(nd.len(), 1);
725 assert_eq!(nd[0].0, "child");
726 }
727
728 #[test]
738 fn module_named_descendants_dyn_empty_parent_no_leading_dot() {
739 struct TransparentWrapper<T: Float> {
743 inner: ParentModule<T>,
744 }
745 impl<T: Float> Module<T> for TransparentWrapper<T> {
746 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
747 self.inner.forward(input)
748 }
749 fn parameters(&self) -> Vec<&Parameter<T>> {
750 self.inner.parameters()
751 }
752 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
753 self.inner.parameters_mut()
754 }
755 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
756 self.inner.named_parameters()
757 }
758 fn children(&self) -> Vec<&dyn Module<T>> {
759 vec![&self.inner]
760 }
761 fn named_children(&self) -> Vec<(String, &dyn Module<T>)> {
762 vec![(String::new(), &self.inner)]
763 }
764 fn train(&mut self) {
765 self.inner.train();
766 }
767 fn eval(&mut self) {
768 self.inner.eval();
769 }
770 fn is_training(&self) -> bool {
771 self.inner.is_training()
772 }
773 }
774 let m = TransparentWrapper::<f32> {
775 inner: ParentModule::new().unwrap(),
776 };
777 let nd: Vec<String> = m
778 .named_descendants_dyn()
779 .into_iter()
780 .map(|(n, _)| n)
781 .collect();
782 assert_eq!(nd, vec![String::new(), "child".to_string()]);
784 for p in &nd {
785 assert!(
786 !p.starts_with('.'),
787 "transparent-wrapper descendant path '{p}' starts with '.'; \
788 the empty-parent branch in named_descendants_dyn has regressed",
789 );
790 }
791 }
792
793 #[test]
798 fn with_forward_hook_wraps_and_fires() {
799 use std::sync::atomic::{AtomicUsize, Ordering};
800 let m = SimpleModule::<f32>::new(2).unwrap();
801 let counter = std::sync::Arc::new(AtomicUsize::new(0));
802 let counter_for_hook = std::sync::Arc::clone(&counter);
803
804 let (wrapped, _handle) = m.with_forward_hook(Box::new(move |_input, _output| {
805 counter_for_hook.fetch_add(1, Ordering::SeqCst);
806 }));
807
808 let input = ferrotorch_core::Tensor::from_storage(
809 ferrotorch_core::TensorStorage::cpu(vec![1.0_f32, 2.0]),
810 vec![2],
811 false,
812 )
813 .unwrap();
814 let _ = wrapped.forward(&input).unwrap();
815 assert_eq!(counter.load(Ordering::SeqCst), 1);
816 }
817
818 #[test]
819 fn with_forward_pre_hook_wraps_and_fires() {
820 use std::sync::atomic::{AtomicUsize, Ordering};
821 let m = SimpleModule::<f32>::new(2).unwrap();
822 let counter = std::sync::Arc::new(AtomicUsize::new(0));
823 let counter_for_hook = std::sync::Arc::clone(&counter);
824
825 let (wrapped, _handle) = m.with_forward_pre_hook(Box::new(move |input| {
826 counter_for_hook.fetch_add(1, Ordering::SeqCst);
827 Ok(input.clone())
828 }));
829
830 let input = ferrotorch_core::Tensor::from_storage(
831 ferrotorch_core::TensorStorage::cpu(vec![1.0_f32, 2.0]),
832 vec![2],
833 false,
834 )
835 .unwrap();
836 let _ = wrapped.forward(&input).unwrap();
837 assert_eq!(counter.load(Ordering::SeqCst), 1);
838 }
839
840 #[test]
841 fn with_backward_hook_returns_handle() {
842 let m = SimpleModule::<f32>::new(2).unwrap();
845 let (wrapped, handle) = m.with_backward_hook(Box::new(|_gi, _go| {}));
846 let input = ferrotorch_core::Tensor::from_storage(
849 ferrotorch_core::TensorStorage::cpu(vec![3.0_f32]),
850 vec![1],
851 false,
852 )
853 .unwrap();
854 let _ = wrapped.forward(&input).unwrap();
855 handle.remove();
857 }
858}