ferrotorch_nn/buffer.rs
1//! `Buffer<T>` — non-trainable persistent module state. (#583)
2//!
3//! Mirrors `torch.nn.Module.register_buffer`. A buffer is a tensor that:
4//! - is part of the module's persistent state (saved / loaded with the
5//! module via `state_dict`),
6//! - moves with the module across devices (`to_device`),
7//! - does **not** participate in gradient descent (no `requires_grad`).
8//!
9//! Typical uses: running mean / running variance in `BatchNorm`, position
10//! tables in attention layers, masks, scaling constants — anything the
11//! module needs to remember across forward passes that should not be
12//! optimized.
13//!
14//! Buffers are exposed via the `Module` trait through `buffers()`,
15//! `buffers_mut()`, and `named_buffers()`. Concrete modules opt in by
16//! storing `Buffer<T>` fields and overriding the relevant trait methods.
17//!
18//! ## REQ status (per `.design/ferrotorch-nn/buffer.md`)
19//!
20//! | REQ | Status | Evidence |
21//! |---|---|---|
22//! | REQ-1 | SHIPPED | `pub struct Buffer<T: Float> { data: Tensor<T> }` with `#[derive(Debug, Clone)]` mirrors `torch/nn/parameter.py:249-279` via R-DEV-7 newtype; consumed by `pub use buffer::Buffer` at `lib.rs:223` + `lib.rs:273` prelude; `ferrotorch-nn/src/module.rs:5` `use crate::buffer::Buffer`; `module.rs:374` `*buf = Buffer::new(tensor.clone())` inside the default `load_state_dict`. |
23//! | REQ-2 | SHIPPED | `Buffer::new(tensor)` enforces `requires_grad = false` mirroring `torch/nn/parameter.py:266-275`; consumed by `module.rs:374` during state-dict load. |
24//! | REQ-3 | SHIPPED | `Buffer::zeros` / `::ones` / `::from_slice` factories; consumed by `module.rs:543` `running_mean: Buffer::zeros(&[2])?` (BN canonical init pattern) and `norm.rs` BatchNorm layers. |
25//! | REQ-4 | SHIPPED | `tensor(&self) -> &Tensor<T>` and `into_tensor(self)` accessors; consumed by `module.rs:75` `buffer.tensor().clone()` inside the default `state_dict`. |
26//! | REQ-5 | SHIPPED | `set_data` re-enforces `requires_grad = false`; consumed by BN layers in `ferrotorch-nn/src/norm.rs` updating running mean / variance across forward passes. |
27//! | REQ-6 | SHIPPED | `to(device) -> FerrotorchResult<Self>`; consumed by `module.rs` `Module::to_device` default impl calling `buffer.to(device)?` for each buffer. |
28//! | REQ-7 | SHIPPED | `impl Deref<Target = Tensor<T>>` (R-DEV-7 Rust analog of Python inheritance); consumed by every callsite invoking a Tensor method on a Buffer (`buf.shape()` in `module.rs:365` shape check). |
29//! | REQ-8 | SHIPPED | `#[derive(Debug, Clone)]` with shallow Arc-backed clone; consumed by `Module::state_dict` calling `buffer.tensor().clone()`. |
30
31use ferrotorch_core::{Device, FerrotorchResult, Float, Tensor};
32
33/// A non-trainable tensor that is part of a module's persistent state.
34///
35/// Like [`crate::Parameter`], `Buffer<T>` derefs to `Tensor<T>` for all
36/// tensor operations and clones share the same underlying Arc identity.
37/// Unlike `Parameter`, `requires_grad` is **always false**.
38#[derive(Debug, Clone)]
39pub struct Buffer<T: Float> {
40 data: Tensor<T>,
41}
42
43impl<T: Float> Buffer<T> {
44 /// Wrap a tensor as a buffer. `requires_grad` is forced to `false`.
45 pub fn new(tensor: Tensor<T>) -> Self {
46 Self {
47 data: tensor.requires_grad_(false),
48 }
49 }
50
51 /// Create a zero-filled buffer with the given shape.
52 pub fn zeros(shape: &[usize]) -> FerrotorchResult<Self> {
53 let t = ferrotorch_core::zeros::<T>(shape)?;
54 Ok(Self::new(t))
55 }
56
57 /// Create a one-filled buffer with the given shape.
58 pub fn ones(shape: &[usize]) -> FerrotorchResult<Self> {
59 let t = ferrotorch_core::ones::<T>(shape)?;
60 Ok(Self::new(t))
61 }
62
63 /// Create a buffer from a slice + shape.
64 pub fn from_slice(data: &[T], shape: &[usize]) -> FerrotorchResult<Self> {
65 let t = ferrotorch_core::from_slice(data, shape)?;
66 Ok(Self::new(t))
67 }
68
69 /// Borrow the underlying tensor.
70 #[inline]
71 pub fn tensor(&self) -> &Tensor<T> {
72 &self.data
73 }
74
75 /// Consume and return the underlying tensor.
76 pub fn into_tensor(self) -> Tensor<T> {
77 self.data
78 }
79
80 /// Replace the buffer's data. The new tensor is set to
81 /// `requires_grad = false` regardless of its input state.
82 pub fn set_data(&mut self, tensor: Tensor<T>) {
83 self.data = tensor.requires_grad_(false);
84 }
85
86 /// Move this buffer to a device.
87 pub fn to(&self, device: Device) -> FerrotorchResult<Self> {
88 Ok(Self::new(self.data.to(device)?))
89 }
90}
91
92impl<T: Float> std::ops::Deref for Buffer<T> {
93 type Target = Tensor<T>;
94
95 #[inline]
96 fn deref(&self) -> &Self::Target {
97 &self.data
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn buffer_does_not_require_grad() {
107 let b = Buffer::<f32>::zeros(&[3, 4]).unwrap();
108 assert!(!b.requires_grad());
109 }
110
111 #[test]
112 fn buffer_derefs_to_tensor() {
113 let b = Buffer::<f32>::ones(&[2, 3]).unwrap();
114 assert_eq!(b.shape(), &[2, 3]);
115 assert_eq!(b.numel(), 6);
116 }
117
118 #[test]
119 fn buffer_clone_shares_identity() {
120 let b = Buffer::<f32>::zeros(&[4]).unwrap();
121 let b2 = b.clone();
122 assert!(b.tensor().is_same(b2.tensor()));
123 }
124
125 #[test]
126 fn buffer_set_data_keeps_no_grad() {
127 let mut b = Buffer::<f32>::zeros(&[3]).unwrap();
128 let t = ferrotorch_core::ones::<f32>(&[3])
129 .unwrap()
130 .requires_grad_(true);
131 assert!(t.requires_grad());
132 b.set_data(t);
133 assert!(!b.requires_grad());
134 }
135
136 #[test]
137 fn buffer_to_cpu_preserves_data() {
138 let b = Buffer::<f32>::from_slice(&[1.0, 2.0, 3.0], &[3]).unwrap();
139 let b2 = b.to(ferrotorch_core::Device::Cpu).unwrap();
140 assert_eq!(b2.data().unwrap(), &[1.0, 2.0, 3.0]);
141 assert!(!b2.requires_grad());
142 }
143}