optical_embeddings/vision/
sam.rs

1//! SAM (Segment Anything Model) vision encoder used by Optical Embeddings.
2//!
3//! This module implements the SAM-based encoder from the DeepSeek-OCR paper, which uses
4//! window attention and convolutional compression to efficiently encode images into
5//! compact vision tokens.
6//!
7//! # Architecture
8//!
9//! The SAM encoder consists of:
10//! - **Patch Embedding**: Converts input images to 16×16 patches
11//! - **Transformer Blocks**: Process patches with window/global attention
12//! - **Neck**: Projects features to output channels (typically 256)
13//! - **Compressor**: 16× spatial reduction via two stride-2 conv layers
14//!
15//! # Example
16//!
17//! ```ignore
18//! let config = SamConfig::base();
19//! let encoder = SamEncoder::new(&config, &device);
20//! let features = encoder.forward(image_tensor);
21//! // Output: [batch, 1024, H/16, W/16] compressed features
22//! ```
23
24use burn::nn::{
25    conv::{Conv2d, Conv2dConfig},
26    LayerNorm, LayerNormConfig, Linear, LinearConfig, PaddingConfig2d,
27};
28use burn::prelude::*;
29use burn::tensor::{backend::Backend, Distribution, Tensor};
30use log::{debug, info, trace};
31
32use super::attention::WindowAttention;
33use crate::config::SamConfig;
34
35/// Patch embedding layer that converts images to patch tokens.
36///
37/// Implements a convolutional layer with kernel and stride equal to patch size,
38/// effectively partitioning the image into non-overlapping patches.
39#[derive(Module, Debug)]
40struct PatchEmbed<B: Backend> {
41    proj: Conv2d<B>,
42    _patch: usize,
43    _embed: usize,
44}
45
46impl<B: Backend> PatchEmbed<B> {
47    /// Creates a new patch embedding layer.
48    ///
49    /// # Arguments
50    /// * `patch_size` - Size of each patch (typically 16)
51    /// * `in_chans` - Number of input channels (3 for RGB)
52    /// * `embed_dim` - Output embedding dimension
53    /// * `device` - Device to create the layer on
54    fn new(patch_size: usize, in_chans: usize, embed_dim: usize, device: &B::Device) -> Self {
55        debug!(
56            "Creating PatchEmbed: patch_size={}, in_chans={}, embed_dim={}",
57            patch_size, in_chans, embed_dim
58        );
59        let proj = Conv2dConfig::new([in_chans, embed_dim], [patch_size, patch_size])
60            .with_stride([patch_size, patch_size])
61            .with_bias(true)
62            .init(device);
63        Self {
64            proj,
65            _patch: patch_size,
66            _embed: embed_dim,
67        }
68    }
69
70    /// Forward pass: converts image to patch embeddings.
71    ///
72    /// # Shape
73    /// - Input: `[B, C, H, W]`
74    /// - Output: `[B, embed_dim, H/patch_size, W/patch_size]`
75    fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
76        trace!("PatchEmbed input shape: {:?}", x.dims());
77        let out = self.proj.forward(x);
78        trace!("PatchEmbed output shape: {:?}", out.dims());
79        out
80    }
81}
82
83/// Multi-Layer Perceptron (feedforward) block with GELU activation.
84///
85/// Standard two-layer MLP used in transformer architectures.
86#[derive(Module, Debug)]
87struct MlpBlock<B: Backend> {
88    fc1: Linear<B>,
89    fc2: Linear<B>,
90}
91
92impl<B: Backend> MlpBlock<B> {
93    /// Creates a new MLP block.
94    ///
95    /// # Arguments
96    /// * `dim` - Input/output dimension
97    /// * `hidden` - Hidden layer dimension (typically 4× the input dim)
98    /// * `device` - Device to create the layer on
99    fn new(dim: usize, hidden: usize, device: &B::Device) -> Self {
100        debug!("Creating MlpBlock: dim={}, hidden={}", dim, hidden);
101        let fc1 = LinearConfig::new(dim, hidden).with_bias(true).init(device);
102        let fc2 = LinearConfig::new(hidden, dim).with_bias(true).init(device);
103        Self { fc1, fc2 }
104    }
105
106    /// Forward pass with GELU activation.
107    ///
108    /// # Shape
109    /// - Input: `[B, N, dim]`
110    /// - Output: `[B, N, dim]`
111    fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
112        trace!("MlpBlock input shape: {:?}", x.dims());
113        let x = self.fc1.forward(x);
114        let x = burn::tensor::activation::gelu(x);
115        self.fc2.forward(x)
116    }
117}
118
119/// SAM transformer block with window or global attention.
120///
121/// Implements a standard transformer block with:
122/// - Layer normalization
123/// - Window/global multi-head attention
124/// - MLP with residual connections
125#[derive(Module, Debug)]
126struct SamBlock<B: Backend> {
127    norm1: LayerNorm<B>,
128    attn: WindowAttention<B>,
129    norm2: LayerNorm<B>,
130    mlp: MlpBlock<B>,
131    window_size: usize,
132}
133
134impl<B: Backend> SamBlock<B> {
135    /// Creates a new SAM transformer block.
136    ///
137    /// # Arguments
138    /// * `dim` - Embedding dimension
139    /// * `heads` - Number of attention heads
140    /// * `mlp_ratio` - MLP hidden dimension ratio (typically 4.0)
141    /// * `window_size` - Window size for attention (0 = global attention)
142    /// * `use_rel_pos` - Whether to use relative position bias
143    /// * `input_size` - Spatial size of input features
144    /// * `device` - Device to create the layer on
145    fn new(
146        dim: usize,
147        heads: usize,
148        mlp_ratio: f32,
149        window_size: usize,
150        use_rel_pos: bool,
151        input_size: (usize, usize),
152        device: &B::Device,
153    ) -> Self {
154        debug!(
155            "Creating SamBlock: dim={}, heads={}, window_size={}, input_size={:?}",
156            dim, heads, window_size, input_size
157        );
158        let norm1 = LayerNormConfig::new(dim).init(device);
159        let norm2 = LayerNormConfig::new(dim).init(device);
160        let mlp = MlpBlock::new(dim, (dim as f32 * mlp_ratio) as usize, device);
161        let attn = WindowAttention::new(
162            dim,
163            heads,
164            use_rel_pos,
165            if window_size == 0 {
166                input_size
167            } else {
168                (window_size, window_size)
169            },
170            device,
171        );
172        Self {
173            norm1,
174            attn,
175            norm2,
176            mlp,
177            window_size,
178        }
179    }
180
181    /// Forward pass with pre-norm architecture and residual connections.
182    ///
183    /// # Shape
184    /// - Input: `[B, H, W, C]`
185    /// - Output: `[B, H, W, C]`
186    fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
187        trace!("SamBlock input shape: {:?}", x.dims());
188        let [b, h, w, c] = {
189            let d = x.dims();
190            [d[0], d[1], d[2], d[3]]
191        };
192        let x1 = self
193            .norm1
194            .forward(x.clone().reshape([b * h * w, c]))
195            .reshape([b, h, w, c]);
196        let x_attn = if self.window_size > 0 {
197            self.attn.forward_windowed(x1, self.window_size)
198        } else {
199            self.attn.forward(x1)
200        };
201        let x = x + x_attn;
202        let x2 = self
203            .norm2
204            .forward(x.clone().reshape([b * h * w, c]))
205            .reshape([b, h, w, c]);
206        let x_mlp = self
207            .mlp
208            .forward(x2.reshape([b, h * w, c]))
209            .reshape([b, h, w, c]);
210        x + x_mlp
211    }
212}
213
214/// SAM vision encoder for the DeepSeek-OCR DeepEncoder architecture.
215///
216/// Implements the complete SAM encoding pipeline:
217/// 1. Patch embedding: Image → patches
218/// 2. Position embedding: Add learned position encodings
219/// 3. Transformer blocks: Window/global attention processing
220/// 4. Neck: Project to output channels
221/// 5. Compression: 16× spatial reduction via convolutional layers
222///
223/// # Compression Pipeline
224///
225/// For a 1024×1024 image with 16×16 patches:
226/// - Input: `[B, 3, 1024, 1024]`
227/// - After patch embed: `[B, 768, 64, 64]` (1024/16 = 64)
228/// - After transformers: `[B, 768, 64, 64]`
229/// - After neck: `[B, 256, 64, 64]`
230/// - After comp1: `[B, 512, 32, 32]` (stride=2)
231/// - After comp2: `[B, 1024, 16, 16]` (stride=2)
232/// - Result: **16× compression** (64² → 16² tokens)
233#[derive(Module, Debug)]
234pub struct SamEncoder<B: Backend> {
235    patch_embed: PatchEmbed<B>,
236    pos_embed: Tensor<B, 4>,
237    blocks: Vec<SamBlock<B>>,
238    neck1: Conv2d<B>,
239    neck2: Conv2d<B>,
240    comp1: Conv2d<B>,
241    comp2: Conv2d<B>,
242}
243
244impl<B: Backend> SamEncoder<B> {
245    /// Creates a new SAM encoder from configuration.
246    ///
247    /// # Arguments
248    /// * `config` - SAM configuration (see `SamConfig`)
249    /// * `device` - Device to create the model on
250    ///
251    /// # Example
252    /// ```ignore
253    /// let config = SamConfig::base();
254    /// let encoder = SamEncoder::new(&config, &device);
255    /// ```
256    pub fn new(config: &SamConfig, device: &B::Device) -> Self {
257        debug!(
258            "Creating PatchEmbed: patch_size={}, in_chans=3, embed_dim={}",
259            config.patch_size, config.embed_dim
260        );
261        let patch_embed = PatchEmbed::new(config.patch_size, 3, config.embed_dim, device);
262
263        let g = config.img_size / config.patch_size;
264        info!("Patch grid size: {}x{} = {} patches", g, g, g * g);
265
266        let pos_embed = Tensor::random(
267            [1, g, g, config.embed_dim],
268            Distribution::Normal(0.0, 0.02),
269            device,
270        );
271
272        let mut blocks = Vec::new();
273        for i in 0..config.depth {
274            if config.global_attn_indexes.contains(&i) {
275                debug!("Layer {}: using global attention", i);
276            } else {
277                debug!(
278                    "Layer {}: using window attention (size={})",
279                    i, config.window_size
280                );
281            }
282
283            let window = if config.global_attn_indexes.contains(&i) {
284                0
285            } else {
286                config.window_size
287            };
288            blocks.push(SamBlock::new(
289                config.embed_dim,
290                config.num_heads,
291                config.mlp_ratio,
292                window,
293                config.use_rel_pos,
294                (g, g),
295                device,
296            ));
297        }
298
299        debug!(
300            "Creating neck layers: embed_dim={} -> out_chans={}",
301            config.embed_dim, config.out_chans
302        );
303        let neck1 = Conv2dConfig::new([config.embed_dim, config.out_chans], [1, 1])
304            .with_bias(false)
305            .init(device);
306        let neck2 = Conv2dConfig::new([config.out_chans, config.out_chans], [3, 3])
307            .with_padding(PaddingConfig2d::Same)
308            .with_bias(false)
309            .init(device);
310
311        debug!(
312            "Creating compression layers: {} -> 512 -> 1024 (stride=2 each)",
313            config.out_chans
314        );
315
316        // Use Explicit padding to allow spatial reduction with stride=2
317        let comp1 = Conv2dConfig::new([config.out_chans, 512], [3, 3])
318            .with_stride([2, 2])
319            .with_padding(PaddingConfig2d::Explicit(1, 1))
320            .with_bias(false)
321            .init(device);
322
323        let comp2 = Conv2dConfig::new([512, 1024], [3, 3])
324            .with_stride([2, 2])
325            .with_padding(PaddingConfig2d::Explicit(1, 1))
326            .with_bias(false)
327            .init(device);
328
329        Self {
330            patch_embed,
331            pos_embed,
332            blocks,
333            neck1,
334            neck2,
335            comp1,
336            comp2,
337        }
338    }
339
340    /// Encodes an image into compressed vision features.
341    ///
342    /// # Arguments
343    /// * `x` - Input image tensor
344    ///
345    /// # Shape
346    /// - Input: `[B, 3, H, W]`
347    /// - Output: `[B, 1024, H/P/4, W/P/4]` where P is patch_size
348    ///
349    /// For default settings (patch_size=16, 16× compression):
350    /// - 1024×1024 image → `[B, 1024, 16, 16]` (256 tokens)
351    /// - 512×512 image → `[B, 1024, 8, 8]` (64 tokens)
352    ///
353    /// # Example
354    /// ```ignore
355    /// let image = Tensor::zeros(, &device);[1]
356    /// let features = encoder.forward(image);
357    /// // Shape:  = 256 vision tokens[2][1]
358    /// ```
359    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
360        info!("SamEncoder forward: input shape {:?}", x.dims());
361
362        let x = self.patch_embed.forward(x); // [B, C, H', W']
363        debug!("After patch_embed: {:?}", x.dims());
364
365        let x = x.swap_dims(1, 3).swap_dims(1, 2); // [B, H', W', C]
366        debug!("After transpose: {:?}", x.dims());
367
368        let x = x.clone() + self.interpolate_pos_embed(&x);
369        debug!("After pos_embed: {:?}", x.dims());
370
371        let mut x = x;
372        for (i, blk) in self.blocks.iter().enumerate() {
373            x = blk.forward(x);
374            trace!("After block {}: {:?}", i, x.dims());
375        }
376        info!(
377            "After all {} transformer blocks: {:?}",
378            self.blocks.len(),
379            x.dims()
380        );
381
382        let x = x.swap_dims(1, 3).swap_dims(2, 3); // [B, C, H', W']
383        debug!("After transpose back: {:?}", x.dims());
384
385        let x = self.neck1.forward(x);
386        debug!("After neck1: {:?}", x.dims());
387
388        let x = self.neck2.forward(x);
389        debug!("After neck2: {:?}", x.dims());
390
391        let x = self.comp1.forward(x);
392        info!("After comp1 (stride=2): {:?}", x.dims());
393
394        let out = self.comp2.forward(x);
395        info!("After comp2 (stride=2): {:?}", out.dims());
396        info!(
397            "Final compression: {}x{} = {} tokens",
398            out.dims()[2],
399            out.dims()[3],
400            out.dims()[2] * out.dims()[3]
401        );
402
403        out // [B, 1024, H'', W'']
404    }
405
406    /// Interpolates position embeddings to match input size.
407    ///
408    /// Currently returns zeros; full implementation would use bicubic interpolation
409    /// for dynamic resolution support.
410    fn interpolate_pos_embed(&self, x: &Tensor<B, 4>) -> Tensor<B, 4> {
411        let d = x.dims();
412        Tensor::zeros([d[0], d[1], d[2], d[3]], &x.device())
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use burn_ndarray::NdArray;
420    type TB = NdArray<f32>;
421    #[test]
422    fn shape_sanity() {
423        let dev = Default::default();
424        let mut cfg = SamConfig::default();
425        cfg.img_size = 1024;
426        let enc = SamEncoder::<TB>::new(&cfg, &dev);
427        let out = enc.forward(Tensor::<TB, 4>::zeros([1, 3, 1024, 1024], &dev));
428        assert_eq!(out.dims()[1], 1024);
429    }
430}