1use std::collections::HashMap;
2
3use ferrotorch_core::{Device, FerrotorchError, FerrotorchResult, Float, Tensor};
4
5use crate::parameter::Parameter;
6
7pub type StateDict<T> = HashMap<String, Tensor<T>>;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum Reduction {
13 Mean,
15 Sum,
17 None,
19}
20
21pub trait Module<T: Float>: Send + Sync {
25 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>;
27
28 fn parameters(&self) -> Vec<&Parameter<T>>;
30
31 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>;
33
34 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>;
39
40 fn train(&mut self);
42
43 fn eval(&mut self);
45
46 fn is_training(&self) -> bool;
48
49 fn to_device(&mut self, device: Device) -> FerrotorchResult<()> {
53 for param in self.parameters_mut() {
54 *param = param.to(device)?;
55 }
56 Ok(())
57 }
58
59 fn state_dict(&self) -> StateDict<T> {
61 self.named_parameters()
62 .into_iter()
63 .map(|(name, param)| {
64 let data = param.tensor().clone();
66 (name, data)
67 })
68 .collect()
69 }
70
71 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
77 let named = self.named_parameters();
78 let known_keys: std::collections::HashSet<&str> =
79 named.iter().map(|(k, _)| k.as_str()).collect();
80
81 if strict {
82 for key in state.keys() {
83 if !known_keys.contains(key.as_str()) {
84 return Err(FerrotorchError::InvalidArgument {
85 message: format!("unexpected key in state_dict: \"{key}\""),
86 });
87 }
88 }
89 }
90
91 let param_names: Vec<String> = self
95 .named_parameters()
96 .into_iter()
97 .map(|(name, _)| name)
98 .collect();
99
100 let params_mut = self.parameters_mut();
101
102 for (name, param) in param_names.iter().zip(params_mut.into_iter()) {
103 if let Some(tensor) = state.get(name) {
104 if param.shape() != tensor.shape() {
105 return Err(FerrotorchError::ShapeMismatch {
106 message: format!(
107 "state_dict shape mismatch for \"{name}\": expected {:?}, got {:?}",
108 param.shape(),
109 tensor.shape()
110 ),
111 });
112 }
113 *param = Parameter::new(tensor.clone());
115 } else if strict {
116 return Err(FerrotorchError::InvalidArgument {
117 message: format!("missing key in state_dict: \"{name}\""),
118 });
119 }
120 }
121
122 Ok(())
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 struct SimpleModule<T: Float> {
131 weight: Parameter<T>,
132 training: bool,
133 }
134
135 impl<T: Float> SimpleModule<T> {
136 fn new(size: usize) -> FerrotorchResult<Self> {
137 Ok(Self {
138 weight: Parameter::zeros(&[size])?,
139 training: true,
140 })
141 }
142 }
143
144 impl<T: Float> Module<T> for SimpleModule<T> {
145 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
146 Ok(input.clone())
148 }
149
150 fn parameters(&self) -> Vec<&Parameter<T>> {
151 vec![&self.weight]
152 }
153
154 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
155 vec![&mut self.weight]
156 }
157
158 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
159 vec![("weight".to_string(), &self.weight)]
160 }
161
162 fn train(&mut self) {
163 self.training = true;
164 }
165
166 fn eval(&mut self) {
167 self.training = false;
168 }
169
170 fn is_training(&self) -> bool {
171 self.training
172 }
173 }
174
175 #[test]
176 fn test_module_parameters() {
177 let m = SimpleModule::<f32>::new(5).unwrap();
178 assert_eq!(m.parameters().len(), 1);
179 assert_eq!(m.parameters()[0].shape(), &[5]);
180 }
181
182 #[test]
183 fn test_module_named_parameters() {
184 let m = SimpleModule::<f32>::new(3).unwrap();
185 let named = m.named_parameters();
186 assert_eq!(named.len(), 1);
187 assert_eq!(named[0].0, "weight");
188 }
189
190 #[test]
191 fn test_module_train_eval() {
192 let mut m = SimpleModule::<f32>::new(2).unwrap();
193 assert!(m.is_training());
194 m.eval();
195 assert!(!m.is_training());
196 m.train();
197 assert!(m.is_training());
198 }
199
200 #[test]
201 fn test_module_state_dict_roundtrip() {
202 let m = SimpleModule::<f32>::new(4).unwrap();
203 let sd = m.state_dict();
204 assert!(sd.contains_key("weight"));
205 assert_eq!(sd["weight"].shape(), &[4]);
206
207 let mut m2 = SimpleModule::<f32>::new(4).unwrap();
208 m2.load_state_dict(&sd, true).unwrap();
209 }
210
211 #[test]
212 fn test_module_state_dict_strict_extra_key() {
213 let mut m = SimpleModule::<f32>::new(3).unwrap();
214 let mut sd = HashMap::new();
215 sd.insert(
216 "weight".to_string(),
217 ferrotorch_core::zeros::<f32>(&[3]).unwrap(),
218 );
219 sd.insert(
220 "extra".to_string(),
221 ferrotorch_core::zeros::<f32>(&[1]).unwrap(),
222 );
223
224 assert!(m.load_state_dict(&sd, true).is_err());
225 assert!(m.load_state_dict(&sd, false).is_ok());
226 }
227
228 #[test]
229 fn test_module_state_dict_shape_mismatch() {
230 let mut m = SimpleModule::<f32>::new(3).unwrap();
231 let mut sd = HashMap::new();
232 sd.insert(
233 "weight".to_string(),
234 ferrotorch_core::zeros::<f32>(&[5]).unwrap(),
235 );
236
237 assert!(m.load_state_dict(&sd, true).is_err());
238 }
239
240 #[test]
241 fn test_module_is_send_sync() {
242 fn assert_send_sync<T: Send + Sync>() {}
243 assert_send_sync::<SimpleModule<f32>>();
244 }
245
246 #[test]
247 fn test_reduction_enum() {
248 assert_eq!(Reduction::Mean, Reduction::Mean);
249 assert_ne!(Reduction::Mean, Reduction::Sum);
250 }
251
252 #[test]
253 fn test_to_device_cpu_preserves_weights() {
254 let mut m = SimpleModule::<f32>::new(4).unwrap();
255 m.to_device(ferrotorch_core::Device::Cpu).unwrap();
256 assert_eq!(m.parameters().len(), 1);
257 assert_eq!(m.parameters()[0].shape(), &[4]);
258 }
259
260 #[test]
261 fn test_to_device_cuda_without_backend() {
262 let mut m = SimpleModule::<f32>::new(3).unwrap();
263 let result = m.to_device(ferrotorch_core::Device::Cuda(0));
264 assert!(result.is_err());
265 }
266}