Skip to main content

ferrotorch_nn/
hooks.rs

1//! Forward/backward hooks for [`Module`] instances.
2//!
3//! PyTorch lets users attach hooks to any `nn.Module` to inspect or modify
4//! activations during the forward and backward passes.  Because our [`Module`]
5//! trait is stateless (no per-instance storage), hooks are added via the
6//! [`HookedModule<M>`] wrapper which stores hooks externally and delegates all
7//! `Module` methods to the inner module.
8//!
9//! # Example
10//!
11//! ```ignore
12//! use ferrotorch_nn::{HookedModule, Linear, Module};
13//!
14//! let linear = Linear::<f32>::new(4, 2, true).unwrap();
15//! let hooked = HookedModule::new(linear);
16//!
17//! let _handle = hooked.register_forward_hook(Box::new(|input, output| {
18//!     println!("in: {:?}  out: {:?}", input.shape(), output.shape());
19//! }));
20//! ```
21//!
22//! ## REQ status (per `.design/ferrotorch-nn/hooks.md`)
23//!
24//! | REQ | Status | Evidence |
25//! |---|---|---|
26//! | REQ-1 | SHIPPED | `pub type ForwardHook<T>` / `ForwardPreHook<T>` / `BackwardHook<T>` with `Send + Sync` bounds mirror PyTorch's hook closure signatures from `torch/nn/modules/module.py:1340-1660`; consumed by `ferrotorch-nn/src/module.rs:6` `use crate::hooks::{BackwardHook, ForwardHook, ForwardPreHook, HookHandle, HookedModule}`. |
27//! | REQ-2 | SHIPPED | `pub struct HookHandle` with `remove(self)` mirrors `torch.utils.hooks.RemovableHandle`; consumed by `module.rs` `with_*_hook` methods returning it as half of the tuple. |
28//! | REQ-3 | SHIPPED | `pub struct HookedModule<M, T: Float>` with three `Mutex<Vec<...>>` hook stores + `AtomicUsize` id counter; consumed by `module.rs` `with_forward_hook` constructing `HookedModule::new(self)`. |
29//! | REQ-4 | SHIPPED | `::new`, `inner`, `inner_mut`, `into_inner` accessors; consumed by `module.rs` `with_*_hook` methods (construction) and downstream code unwrapping via `into_inner` after removing all hooks. |
30//! | REQ-5 | SHIPPED | `register_forward_hook` / `register_forward_pre_hook` / `register_backward_hook` taking `&self`; consumed by `module.rs` `with_*_hook` calling each on the freshly-wrapped HookedModule. |
31//! | REQ-6 | SHIPPED | `impl Module<T> for HookedModule<M, T>` with chained pre-hooks + post-hooks + delegation; consumed by every `.forward(input)` call on a HookedModule — the production path through the `Module<T>` trait. |
32//! | REQ-7 | SHIPPED | `gc_hooks<H>` private helper invoked at the start of each hook-list traversal; consumed transitively by every `HookedModule::forward` call. Pinned by `test_hook_handle_remove` — second forward after `handle.remove()` does NOT fire the hook. |
33//! | REQ-8 | SHIPPED | `HookHandle::id(&self) -> usize` accessor mirrors upstream `RemovableHandle.id`; consumed by downstream observability code maintaining `hook_id → metadata` maps and load-bearing for the lazy-GC mechanism. |
34
35use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
36use std::sync::{Arc, Mutex};
37
38use ferrotorch_core::{FerrotorchResult, Float, Tensor};
39
40use crate::module::{Module, StateDict};
41use crate::parameter::Parameter;
42
43// ---------------------------------------------------------------------------
44// Hook type aliases
45// ---------------------------------------------------------------------------
46
47/// A closure invoked *after* the forward pass with (input, output).
48///
49/// Intended for observation / logging; the return value is not used.
50pub type ForwardHook<T> = Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Send + Sync>;
51
52/// A closure invoked *before* the forward pass with (input).
53///
54/// May return a replacement input tensor, allowing the hook to transform
55/// activations before they reach the module.
56pub type ForwardPreHook<T> = Box<dyn Fn(&Tensor<T>) -> FerrotorchResult<Tensor<T>> + Send + Sync>;
57
58/// A closure invoked during the backward pass with (grad_input, grad_output).
59///
60/// Intended for observation / logging of gradients.
61pub type BackwardHook<T> = Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Send + Sync>;
62
63// ---------------------------------------------------------------------------
64// HookHandle
65// ---------------------------------------------------------------------------
66
67/// An opaque handle returned when a hook is registered.
68///
69/// Calling [`HookHandle::remove`] unregisters the hook so it will not fire on
70/// subsequent forward/backward calls.  Dropping the handle *without* calling
71/// `remove` leaves the hook active (matching PyTorch semantics).
72#[derive(Debug)]
73pub struct HookHandle {
74    id: usize,
75    removed: Arc<AtomicBool>,
76}
77
78impl HookHandle {
79    fn new(id: usize, removed: Arc<AtomicBool>) -> Self {
80        Self { id, removed }
81    }
82
83    /// The unique identifier for this hook registration.
84    pub fn id(&self) -> usize {
85        self.id
86    }
87
88    /// Unregister the hook.  Subsequent forward/backward passes will skip it.
89    pub fn remove(self) {
90        self.removed.store(true, Ordering::Release);
91    }
92}
93
94// ---------------------------------------------------------------------------
95// Internal storage entry
96// ---------------------------------------------------------------------------
97
98/// One entry in a hook list.  The `removed` flag is shared with the
99/// corresponding [`HookHandle`]; when the handle is removed the flag is set
100/// and the hook is lazily purged on the next invocation.
101struct HookEntry<H> {
102    #[allow(dead_code)] // Retained for future lookup-by-id operations.
103    id: usize,
104    hook: H,
105    removed: Arc<AtomicBool>,
106}
107
108// ---------------------------------------------------------------------------
109// HookedModule
110// ---------------------------------------------------------------------------
111
112/// A wrapper that adds hook storage around any [`Module`].
113///
114/// `HookedModule` implements `Module<T>` itself, so it can be used anywhere
115/// the inner module could be used.  Hooks are stored behind `Mutex`es and
116/// the wrapper is `Send + Sync` as long as the inner module is.
117pub struct HookedModule<M, T: Float> {
118    inner: M,
119    forward_hooks: Mutex<Vec<HookEntry<ForwardHook<T>>>>,
120    forward_pre_hooks: Mutex<Vec<HookEntry<ForwardPreHook<T>>>>,
121    backward_hooks: Mutex<Vec<HookEntry<BackwardHook<T>>>>,
122    next_id: AtomicUsize,
123}
124
125impl<M, T: Float> HookedModule<M, T> {
126    /// Wrap a module, enabling hook registration.
127    pub fn new(module: M) -> Self {
128        Self {
129            inner: module,
130            forward_hooks: Mutex::new(Vec::new()),
131            forward_pre_hooks: Mutex::new(Vec::new()),
132            backward_hooks: Mutex::new(Vec::new()),
133            next_id: AtomicUsize::new(0),
134        }
135    }
136
137    /// Register a hook that fires *after* `forward`.
138    ///
139    /// The hook receives `(&input, &output)`.  Returns a [`HookHandle`] that
140    /// can be used to unregister the hook.
141    pub fn register_forward_hook(&self, hook: ForwardHook<T>) -> HookHandle {
142        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
143        let removed = Arc::new(AtomicBool::new(false));
144        let entry = HookEntry {
145            id,
146            hook,
147            removed: Arc::clone(&removed),
148        };
149        self.forward_hooks.lock().unwrap().push(entry);
150        HookHandle::new(id, removed)
151    }
152
153    /// Register a hook that fires *before* `forward`.
154    ///
155    /// The hook receives `(&input)` and returns a (possibly modified) input
156    /// tensor.  Returns a [`HookHandle`].
157    pub fn register_forward_pre_hook(&self, hook: ForwardPreHook<T>) -> HookHandle {
158        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
159        let removed = Arc::new(AtomicBool::new(false));
160        let entry = HookEntry {
161            id,
162            hook,
163            removed: Arc::clone(&removed),
164        };
165        self.forward_pre_hooks.lock().unwrap().push(entry);
166        HookHandle::new(id, removed)
167    }
168
169    /// Register a hook that fires during the backward pass.
170    ///
171    /// The hook receives `(&grad_input, &grad_output)`.  Returns a
172    /// [`HookHandle`].
173    pub fn register_backward_hook(&self, hook: BackwardHook<T>) -> HookHandle {
174        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
175        let removed = Arc::new(AtomicBool::new(false));
176        let entry = HookEntry {
177            id,
178            hook,
179            removed: Arc::clone(&removed),
180        };
181        self.backward_hooks.lock().unwrap().push(entry);
182        HookHandle::new(id, removed)
183    }
184
185    /// Borrow the inner module.
186    pub fn inner(&self) -> &M {
187        &self.inner
188    }
189
190    /// Mutably borrow the inner module.
191    pub fn inner_mut(&mut self) -> &mut M {
192        &mut self.inner
193    }
194
195    /// Consume the wrapper and return the inner module.
196    pub fn into_inner(self) -> M {
197        self.inner
198    }
199
200    /// Purge entries whose handle has been removed.
201    fn gc_hooks<H>(hooks: &mut Vec<HookEntry<H>>) {
202        hooks.retain(|e| !e.removed.load(Ordering::Acquire));
203    }
204}
205
206// ---------------------------------------------------------------------------
207// Module implementation
208// ---------------------------------------------------------------------------
209
210impl<M: Module<T>, T: Float> Module<T> for HookedModule<M, T> {
211    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
212        // 1. Run forward pre-hooks (each may transform the input).
213        let mut x = input.clone();
214        {
215            let mut pre_hooks = self.forward_pre_hooks.lock().unwrap();
216            Self::gc_hooks(&mut pre_hooks);
217            for entry in pre_hooks.iter() {
218                if !entry.removed.load(Ordering::Acquire) {
219                    x = (entry.hook)(&x)?;
220                }
221            }
222        }
223
224        // 2. Run the inner module's forward pass.
225        let output = self.inner.forward(&x)?;
226
227        // 3. Run forward post-hooks (observe input + output).
228        {
229            let mut post_hooks = self.forward_hooks.lock().unwrap();
230            Self::gc_hooks(&mut post_hooks);
231            for entry in post_hooks.iter() {
232                if !entry.removed.load(Ordering::Acquire) {
233                    (entry.hook)(&x, &output);
234                }
235            }
236        }
237
238        Ok(output)
239    }
240
241    fn parameters(&self) -> Vec<&Parameter<T>> {
242        self.inner.parameters()
243    }
244
245    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
246        self.inner.parameters_mut()
247    }
248
249    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
250        self.inner.named_parameters()
251    }
252
253    fn train(&mut self) {
254        self.inner.train();
255    }
256
257    fn eval(&mut self) {
258        self.inner.eval();
259    }
260
261    fn is_training(&self) -> bool {
262        self.inner.is_training()
263    }
264
265    fn state_dict(&self) -> StateDict<T> {
266        self.inner.state_dict()
267    }
268
269    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
270        self.inner.load_state_dict(state, strict)
271    }
272}
273
274// ---------------------------------------------------------------------------
275// Tests
276// ---------------------------------------------------------------------------
277
278#[cfg(test)]
279mod tests {
280    use std::sync::Arc;
281    use std::sync::atomic::AtomicUsize;
282
283    use ferrotorch_core::{FerrotorchResult, Float, Tensor};
284
285    use crate::module::Module;
286    use crate::parameter::Parameter;
287
288    use super::HookedModule;
289
290    // -- Minimal test module ------------------------------------------------
291
292    struct DoubleModule<T: Float> {
293        weight: Parameter<T>,
294        training: bool,
295    }
296
297    impl<T: Float> DoubleModule<T> {
298        fn new(size: usize) -> FerrotorchResult<Self> {
299            Ok(Self {
300                weight: Parameter::ones(&[size])?,
301                training: true,
302            })
303        }
304    }
305
306    impl<T: Float> Module<T> for DoubleModule<T> {
307        fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
308            // Simple: output = input + input  (doubles the values).
309            let out = input.add_t(input)?;
310            Ok(out)
311        }
312
313        fn parameters(&self) -> Vec<&Parameter<T>> {
314            vec![&self.weight]
315        }
316
317        fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
318            vec![&mut self.weight]
319        }
320
321        fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
322            vec![("weight".to_string(), &self.weight)]
323        }
324
325        fn train(&mut self) {
326            self.training = true;
327        }
328
329        fn eval(&mut self) {
330            self.training = false;
331        }
332
333        fn is_training(&self) -> bool {
334            self.training
335        }
336    }
337
338    // -- Tests --------------------------------------------------------------
339
340    #[test]
341    fn test_forward_hook_captures_output_shape() {
342        let m = DoubleModule::<f32>::new(4).unwrap();
343        let hooked = HookedModule::new(m);
344
345        let captured_shape = Arc::new(Mutex::new(Vec::<usize>::new()));
346        let shape_ref = Arc::clone(&captured_shape);
347
348        let _handle = hooked.register_forward_hook(Box::new(move |_input, output| {
349            *shape_ref.lock().unwrap() = output.shape().to_vec();
350        }));
351
352        let input = ferrotorch_core::ones::<f32>(&[4]).unwrap();
353        let _out = hooked.forward(&input).unwrap();
354
355        assert_eq!(*captured_shape.lock().unwrap(), vec![4]);
356    }
357
358    #[test]
359    fn test_forward_pre_hook_modifies_input() {
360        let m = DoubleModule::<f32>::new(3).unwrap();
361        let hooked = HookedModule::new(m);
362
363        // Pre-hook replaces input with zeros.
364        let _handle = hooked.register_forward_pre_hook(Box::new(|input| {
365            ferrotorch_core::zeros::<f32>(input.shape())
366        }));
367
368        let input = ferrotorch_core::ones::<f32>(&[3]).unwrap();
369        let out = hooked.forward(&input).unwrap();
370
371        // DoubleModule doubles input; zeros doubled = zeros.
372        let data = out.data().unwrap();
373        assert!(data.iter().all(|&v| v == 0.0));
374    }
375
376    #[test]
377    fn test_multiple_hooks_fire_in_order() {
378        let m = DoubleModule::<f32>::new(2).unwrap();
379        let hooked = HookedModule::new(m);
380
381        let order = Arc::new(Mutex::new(Vec::<usize>::new()));
382
383        let o1 = Arc::clone(&order);
384        let _h1 = hooked.register_forward_hook(Box::new(move |_input, _output| {
385            o1.lock().unwrap().push(1);
386        }));
387
388        let o2 = Arc::clone(&order);
389        let _h2 = hooked.register_forward_hook(Box::new(move |_input, _output| {
390            o2.lock().unwrap().push(2);
391        }));
392
393        let o3 = Arc::clone(&order);
394        let _h3 = hooked.register_forward_hook(Box::new(move |_input, _output| {
395            o3.lock().unwrap().push(3);
396        }));
397
398        let input = ferrotorch_core::ones::<f32>(&[2]).unwrap();
399        let _out = hooked.forward(&input).unwrap();
400
401        assert_eq!(*order.lock().unwrap(), vec![1, 2, 3]);
402    }
403
404    #[test]
405    fn test_hook_handle_remove() {
406        let m = DoubleModule::<f32>::new(2).unwrap();
407        let hooked = HookedModule::new(m);
408
409        let count = Arc::new(AtomicUsize::new(0));
410        let c = Arc::clone(&count);
411
412        let handle = hooked.register_forward_hook(Box::new(move |_input, _output| {
413            c.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
414        }));
415
416        let input = ferrotorch_core::ones::<f32>(&[2]).unwrap();
417
418        // First forward — hook fires.
419        let _out = hooked.forward(&input).unwrap();
420        assert_eq!(count.load(std::sync::atomic::Ordering::Relaxed), 1);
421
422        // Remove the hook.
423        handle.remove();
424
425        // Second forward — hook must NOT fire.
426        let _out = hooked.forward(&input).unwrap();
427        assert_eq!(count.load(std::sync::atomic::Ordering::Relaxed), 1);
428    }
429
430    #[test]
431    fn test_hooked_module_delegates_parameters() {
432        let m = DoubleModule::<f32>::new(5).unwrap();
433        let hooked = HookedModule::new(m);
434
435        assert_eq!(hooked.parameters().len(), 1);
436        assert_eq!(hooked.parameters()[0].shape(), &[5]);
437    }
438
439    #[test]
440    fn test_hooked_module_delegates_named_parameters() {
441        let m = DoubleModule::<f32>::new(3).unwrap();
442        let hooked = HookedModule::new(m);
443
444        let named = hooked.named_parameters();
445        assert_eq!(named.len(), 1);
446        assert_eq!(named[0].0, "weight");
447    }
448
449    #[test]
450    fn test_hooked_module_delegates_state_dict() {
451        let m = DoubleModule::<f32>::new(4).unwrap();
452        let hooked = HookedModule::new(m);
453
454        let sd = hooked.state_dict();
455        assert!(sd.contains_key("weight"));
456        assert_eq!(sd["weight"].shape(), &[4]);
457    }
458
459    #[test]
460    fn test_hooked_module_delegates_train_eval() {
461        let m = DoubleModule::<f32>::new(2).unwrap();
462        let mut hooked = HookedModule::new(m);
463
464        assert!(hooked.is_training());
465        hooked.eval();
466        assert!(!hooked.is_training());
467        hooked.train();
468        assert!(hooked.is_training());
469    }
470
471    #[test]
472    fn test_hooked_module_inner_access() {
473        let m = DoubleModule::<f32>::new(3).unwrap();
474        let hooked: HookedModule<_, f32> = HookedModule::new(m);
475        assert_eq!(hooked.inner().parameters().len(), 1);
476    }
477
478    #[test]
479    fn test_hooked_module_is_send_sync() {
480        fn assert_send_sync<S: Send + Sync>() {}
481        assert_send_sync::<HookedModule<DoubleModule<f32>, f32>>();
482        assert_send_sync::<HookedModule<DoubleModule<f64>, f64>>();
483    }
484
485    #[test]
486    fn test_backward_hook_registration() {
487        let m = DoubleModule::<f32>::new(2).unwrap();
488        let hooked = HookedModule::new(m);
489
490        let called = Arc::new(AtomicUsize::new(0));
491        let c = Arc::clone(&called);
492
493        let _handle = hooked.register_backward_hook(Box::new(move |_grad_in, _grad_out| {
494            c.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
495        }));
496
497        // Backward hooks are registered but not invoked during forward.
498        // Verify that forward still works with backward hooks registered.
499        let input = ferrotorch_core::ones::<f32>(&[2]).unwrap();
500        let _out = hooked.forward(&input).unwrap();
501
502        assert_eq!(called.load(std::sync::atomic::Ordering::Relaxed), 0);
503    }
504
505    #[test]
506    fn test_multiple_pre_hooks_chain() {
507        let m = DoubleModule::<f32>::new(1).unwrap();
508        let hooked = HookedModule::new(m);
509
510        // First pre-hook: replace with zeros.
511        let _h1 = hooked.register_forward_pre_hook(Box::new(|input| {
512            ferrotorch_core::zeros::<f32>(input.shape())
513        }));
514
515        // Second pre-hook: add ones (zeros + ones = ones).
516        let _h2 = hooked.register_forward_pre_hook(Box::new(|input| {
517            let ones = ferrotorch_core::ones::<f32>(input.shape())?;
518            input.add_t(&ones)
519        }));
520
521        let input = ferrotorch_core::from_slice::<f32>(&[42.0], &[1]).unwrap();
522        let out = hooked.forward(&input).unwrap();
523
524        // Pre-hooks chained: 42 -> 0 -> 1; DoubleModule doubles: 1+1 = 2.
525        let data = out.data().unwrap();
526        assert_eq!(data, vec![2.0]);
527    }
528
529    use std::sync::Mutex;
530}