1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
36use std::sync::{Arc, Mutex};
37
38use ferrotorch_core::{FerrotorchResult, Float, Tensor};
39
40use crate::module::{Module, StateDict};
41use crate::parameter::Parameter;
42
43pub type ForwardHook<T> = Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Send + Sync>;
51
52pub type ForwardPreHook<T> = Box<dyn Fn(&Tensor<T>) -> FerrotorchResult<Tensor<T>> + Send + Sync>;
57
58pub type BackwardHook<T> = Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Send + Sync>;
62
63#[derive(Debug)]
73pub struct HookHandle {
74 id: usize,
75 removed: Arc<AtomicBool>,
76}
77
78impl HookHandle {
79 fn new(id: usize, removed: Arc<AtomicBool>) -> Self {
80 Self { id, removed }
81 }
82
83 pub fn id(&self) -> usize {
85 self.id
86 }
87
88 pub fn remove(self) {
90 self.removed.store(true, Ordering::Release);
91 }
92}
93
94struct HookEntry<H> {
102 #[allow(dead_code)] id: usize,
104 hook: H,
105 removed: Arc<AtomicBool>,
106}
107
108pub struct HookedModule<M, T: Float> {
118 inner: M,
119 forward_hooks: Mutex<Vec<HookEntry<ForwardHook<T>>>>,
120 forward_pre_hooks: Mutex<Vec<HookEntry<ForwardPreHook<T>>>>,
121 backward_hooks: Mutex<Vec<HookEntry<BackwardHook<T>>>>,
122 next_id: AtomicUsize,
123}
124
125impl<M, T: Float> HookedModule<M, T> {
126 pub fn new(module: M) -> Self {
128 Self {
129 inner: module,
130 forward_hooks: Mutex::new(Vec::new()),
131 forward_pre_hooks: Mutex::new(Vec::new()),
132 backward_hooks: Mutex::new(Vec::new()),
133 next_id: AtomicUsize::new(0),
134 }
135 }
136
137 pub fn register_forward_hook(&self, hook: ForwardHook<T>) -> HookHandle {
142 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
143 let removed = Arc::new(AtomicBool::new(false));
144 let entry = HookEntry {
145 id,
146 hook,
147 removed: Arc::clone(&removed),
148 };
149 self.forward_hooks.lock().unwrap().push(entry);
150 HookHandle::new(id, removed)
151 }
152
153 pub fn register_forward_pre_hook(&self, hook: ForwardPreHook<T>) -> HookHandle {
158 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
159 let removed = Arc::new(AtomicBool::new(false));
160 let entry = HookEntry {
161 id,
162 hook,
163 removed: Arc::clone(&removed),
164 };
165 self.forward_pre_hooks.lock().unwrap().push(entry);
166 HookHandle::new(id, removed)
167 }
168
169 pub fn register_backward_hook(&self, hook: BackwardHook<T>) -> HookHandle {
174 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
175 let removed = Arc::new(AtomicBool::new(false));
176 let entry = HookEntry {
177 id,
178 hook,
179 removed: Arc::clone(&removed),
180 };
181 self.backward_hooks.lock().unwrap().push(entry);
182 HookHandle::new(id, removed)
183 }
184
185 pub fn inner(&self) -> &M {
187 &self.inner
188 }
189
190 pub fn inner_mut(&mut self) -> &mut M {
192 &mut self.inner
193 }
194
195 pub fn into_inner(self) -> M {
197 self.inner
198 }
199
200 fn gc_hooks<H>(hooks: &mut Vec<HookEntry<H>>) {
202 hooks.retain(|e| !e.removed.load(Ordering::Acquire));
203 }
204}
205
206impl<M: Module<T>, T: Float> Module<T> for HookedModule<M, T> {
211 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
212 let mut x = input.clone();
214 {
215 let mut pre_hooks = self.forward_pre_hooks.lock().unwrap();
216 Self::gc_hooks(&mut pre_hooks);
217 for entry in pre_hooks.iter() {
218 if !entry.removed.load(Ordering::Acquire) {
219 x = (entry.hook)(&x)?;
220 }
221 }
222 }
223
224 let output = self.inner.forward(&x)?;
226
227 {
229 let mut post_hooks = self.forward_hooks.lock().unwrap();
230 Self::gc_hooks(&mut post_hooks);
231 for entry in post_hooks.iter() {
232 if !entry.removed.load(Ordering::Acquire) {
233 (entry.hook)(&x, &output);
234 }
235 }
236 }
237
238 Ok(output)
239 }
240
241 fn parameters(&self) -> Vec<&Parameter<T>> {
242 self.inner.parameters()
243 }
244
245 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
246 self.inner.parameters_mut()
247 }
248
249 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
250 self.inner.named_parameters()
251 }
252
253 fn train(&mut self) {
254 self.inner.train();
255 }
256
257 fn eval(&mut self) {
258 self.inner.eval();
259 }
260
261 fn is_training(&self) -> bool {
262 self.inner.is_training()
263 }
264
265 fn state_dict(&self) -> StateDict<T> {
266 self.inner.state_dict()
267 }
268
269 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
270 self.inner.load_state_dict(state, strict)
271 }
272}
273
274#[cfg(test)]
279mod tests {
280 use std::sync::Arc;
281 use std::sync::atomic::AtomicUsize;
282
283 use ferrotorch_core::{FerrotorchResult, Float, Tensor};
284
285 use crate::module::Module;
286 use crate::parameter::Parameter;
287
288 use super::HookedModule;
289
290 struct DoubleModule<T: Float> {
293 weight: Parameter<T>,
294 training: bool,
295 }
296
297 impl<T: Float> DoubleModule<T> {
298 fn new(size: usize) -> FerrotorchResult<Self> {
299 Ok(Self {
300 weight: Parameter::ones(&[size])?,
301 training: true,
302 })
303 }
304 }
305
306 impl<T: Float> Module<T> for DoubleModule<T> {
307 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
308 let out = input.add_t(input)?;
310 Ok(out)
311 }
312
313 fn parameters(&self) -> Vec<&Parameter<T>> {
314 vec![&self.weight]
315 }
316
317 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
318 vec![&mut self.weight]
319 }
320
321 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
322 vec![("weight".to_string(), &self.weight)]
323 }
324
325 fn train(&mut self) {
326 self.training = true;
327 }
328
329 fn eval(&mut self) {
330 self.training = false;
331 }
332
333 fn is_training(&self) -> bool {
334 self.training
335 }
336 }
337
338 #[test]
341 fn test_forward_hook_captures_output_shape() {
342 let m = DoubleModule::<f32>::new(4).unwrap();
343 let hooked = HookedModule::new(m);
344
345 let captured_shape = Arc::new(Mutex::new(Vec::<usize>::new()));
346 let shape_ref = Arc::clone(&captured_shape);
347
348 let _handle = hooked.register_forward_hook(Box::new(move |_input, output| {
349 *shape_ref.lock().unwrap() = output.shape().to_vec();
350 }));
351
352 let input = ferrotorch_core::ones::<f32>(&[4]).unwrap();
353 let _out = hooked.forward(&input).unwrap();
354
355 assert_eq!(*captured_shape.lock().unwrap(), vec![4]);
356 }
357
358 #[test]
359 fn test_forward_pre_hook_modifies_input() {
360 let m = DoubleModule::<f32>::new(3).unwrap();
361 let hooked = HookedModule::new(m);
362
363 let _handle = hooked.register_forward_pre_hook(Box::new(|input| {
365 ferrotorch_core::zeros::<f32>(input.shape())
366 }));
367
368 let input = ferrotorch_core::ones::<f32>(&[3]).unwrap();
369 let out = hooked.forward(&input).unwrap();
370
371 let data = out.data().unwrap();
373 assert!(data.iter().all(|&v| v == 0.0));
374 }
375
376 #[test]
377 fn test_multiple_hooks_fire_in_order() {
378 let m = DoubleModule::<f32>::new(2).unwrap();
379 let hooked = HookedModule::new(m);
380
381 let order = Arc::new(Mutex::new(Vec::<usize>::new()));
382
383 let o1 = Arc::clone(&order);
384 let _h1 = hooked.register_forward_hook(Box::new(move |_input, _output| {
385 o1.lock().unwrap().push(1);
386 }));
387
388 let o2 = Arc::clone(&order);
389 let _h2 = hooked.register_forward_hook(Box::new(move |_input, _output| {
390 o2.lock().unwrap().push(2);
391 }));
392
393 let o3 = Arc::clone(&order);
394 let _h3 = hooked.register_forward_hook(Box::new(move |_input, _output| {
395 o3.lock().unwrap().push(3);
396 }));
397
398 let input = ferrotorch_core::ones::<f32>(&[2]).unwrap();
399 let _out = hooked.forward(&input).unwrap();
400
401 assert_eq!(*order.lock().unwrap(), vec![1, 2, 3]);
402 }
403
404 #[test]
405 fn test_hook_handle_remove() {
406 let m = DoubleModule::<f32>::new(2).unwrap();
407 let hooked = HookedModule::new(m);
408
409 let count = Arc::new(AtomicUsize::new(0));
410 let c = Arc::clone(&count);
411
412 let handle = hooked.register_forward_hook(Box::new(move |_input, _output| {
413 c.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
414 }));
415
416 let input = ferrotorch_core::ones::<f32>(&[2]).unwrap();
417
418 let _out = hooked.forward(&input).unwrap();
420 assert_eq!(count.load(std::sync::atomic::Ordering::Relaxed), 1);
421
422 handle.remove();
424
425 let _out = hooked.forward(&input).unwrap();
427 assert_eq!(count.load(std::sync::atomic::Ordering::Relaxed), 1);
428 }
429
430 #[test]
431 fn test_hooked_module_delegates_parameters() {
432 let m = DoubleModule::<f32>::new(5).unwrap();
433 let hooked = HookedModule::new(m);
434
435 assert_eq!(hooked.parameters().len(), 1);
436 assert_eq!(hooked.parameters()[0].shape(), &[5]);
437 }
438
439 #[test]
440 fn test_hooked_module_delegates_named_parameters() {
441 let m = DoubleModule::<f32>::new(3).unwrap();
442 let hooked = HookedModule::new(m);
443
444 let named = hooked.named_parameters();
445 assert_eq!(named.len(), 1);
446 assert_eq!(named[0].0, "weight");
447 }
448
449 #[test]
450 fn test_hooked_module_delegates_state_dict() {
451 let m = DoubleModule::<f32>::new(4).unwrap();
452 let hooked = HookedModule::new(m);
453
454 let sd = hooked.state_dict();
455 assert!(sd.contains_key("weight"));
456 assert_eq!(sd["weight"].shape(), &[4]);
457 }
458
459 #[test]
460 fn test_hooked_module_delegates_train_eval() {
461 let m = DoubleModule::<f32>::new(2).unwrap();
462 let mut hooked = HookedModule::new(m);
463
464 assert!(hooked.is_training());
465 hooked.eval();
466 assert!(!hooked.is_training());
467 hooked.train();
468 assert!(hooked.is_training());
469 }
470
471 #[test]
472 fn test_hooked_module_inner_access() {
473 let m = DoubleModule::<f32>::new(3).unwrap();
474 let hooked: HookedModule<_, f32> = HookedModule::new(m);
475 assert_eq!(hooked.inner().parameters().len(), 1);
476 }
477
478 #[test]
479 fn test_hooked_module_is_send_sync() {
480 fn assert_send_sync<S: Send + Sync>() {}
481 assert_send_sync::<HookedModule<DoubleModule<f32>, f32>>();
482 assert_send_sync::<HookedModule<DoubleModule<f64>, f64>>();
483 }
484
485 #[test]
486 fn test_backward_hook_registration() {
487 let m = DoubleModule::<f32>::new(2).unwrap();
488 let hooked = HookedModule::new(m);
489
490 let called = Arc::new(AtomicUsize::new(0));
491 let c = Arc::clone(&called);
492
493 let _handle = hooked.register_backward_hook(Box::new(move |_grad_in, _grad_out| {
494 c.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
495 }));
496
497 let input = ferrotorch_core::ones::<f32>(&[2]).unwrap();
500 let _out = hooked.forward(&input).unwrap();
501
502 assert_eq!(called.load(std::sync::atomic::Ordering::Relaxed), 0);
503 }
504
505 #[test]
506 fn test_multiple_pre_hooks_chain() {
507 let m = DoubleModule::<f32>::new(1).unwrap();
508 let hooked = HookedModule::new(m);
509
510 let _h1 = hooked.register_forward_pre_hook(Box::new(|input| {
512 ferrotorch_core::zeros::<f32>(input.shape())
513 }));
514
515 let _h2 = hooked.register_forward_pre_hook(Box::new(|input| {
517 let ones = ferrotorch_core::ones::<f32>(input.shape())?;
518 input.add_t(&ones)
519 }));
520
521 let input = ferrotorch_core::from_slice::<f32>(&[42.0], &[1]).unwrap();
522 let out = hooked.forward(&input).unwrap();
523
524 let data = out.data().unwrap();
526 assert_eq!(data, vec![2.0]);
527 }
528
529 use std::sync::Mutex;
530}