Skip to main content

ferrotorch_nn/
module.rs

1use std::collections::HashMap;
2
3use ferrotorch_core::{Device, FerrotorchError, FerrotorchResult, Float, Tensor};
4
5use crate::parameter::Parameter;
6
7/// A map from parameter names to tensors, used for serialization.
8pub type StateDict<T> = HashMap<String, Tensor<T>>;
9
10/// Reduction mode for loss functions.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum Reduction {
13    /// Return the mean of all losses.
14    Mean,
15    /// Return the sum of all losses.
16    Sum,
17    /// Return the unreduced loss tensor.
18    None,
19}
20
21/// The trait that all neural network layers implement.
22///
23/// Requires `Send + Sync` to match `Tensor<T>`'s thread-safety guarantees.
24pub trait Module<T: Float>: Send + Sync {
25    /// Forward pass. Takes input tensor, returns output tensor.
26    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>;
27
28    /// Iterate over all learnable parameters.
29    fn parameters(&self) -> Vec<&Parameter<T>>;
30
31    /// Iterate over all learnable parameters mutably.
32    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>;
33
34    /// Named parameters for state dict serialization.
35    ///
36    /// Keys use dot-separated paths for nested modules
37    /// (e.g., `"layer1.weight"`, `"layer1.bias"`).
38    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>;
39
40    /// Set training mode. Affects dropout, batchnorm, etc.
41    fn train(&mut self);
42
43    /// Set evaluation mode.
44    fn eval(&mut self);
45
46    /// Whether the module is in training mode.
47    fn is_training(&self) -> bool;
48
49    /// Move all parameters to a device.
50    ///
51    /// Default implementation iterates `parameters_mut()` and transfers each.
52    fn to_device(&mut self, device: Device) -> FerrotorchResult<()> {
53        for param in self.parameters_mut() {
54            *param = param.to(device)?;
55        }
56        Ok(())
57    }
58
59    /// Export parameters as a state dict.
60    fn state_dict(&self) -> StateDict<T> {
61        self.named_parameters()
62            .into_iter()
63            .map(|(name, param)| {
64                // Clone the tensor data (not just the Arc) for serialization.
65                let data = param.tensor().clone();
66                (name, data)
67            })
68            .collect()
69    }
70
71    /// Load parameters from a state dict.
72    ///
73    /// When `strict` is `true` (default), unexpected keys are an error.
74    /// When `false`, unexpected keys are silently ignored and missing
75    /// keys leave existing parameter values unchanged.
76    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
77        let named = self.named_parameters();
78        let known_keys: std::collections::HashSet<&str> =
79            named.iter().map(|(k, _)| k.as_str()).collect();
80
81        if strict {
82            for key in state.keys() {
83                if !known_keys.contains(key.as_str()) {
84                    return Err(FerrotorchError::InvalidArgument {
85                        message: format!("unexpected key in state_dict: \"{key}\""),
86                    });
87                }
88            }
89        }
90
91        // We need mutable access to parameters. Use named_parameters to get
92        // the mapping, then parameters_mut to actually update.
93        // This two-pass approach avoids borrowing issues.
94        let param_names: Vec<String> = self
95            .named_parameters()
96            .into_iter()
97            .map(|(name, _)| name)
98            .collect();
99
100        let params_mut = self.parameters_mut();
101
102        for (name, param) in param_names.iter().zip(params_mut.into_iter()) {
103            if let Some(tensor) = state.get(name) {
104                if param.shape() != tensor.shape() {
105                    return Err(FerrotorchError::ShapeMismatch {
106                        message: format!(
107                            "state_dict shape mismatch for \"{name}\": expected {:?}, got {:?}",
108                            param.shape(),
109                            tensor.shape()
110                        ),
111                    });
112                }
113                // Replace the parameter data with the loaded tensor.
114                *param = Parameter::new(tensor.clone());
115            } else if strict {
116                return Err(FerrotorchError::InvalidArgument {
117                    message: format!("missing key in state_dict: \"{name}\""),
118                });
119            }
120        }
121
122        Ok(())
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    /// A minimal test module with one parameter.
130    struct SimpleModule<T: Float> {
131        weight: Parameter<T>,
132        training: bool,
133    }
134
135    impl<T: Float> SimpleModule<T> {
136        fn new(size: usize) -> FerrotorchResult<Self> {
137            Ok(Self {
138                weight: Parameter::zeros(&[size])?,
139                training: true,
140            })
141        }
142    }
143
144    impl<T: Float> Module<T> for SimpleModule<T> {
145        fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
146            // Just return input for testing.
147            Ok(input.clone())
148        }
149
150        fn parameters(&self) -> Vec<&Parameter<T>> {
151            vec![&self.weight]
152        }
153
154        fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
155            vec![&mut self.weight]
156        }
157
158        fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
159            vec![("weight".to_string(), &self.weight)]
160        }
161
162        fn train(&mut self) {
163            self.training = true;
164        }
165
166        fn eval(&mut self) {
167            self.training = false;
168        }
169
170        fn is_training(&self) -> bool {
171            self.training
172        }
173    }
174
175    #[test]
176    fn test_module_parameters() {
177        let m = SimpleModule::<f32>::new(5).unwrap();
178        assert_eq!(m.parameters().len(), 1);
179        assert_eq!(m.parameters()[0].shape(), &[5]);
180    }
181
182    #[test]
183    fn test_module_named_parameters() {
184        let m = SimpleModule::<f32>::new(3).unwrap();
185        let named = m.named_parameters();
186        assert_eq!(named.len(), 1);
187        assert_eq!(named[0].0, "weight");
188    }
189
190    #[test]
191    fn test_module_train_eval() {
192        let mut m = SimpleModule::<f32>::new(2).unwrap();
193        assert!(m.is_training());
194        m.eval();
195        assert!(!m.is_training());
196        m.train();
197        assert!(m.is_training());
198    }
199
200    #[test]
201    fn test_module_state_dict_roundtrip() {
202        let m = SimpleModule::<f32>::new(4).unwrap();
203        let sd = m.state_dict();
204        assert!(sd.contains_key("weight"));
205        assert_eq!(sd["weight"].shape(), &[4]);
206
207        let mut m2 = SimpleModule::<f32>::new(4).unwrap();
208        m2.load_state_dict(&sd, true).unwrap();
209    }
210
211    #[test]
212    fn test_module_state_dict_strict_extra_key() {
213        let mut m = SimpleModule::<f32>::new(3).unwrap();
214        let mut sd = HashMap::new();
215        sd.insert(
216            "weight".to_string(),
217            ferrotorch_core::zeros::<f32>(&[3]).unwrap(),
218        );
219        sd.insert(
220            "extra".to_string(),
221            ferrotorch_core::zeros::<f32>(&[1]).unwrap(),
222        );
223
224        assert!(m.load_state_dict(&sd, true).is_err());
225        assert!(m.load_state_dict(&sd, false).is_ok());
226    }
227
228    #[test]
229    fn test_module_state_dict_shape_mismatch() {
230        let mut m = SimpleModule::<f32>::new(3).unwrap();
231        let mut sd = HashMap::new();
232        sd.insert(
233            "weight".to_string(),
234            ferrotorch_core::zeros::<f32>(&[5]).unwrap(),
235        );
236
237        assert!(m.load_state_dict(&sd, true).is_err());
238    }
239
240    #[test]
241    fn test_module_is_send_sync() {
242        fn assert_send_sync<T: Send + Sync>() {}
243        assert_send_sync::<SimpleModule<f32>>();
244    }
245
246    #[test]
247    fn test_reduction_enum() {
248        assert_eq!(Reduction::Mean, Reduction::Mean);
249        assert_ne!(Reduction::Mean, Reduction::Sum);
250    }
251
252    #[test]
253    fn test_to_device_cpu_preserves_weights() {
254        let mut m = SimpleModule::<f32>::new(4).unwrap();
255        m.to_device(ferrotorch_core::Device::Cpu).unwrap();
256        assert_eq!(m.parameters().len(), 1);
257        assert_eq!(m.parameters()[0].shape(), &[4]);
258    }
259
260    #[test]
261    fn test_to_device_cuda_without_backend() {
262        let mut m = SimpleModule::<f32>::new(3).unwrap();
263        let result = m.to_device(ferrotorch_core::Device::Cuda(0));
264        assert!(result.is_err());
265    }
266}