Skip to main content

ferrotorch_nn/
buffer.rs

1//! `Buffer<T>` — non-trainable persistent module state. (#583)
2//!
3//! Mirrors `torch.nn.Module.register_buffer`. A buffer is a tensor that:
4//! - is part of the module's persistent state (saved / loaded with the
5//!   module via `state_dict`),
6//! - moves with the module across devices (`to_device`),
7//! - does **not** participate in gradient descent (no `requires_grad`).
8//!
9//! Typical uses: running mean / running variance in `BatchNorm`, position
10//! tables in attention layers, masks, scaling constants — anything the
11//! module needs to remember across forward passes that should not be
12//! optimized.
13//!
14//! Buffers are exposed via the `Module` trait through `buffers()`,
15//! `buffers_mut()`, and `named_buffers()`. Concrete modules opt in by
16//! storing `Buffer<T>` fields and overriding the relevant trait methods.
17//!
18//! ## REQ status (per `.design/ferrotorch-nn/buffer.md`)
19//!
20//! | REQ | Status | Evidence |
21//! |---|---|---|
22//! | REQ-1 | SHIPPED | `pub struct Buffer<T: Float> { data: Tensor<T> }` with `#[derive(Debug, Clone)]` mirrors `torch/nn/parameter.py:249-279` via R-DEV-7 newtype; consumed by `pub use buffer::Buffer` at `lib.rs:223` + `lib.rs:273` prelude; `ferrotorch-nn/src/module.rs:5` `use crate::buffer::Buffer`; `module.rs:374` `*buf = Buffer::new(tensor.clone())` inside the default `load_state_dict`. |
23//! | REQ-2 | SHIPPED | `Buffer::new(tensor)` enforces `requires_grad = false` mirroring `torch/nn/parameter.py:266-275`; consumed by `module.rs:374` during state-dict load. |
24//! | REQ-3 | SHIPPED | `Buffer::zeros` / `::ones` / `::from_slice` factories; consumed by `module.rs:543` `running_mean: Buffer::zeros(&[2])?` (BN canonical init pattern) and `norm.rs` BatchNorm layers. |
25//! | REQ-4 | SHIPPED | `tensor(&self) -> &Tensor<T>` and `into_tensor(self)` accessors; consumed by `module.rs:75` `buffer.tensor().clone()` inside the default `state_dict`. |
26//! | REQ-5 | SHIPPED | `set_data` re-enforces `requires_grad = false`; consumed by BN layers in `ferrotorch-nn/src/norm.rs` updating running mean / variance across forward passes. |
27//! | REQ-6 | SHIPPED | `to(device) -> FerrotorchResult<Self>`; consumed by `module.rs` `Module::to_device` default impl calling `buffer.to(device)?` for each buffer. |
28//! | REQ-7 | SHIPPED | `impl Deref<Target = Tensor<T>>` (R-DEV-7 Rust analog of Python inheritance); consumed by every callsite invoking a Tensor method on a Buffer (`buf.shape()` in `module.rs:365` shape check). |
29//! | REQ-8 | SHIPPED | `#[derive(Debug, Clone)]` with shallow Arc-backed clone; consumed by `Module::state_dict` calling `buffer.tensor().clone()`. |
30
31use ferrotorch_core::{Device, FerrotorchResult, Float, Tensor};
32
33/// A non-trainable tensor that is part of a module's persistent state.
34///
35/// Like [`crate::Parameter`], `Buffer<T>` derefs to `Tensor<T>` for all
36/// tensor operations and clones share the same underlying Arc identity.
37/// Unlike `Parameter`, `requires_grad` is **always false**.
38#[derive(Debug, Clone)]
39pub struct Buffer<T: Float> {
40    data: Tensor<T>,
41}
42
43impl<T: Float> Buffer<T> {
44    /// Wrap a tensor as a buffer. `requires_grad` is forced to `false`.
45    pub fn new(tensor: Tensor<T>) -> Self {
46        Self {
47            data: tensor.requires_grad_(false),
48        }
49    }
50
51    /// Create a zero-filled buffer with the given shape.
52    pub fn zeros(shape: &[usize]) -> FerrotorchResult<Self> {
53        let t = ferrotorch_core::zeros::<T>(shape)?;
54        Ok(Self::new(t))
55    }
56
57    /// Create a one-filled buffer with the given shape.
58    pub fn ones(shape: &[usize]) -> FerrotorchResult<Self> {
59        let t = ferrotorch_core::ones::<T>(shape)?;
60        Ok(Self::new(t))
61    }
62
63    /// Create a buffer from a slice + shape.
64    pub fn from_slice(data: &[T], shape: &[usize]) -> FerrotorchResult<Self> {
65        let t = ferrotorch_core::from_slice(data, shape)?;
66        Ok(Self::new(t))
67    }
68
69    /// Borrow the underlying tensor.
70    #[inline]
71    pub fn tensor(&self) -> &Tensor<T> {
72        &self.data
73    }
74
75    /// Consume and return the underlying tensor.
76    pub fn into_tensor(self) -> Tensor<T> {
77        self.data
78    }
79
80    /// Replace the buffer's data. The new tensor is set to
81    /// `requires_grad = false` regardless of its input state.
82    pub fn set_data(&mut self, tensor: Tensor<T>) {
83        self.data = tensor.requires_grad_(false);
84    }
85
86    /// Move this buffer to a device.
87    pub fn to(&self, device: Device) -> FerrotorchResult<Self> {
88        Ok(Self::new(self.data.to(device)?))
89    }
90}
91
92impl<T: Float> std::ops::Deref for Buffer<T> {
93    type Target = Tensor<T>;
94
95    #[inline]
96    fn deref(&self) -> &Self::Target {
97        &self.data
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn buffer_does_not_require_grad() {
107        let b = Buffer::<f32>::zeros(&[3, 4]).unwrap();
108        assert!(!b.requires_grad());
109    }
110
111    #[test]
112    fn buffer_derefs_to_tensor() {
113        let b = Buffer::<f32>::ones(&[2, 3]).unwrap();
114        assert_eq!(b.shape(), &[2, 3]);
115        assert_eq!(b.numel(), 6);
116    }
117
118    #[test]
119    fn buffer_clone_shares_identity() {
120        let b = Buffer::<f32>::zeros(&[4]).unwrap();
121        let b2 = b.clone();
122        assert!(b.tensor().is_same(b2.tensor()));
123    }
124
125    #[test]
126    fn buffer_set_data_keeps_no_grad() {
127        let mut b = Buffer::<f32>::zeros(&[3]).unwrap();
128        let t = ferrotorch_core::ones::<f32>(&[3])
129            .unwrap()
130            .requires_grad_(true);
131        assert!(t.requires_grad());
132        b.set_data(t);
133        assert!(!b.requires_grad());
134    }
135
136    #[test]
137    fn buffer_to_cpu_preserves_data() {
138        let b = Buffer::<f32>::from_slice(&[1.0, 2.0, 3.0], &[3]).unwrap();
139        let b2 = b.to(ferrotorch_core::Device::Cpu).unwrap();
140        assert_eq!(b2.data().unwrap(), &[1.0, 2.0, 3.0]);
141        assert!(!b2.requires_grad());
142    }
143}