Skip to main content

oxicuda_vision/patch_embed/
pos_embed.rs

1//! 2-D positional encodings for Vision Transformers.
2//!
3//! Provides:
4//! - **`pos_2d_sincos`**: deterministic 2-D sinusoidal positional encoding
5//!   as used in MAE / BEiT / DeiT. The first half of `dim` encodes the
6//!   row (H) axis; the second half encodes the column (W) axis.
7//! - **`LearnablePosEmbed`**: a simple learned position table.
8//! - **`add_pos_embed`**: in-place addition of position embeddings to tokens.
9
10use crate::{
11    error::{VisionError, VisionResult},
12    handle::LcgRng,
13};
14
15// ─── 2-D sinusoidal position encoding ────────────────────────────────────────
16
17/// Compute a 2-D sinusoidal positional encoding for a `grid_h × grid_w` grid.
18///
19/// Each position `(h, w)` gets a `dim`-dimensional vector: the first `dim/2`
20/// dimensions encode the row using the standard 1-D sinusoidal schedule, and
21/// the second `dim/2` dimensions encode the column.
22///
23/// The temperature-based frequency schedule (Vaswani et al.) is used:
24///
25/// ```text
26/// freq[k] = 1 / 10000^(2k / dim_half)
27/// encoding[h, w, k]        = sin(h * freq[k])   for k in [0, dim/4)
28/// encoding[h, w, k + dim/4] = cos(h * freq[k])   for k in [0, dim/4)
29/// encoding[h, w, dim/2 + k]        = sin(w * freq[k])
30/// encoding[h, w, dim/2 + k + dim/4] = cos(w * freq[k])
31/// ```
32///
33/// Returns a flat `[grid_h * grid_w, dim]` `Vec<f32>` in row-major order.
34/// `dim` must be divisible by 4.
35pub fn pos_2d_sincos(grid_h: usize, grid_w: usize, dim: usize) -> VisionResult<Vec<f32>> {
36    if dim == 0 || dim % 4 != 0 {
37        return Err(VisionError::InvalidEmbedDim(dim));
38    }
39    if grid_h == 0 || grid_w == 0 {
40        return Err(VisionError::InvalidImageSize {
41            height: grid_h,
42            width: grid_w,
43            channels: 1,
44        });
45    }
46
47    let n = grid_h * grid_w;
48    let dim_half = dim / 2; // split: first half H, second half W
49    let dim_qtr = dim / 4; // sin/cos each get dim_qtr freqs
50
51    let mut out = vec![0.0f32; n * dim];
52
53    // Temperature = 10000^(2k / dim_half), k ∈ [0, dim_qtr)
54    let freqs: Vec<f32> = (0..dim_qtr)
55        .map(|k| 1.0 / 10000_f32.powf(2.0 * k as f32 / dim_half as f32))
56        .collect();
57
58    for h in 0..grid_h {
59        for w in 0..grid_w {
60            let pos = h * grid_w + w;
61            let base = pos * dim;
62
63            // H-axis encoding: indices [0, dim_qtr) sin, [dim_qtr, dim_half) cos
64            for k in 0..dim_qtr {
65                let angle = h as f32 * freqs[k];
66                out[base + k] = angle.sin();
67                out[base + dim_qtr + k] = angle.cos();
68            }
69
70            // W-axis encoding: indices [dim_half, dim_half+dim_qtr) sin, ...
71            for k in 0..dim_qtr {
72                let angle = w as f32 * freqs[k];
73                out[base + dim_half + k] = angle.sin();
74                out[base + dim_half + dim_qtr + k] = angle.cos();
75            }
76        }
77    }
78
79    Ok(out)
80}
81
82// ─── Learnable position embedding ────────────────────────────────────────────
83
84/// Learnable position embedding table: `[n_positions, embed_dim]`.
85///
86/// Row `i` is the position embedding for the `i`-th token (index 0 is
87/// conventionally the CLS token position).
88#[derive(Debug, Clone)]
89pub struct LearnablePosEmbed {
90    /// Flat `[n_positions × embed_dim]` parameter table.
91    pub table: Vec<f32>,
92    /// Number of positions (including CLS if present).
93    pub n_positions: usize,
94    /// Embedding dimension.
95    pub embed_dim: usize,
96}
97
98impl LearnablePosEmbed {
99    /// Create a learnable position embedding with small Gaussian init.
100    pub fn new(n_positions: usize, embed_dim: usize, rng: &mut LcgRng) -> VisionResult<Self> {
101        if embed_dim == 0 {
102            return Err(VisionError::InvalidEmbedDim(embed_dim));
103        }
104        if n_positions == 0 {
105            return Err(VisionError::EmptyInput("n_positions"));
106        }
107        let mut table = vec![0.0f32; n_positions * embed_dim];
108        rng.fill_normal(&mut table);
109        let scale = 0.02;
110        for v in &mut table {
111            *v *= scale;
112        }
113        Ok(Self {
114            table,
115            n_positions,
116            embed_dim,
117        })
118    }
119
120    /// Return the embedding for position `i` as a slice of length `embed_dim`.
121    pub fn position_embedding(&self, i: usize) -> VisionResult<&[f32]> {
122        if i >= self.n_positions {
123            return Err(VisionError::DimensionMismatch {
124                expected: self.n_positions - 1,
125                got: i,
126            });
127        }
128        let start = i * self.embed_dim;
129        Ok(&self.table[start..start + self.embed_dim])
130    }
131}
132
133// ─── add_pos_embed ────────────────────────────────────────────────────────────
134
135/// Add positional embeddings to a token sequence in-place.
136///
137/// `tokens` is flat `[n_tokens × embed_dim]`.
138/// `pos_embed` is flat `[n_tokens × embed_dim]` (or a prefix thereof).
139///
140/// Validates shape compatibility and returns an error on mismatch.
141pub fn add_pos_embed(tokens: &mut [f32], pos_embed: &[f32], embed_dim: usize) -> VisionResult<()> {
142    if tokens.len() != pos_embed.len() {
143        return Err(VisionError::DimensionMismatch {
144            expected: tokens.len(),
145            got: pos_embed.len(),
146        });
147    }
148    if embed_dim == 0 || tokens.len() % embed_dim != 0 {
149        return Err(VisionError::InvalidEmbedDim(embed_dim));
150    }
151    for (t, p) in tokens.iter_mut().zip(pos_embed.iter()) {
152        *t += p;
153    }
154    Ok(())
155}
156
157// ─── Tests ───────────────────────────────────────────────────────────────────
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use crate::handle::LcgRng;
163
164    #[test]
165    fn pos_2d_sincos_shape() {
166        let pe = pos_2d_sincos(4, 4, 64).expect("ok");
167        assert_eq!(pe.len(), 4 * 4 * 64); // 16 positions × 64 dims
168    }
169
170    #[test]
171    fn pos_2d_sincos_finite() {
172        let pe = pos_2d_sincos(8, 8, 64).expect("ok");
173        assert!(pe.iter().all(|v| v.is_finite()), "non-finite pos embed");
174    }
175
176    #[test]
177    fn pos_2d_sincos_in_range() {
178        let pe = pos_2d_sincos(4, 4, 64).expect("ok");
179        // sin/cos values are in [-1, 1]
180        assert!(
181            pe.iter().all(|&v| (-1.0f32..=1.0).contains(&v)),
182            "out of [-1,1]"
183        );
184    }
185
186    #[test]
187    fn pos_2d_sincos_invalid_dim_not_div4() {
188        let r = pos_2d_sincos(4, 4, 6); // 6 % 4 != 0
189        assert!(matches!(r, Err(VisionError::InvalidEmbedDim(6))));
190    }
191
192    #[test]
193    fn pos_2d_sincos_invalid_grid_zero() {
194        let r = pos_2d_sincos(0, 4, 64);
195        assert!(matches!(r, Err(VisionError::InvalidImageSize { .. })));
196    }
197
198    #[test]
199    fn pos_2d_sincos_distinct_positions() {
200        let pe = pos_2d_sincos(4, 4, 64).expect("ok");
201        let embed_dim = 64;
202        // Position (0,0) and (0,1) should differ
203        let p00 = &pe[0..embed_dim];
204        let p01 = &pe[embed_dim..2 * embed_dim];
205        let diff: f32 = p00.iter().zip(p01.iter()).map(|(a, b)| (a - b).abs()).sum();
206        assert!(
207            diff > 1e-3,
208            "adjacent positions should differ; total diff={diff}"
209        );
210    }
211
212    #[test]
213    fn pos_2d_sincos_periodicity_check() {
214        // The first dimension encodes frequency 1.0 (k=0, freq=1/10000^0=1),
215        // so index 0 for position (h,w) is sin(h * 1.0).
216        let pe = pos_2d_sincos(4, 1, 4).expect("ok"); // 4 rows, 1 col, dim=4
217        // Position h=0: sin(0*1)=0
218        assert!((pe[0] - 0.0_f32.sin()).abs() < 1e-6);
219        // Position h=1: sin(1*1)=sin(1)
220        assert!((pe[4] - 1.0_f32.sin()).abs() < 1e-6);
221    }
222
223    #[test]
224    fn learnable_pos_embed_shape() {
225        let mut rng = LcgRng::new(1);
226        let lpe = LearnablePosEmbed::new(65, 64, &mut rng).expect("ok"); // 64 patches + CLS
227        assert_eq!(lpe.table.len(), 65 * 64);
228    }
229
230    #[test]
231    fn learnable_pos_embed_finite() {
232        let mut rng = LcgRng::new(2);
233        let lpe = LearnablePosEmbed::new(17, 32, &mut rng).expect("ok");
234        assert!(lpe.table.iter().all(|v| v.is_finite()));
235    }
236
237    #[test]
238    fn learnable_pos_embed_access() {
239        let mut rng = LcgRng::new(3);
240        let lpe = LearnablePosEmbed::new(8, 16, &mut rng).expect("ok");
241        let emb = lpe.position_embedding(3).expect("ok");
242        assert_eq!(emb.len(), 16);
243        assert_eq!(emb, &lpe.table[3 * 16..4 * 16]);
244    }
245
246    #[test]
247    fn learnable_pos_embed_out_of_bounds_errors() {
248        let mut rng = LcgRng::new(4);
249        let lpe = LearnablePosEmbed::new(8, 16, &mut rng).expect("ok");
250        let r = lpe.position_embedding(8);
251        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
252    }
253
254    #[test]
255    fn add_pos_embed_in_place() {
256        let mut tokens = vec![1.0f32; 4 * 8]; // 4 tokens, dim=8
257        let pos = vec![0.5f32; 4 * 8];
258        add_pos_embed(&mut tokens, &pos, 8).expect("ok");
259        assert!(tokens.iter().all(|&v| (v - 1.5).abs() < 1e-6));
260    }
261
262    #[test]
263    fn add_pos_embed_shape_mismatch_errors() {
264        let mut tokens = vec![1.0f32; 4 * 8];
265        let pos = vec![0.5f32; 3 * 8]; // wrong size
266        let r = add_pos_embed(&mut tokens, &pos, 8);
267        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
268    }
269}