1use burn::nn::{
25 conv::{Conv2d, Conv2dConfig},
26 LayerNorm, LayerNormConfig, Linear, LinearConfig, PaddingConfig2d,
27};
28use burn::prelude::*;
29use burn::tensor::{backend::Backend, Distribution, Tensor};
30use log::{debug, info, trace};
31
32use super::attention::WindowAttention;
33use crate::config::SamConfig;
34
35#[derive(Module, Debug)]
40struct PatchEmbed<B: Backend> {
41 proj: Conv2d<B>,
42 _patch: usize,
43 _embed: usize,
44}
45
46impl<B: Backend> PatchEmbed<B> {
47 fn new(patch_size: usize, in_chans: usize, embed_dim: usize, device: &B::Device) -> Self {
55 debug!(
56 "Creating PatchEmbed: patch_size={}, in_chans={}, embed_dim={}",
57 patch_size, in_chans, embed_dim
58 );
59 let proj = Conv2dConfig::new([in_chans, embed_dim], [patch_size, patch_size])
60 .with_stride([patch_size, patch_size])
61 .with_bias(true)
62 .init(device);
63 Self {
64 proj,
65 _patch: patch_size,
66 _embed: embed_dim,
67 }
68 }
69
70 fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
76 trace!("PatchEmbed input shape: {:?}", x.dims());
77 let out = self.proj.forward(x);
78 trace!("PatchEmbed output shape: {:?}", out.dims());
79 out
80 }
81}
82
83#[derive(Module, Debug)]
87struct MlpBlock<B: Backend> {
88 fc1: Linear<B>,
89 fc2: Linear<B>,
90}
91
92impl<B: Backend> MlpBlock<B> {
93 fn new(dim: usize, hidden: usize, device: &B::Device) -> Self {
100 debug!("Creating MlpBlock: dim={}, hidden={}", dim, hidden);
101 let fc1 = LinearConfig::new(dim, hidden).with_bias(true).init(device);
102 let fc2 = LinearConfig::new(hidden, dim).with_bias(true).init(device);
103 Self { fc1, fc2 }
104 }
105
106 fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
112 trace!("MlpBlock input shape: {:?}", x.dims());
113 let x = self.fc1.forward(x);
114 let x = burn::tensor::activation::gelu(x);
115 self.fc2.forward(x)
116 }
117}
118
119#[derive(Module, Debug)]
126struct SamBlock<B: Backend> {
127 norm1: LayerNorm<B>,
128 attn: WindowAttention<B>,
129 norm2: LayerNorm<B>,
130 mlp: MlpBlock<B>,
131 window_size: usize,
132}
133
134impl<B: Backend> SamBlock<B> {
135 fn new(
146 dim: usize,
147 heads: usize,
148 mlp_ratio: f32,
149 window_size: usize,
150 use_rel_pos: bool,
151 input_size: (usize, usize),
152 device: &B::Device,
153 ) -> Self {
154 debug!(
155 "Creating SamBlock: dim={}, heads={}, window_size={}, input_size={:?}",
156 dim, heads, window_size, input_size
157 );
158 let norm1 = LayerNormConfig::new(dim).init(device);
159 let norm2 = LayerNormConfig::new(dim).init(device);
160 let mlp = MlpBlock::new(dim, (dim as f32 * mlp_ratio) as usize, device);
161 let attn = WindowAttention::new(
162 dim,
163 heads,
164 use_rel_pos,
165 if window_size == 0 {
166 input_size
167 } else {
168 (window_size, window_size)
169 },
170 device,
171 );
172 Self {
173 norm1,
174 attn,
175 norm2,
176 mlp,
177 window_size,
178 }
179 }
180
181 fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
187 trace!("SamBlock input shape: {:?}", x.dims());
188 let [b, h, w, c] = {
189 let d = x.dims();
190 [d[0], d[1], d[2], d[3]]
191 };
192 let x1 = self
193 .norm1
194 .forward(x.clone().reshape([b * h * w, c]))
195 .reshape([b, h, w, c]);
196 let x_attn = if self.window_size > 0 {
197 self.attn.forward_windowed(x1, self.window_size)
198 } else {
199 self.attn.forward(x1)
200 };
201 let x = x + x_attn;
202 let x2 = self
203 .norm2
204 .forward(x.clone().reshape([b * h * w, c]))
205 .reshape([b, h, w, c]);
206 let x_mlp = self
207 .mlp
208 .forward(x2.reshape([b, h * w, c]))
209 .reshape([b, h, w, c]);
210 x + x_mlp
211 }
212}
213
214#[derive(Module, Debug)]
234pub struct SamEncoder<B: Backend> {
235 patch_embed: PatchEmbed<B>,
236 pos_embed: Tensor<B, 4>,
237 blocks: Vec<SamBlock<B>>,
238 neck1: Conv2d<B>,
239 neck2: Conv2d<B>,
240 comp1: Conv2d<B>,
241 comp2: Conv2d<B>,
242}
243
244impl<B: Backend> SamEncoder<B> {
245 pub fn new(config: &SamConfig, device: &B::Device) -> Self {
257 debug!(
258 "Creating PatchEmbed: patch_size={}, in_chans=3, embed_dim={}",
259 config.patch_size, config.embed_dim
260 );
261 let patch_embed = PatchEmbed::new(config.patch_size, 3, config.embed_dim, device);
262
263 let g = config.img_size / config.patch_size;
264 info!("Patch grid size: {}x{} = {} patches", g, g, g * g);
265
266 let pos_embed = Tensor::random(
267 [1, g, g, config.embed_dim],
268 Distribution::Normal(0.0, 0.02),
269 device,
270 );
271
272 let mut blocks = Vec::new();
273 for i in 0..config.depth {
274 if config.global_attn_indexes.contains(&i) {
275 debug!("Layer {}: using global attention", i);
276 } else {
277 debug!(
278 "Layer {}: using window attention (size={})",
279 i, config.window_size
280 );
281 }
282
283 let window = if config.global_attn_indexes.contains(&i) {
284 0
285 } else {
286 config.window_size
287 };
288 blocks.push(SamBlock::new(
289 config.embed_dim,
290 config.num_heads,
291 config.mlp_ratio,
292 window,
293 config.use_rel_pos,
294 (g, g),
295 device,
296 ));
297 }
298
299 debug!(
300 "Creating neck layers: embed_dim={} -> out_chans={}",
301 config.embed_dim, config.out_chans
302 );
303 let neck1 = Conv2dConfig::new([config.embed_dim, config.out_chans], [1, 1])
304 .with_bias(false)
305 .init(device);
306 let neck2 = Conv2dConfig::new([config.out_chans, config.out_chans], [3, 3])
307 .with_padding(PaddingConfig2d::Same)
308 .with_bias(false)
309 .init(device);
310
311 debug!(
312 "Creating compression layers: {} -> 512 -> 1024 (stride=2 each)",
313 config.out_chans
314 );
315
316 let comp1 = Conv2dConfig::new([config.out_chans, 512], [3, 3])
318 .with_stride([2, 2])
319 .with_padding(PaddingConfig2d::Explicit(1, 1))
320 .with_bias(false)
321 .init(device);
322
323 let comp2 = Conv2dConfig::new([512, 1024], [3, 3])
324 .with_stride([2, 2])
325 .with_padding(PaddingConfig2d::Explicit(1, 1))
326 .with_bias(false)
327 .init(device);
328
329 Self {
330 patch_embed,
331 pos_embed,
332 blocks,
333 neck1,
334 neck2,
335 comp1,
336 comp2,
337 }
338 }
339
340 pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
360 info!("SamEncoder forward: input shape {:?}", x.dims());
361
362 let x = self.patch_embed.forward(x); debug!("After patch_embed: {:?}", x.dims());
364
365 let x = x.swap_dims(1, 3).swap_dims(1, 2); debug!("After transpose: {:?}", x.dims());
367
368 let x = x.clone() + self.interpolate_pos_embed(&x);
369 debug!("After pos_embed: {:?}", x.dims());
370
371 let mut x = x;
372 for (i, blk) in self.blocks.iter().enumerate() {
373 x = blk.forward(x);
374 trace!("After block {}: {:?}", i, x.dims());
375 }
376 info!(
377 "After all {} transformer blocks: {:?}",
378 self.blocks.len(),
379 x.dims()
380 );
381
382 let x = x.swap_dims(1, 3).swap_dims(2, 3); debug!("After transpose back: {:?}", x.dims());
384
385 let x = self.neck1.forward(x);
386 debug!("After neck1: {:?}", x.dims());
387
388 let x = self.neck2.forward(x);
389 debug!("After neck2: {:?}", x.dims());
390
391 let x = self.comp1.forward(x);
392 info!("After comp1 (stride=2): {:?}", x.dims());
393
394 let out = self.comp2.forward(x);
395 info!("After comp2 (stride=2): {:?}", out.dims());
396 info!(
397 "Final compression: {}x{} = {} tokens",
398 out.dims()[2],
399 out.dims()[3],
400 out.dims()[2] * out.dims()[3]
401 );
402
403 out }
405
406 fn interpolate_pos_embed(&self, x: &Tensor<B, 4>) -> Tensor<B, 4> {
411 let d = x.dims();
412 Tensor::zeros([d[0], d[1], d[2], d[3]], &x.device())
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use burn_ndarray::NdArray;
420 type TB = NdArray<f32>;
421 #[test]
422 fn shape_sanity() {
423 let dev = Default::default();
424 let mut cfg = SamConfig::default();
425 cfg.img_size = 1024;
426 let enc = SamEncoder::<TB>::new(&cfg, &dev);
427 let out = enc.forward(Tensor::<TB, 4>::zeros([1, 3, 1024, 1024], &dev));
428 assert_eq!(out.dims()[1], 1024);
429 }
430}