oxicuda_vision/patch_embed/
pos_embed.rs1use crate::{
11 error::{VisionError, VisionResult},
12 handle::LcgRng,
13};
14
15pub fn pos_2d_sincos(grid_h: usize, grid_w: usize, dim: usize) -> VisionResult<Vec<f32>> {
36 if dim == 0 || dim % 4 != 0 {
37 return Err(VisionError::InvalidEmbedDim(dim));
38 }
39 if grid_h == 0 || grid_w == 0 {
40 return Err(VisionError::InvalidImageSize {
41 height: grid_h,
42 width: grid_w,
43 channels: 1,
44 });
45 }
46
47 let n = grid_h * grid_w;
48 let dim_half = dim / 2; let dim_qtr = dim / 4; let mut out = vec![0.0f32; n * dim];
52
53 let freqs: Vec<f32> = (0..dim_qtr)
55 .map(|k| 1.0 / 10000_f32.powf(2.0 * k as f32 / dim_half as f32))
56 .collect();
57
58 for h in 0..grid_h {
59 for w in 0..grid_w {
60 let pos = h * grid_w + w;
61 let base = pos * dim;
62
63 for k in 0..dim_qtr {
65 let angle = h as f32 * freqs[k];
66 out[base + k] = angle.sin();
67 out[base + dim_qtr + k] = angle.cos();
68 }
69
70 for k in 0..dim_qtr {
72 let angle = w as f32 * freqs[k];
73 out[base + dim_half + k] = angle.sin();
74 out[base + dim_half + dim_qtr + k] = angle.cos();
75 }
76 }
77 }
78
79 Ok(out)
80}
81
82#[derive(Debug, Clone)]
89pub struct LearnablePosEmbed {
90 pub table: Vec<f32>,
92 pub n_positions: usize,
94 pub embed_dim: usize,
96}
97
98impl LearnablePosEmbed {
99 pub fn new(n_positions: usize, embed_dim: usize, rng: &mut LcgRng) -> VisionResult<Self> {
101 if embed_dim == 0 {
102 return Err(VisionError::InvalidEmbedDim(embed_dim));
103 }
104 if n_positions == 0 {
105 return Err(VisionError::EmptyInput("n_positions"));
106 }
107 let mut table = vec![0.0f32; n_positions * embed_dim];
108 rng.fill_normal(&mut table);
109 let scale = 0.02;
110 for v in &mut table {
111 *v *= scale;
112 }
113 Ok(Self {
114 table,
115 n_positions,
116 embed_dim,
117 })
118 }
119
120 pub fn position_embedding(&self, i: usize) -> VisionResult<&[f32]> {
122 if i >= self.n_positions {
123 return Err(VisionError::DimensionMismatch {
124 expected: self.n_positions - 1,
125 got: i,
126 });
127 }
128 let start = i * self.embed_dim;
129 Ok(&self.table[start..start + self.embed_dim])
130 }
131}
132
133pub fn add_pos_embed(tokens: &mut [f32], pos_embed: &[f32], embed_dim: usize) -> VisionResult<()> {
142 if tokens.len() != pos_embed.len() {
143 return Err(VisionError::DimensionMismatch {
144 expected: tokens.len(),
145 got: pos_embed.len(),
146 });
147 }
148 if embed_dim == 0 || tokens.len() % embed_dim != 0 {
149 return Err(VisionError::InvalidEmbedDim(embed_dim));
150 }
151 for (t, p) in tokens.iter_mut().zip(pos_embed.iter()) {
152 *t += p;
153 }
154 Ok(())
155}
156
157#[cfg(test)]
160mod tests {
161 use super::*;
162 use crate::handle::LcgRng;
163
164 #[test]
165 fn pos_2d_sincos_shape() {
166 let pe = pos_2d_sincos(4, 4, 64).expect("ok");
167 assert_eq!(pe.len(), 4 * 4 * 64); }
169
170 #[test]
171 fn pos_2d_sincos_finite() {
172 let pe = pos_2d_sincos(8, 8, 64).expect("ok");
173 assert!(pe.iter().all(|v| v.is_finite()), "non-finite pos embed");
174 }
175
176 #[test]
177 fn pos_2d_sincos_in_range() {
178 let pe = pos_2d_sincos(4, 4, 64).expect("ok");
179 assert!(
181 pe.iter().all(|&v| (-1.0f32..=1.0).contains(&v)),
182 "out of [-1,1]"
183 );
184 }
185
186 #[test]
187 fn pos_2d_sincos_invalid_dim_not_div4() {
188 let r = pos_2d_sincos(4, 4, 6); assert!(matches!(r, Err(VisionError::InvalidEmbedDim(6))));
190 }
191
192 #[test]
193 fn pos_2d_sincos_invalid_grid_zero() {
194 let r = pos_2d_sincos(0, 4, 64);
195 assert!(matches!(r, Err(VisionError::InvalidImageSize { .. })));
196 }
197
198 #[test]
199 fn pos_2d_sincos_distinct_positions() {
200 let pe = pos_2d_sincos(4, 4, 64).expect("ok");
201 let embed_dim = 64;
202 let p00 = &pe[0..embed_dim];
204 let p01 = &pe[embed_dim..2 * embed_dim];
205 let diff: f32 = p00.iter().zip(p01.iter()).map(|(a, b)| (a - b).abs()).sum();
206 assert!(
207 diff > 1e-3,
208 "adjacent positions should differ; total diff={diff}"
209 );
210 }
211
212 #[test]
213 fn pos_2d_sincos_periodicity_check() {
214 let pe = pos_2d_sincos(4, 1, 4).expect("ok"); assert!((pe[0] - 0.0_f32.sin()).abs() < 1e-6);
219 assert!((pe[4] - 1.0_f32.sin()).abs() < 1e-6);
221 }
222
223 #[test]
224 fn learnable_pos_embed_shape() {
225 let mut rng = LcgRng::new(1);
226 let lpe = LearnablePosEmbed::new(65, 64, &mut rng).expect("ok"); assert_eq!(lpe.table.len(), 65 * 64);
228 }
229
230 #[test]
231 fn learnable_pos_embed_finite() {
232 let mut rng = LcgRng::new(2);
233 let lpe = LearnablePosEmbed::new(17, 32, &mut rng).expect("ok");
234 assert!(lpe.table.iter().all(|v| v.is_finite()));
235 }
236
237 #[test]
238 fn learnable_pos_embed_access() {
239 let mut rng = LcgRng::new(3);
240 let lpe = LearnablePosEmbed::new(8, 16, &mut rng).expect("ok");
241 let emb = lpe.position_embedding(3).expect("ok");
242 assert_eq!(emb.len(), 16);
243 assert_eq!(emb, &lpe.table[3 * 16..4 * 16]);
244 }
245
246 #[test]
247 fn learnable_pos_embed_out_of_bounds_errors() {
248 let mut rng = LcgRng::new(4);
249 let lpe = LearnablePosEmbed::new(8, 16, &mut rng).expect("ok");
250 let r = lpe.position_embedding(8);
251 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
252 }
253
254 #[test]
255 fn add_pos_embed_in_place() {
256 let mut tokens = vec![1.0f32; 4 * 8]; let pos = vec![0.5f32; 4 * 8];
258 add_pos_embed(&mut tokens, &pos, 8).expect("ok");
259 assert!(tokens.iter().all(|&v| (v - 1.5).abs() < 1e-6));
260 }
261
262 #[test]
263 fn add_pos_embed_shape_mismatch_errors() {
264 let mut tokens = vec![1.0f32; 4 * 8];
265 let pos = vec![0.5f32; 3 * 8]; let r = add_pos_embed(&mut tokens, &pos, 8);
267 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
268 }
269}