Skip to main content

rlx_sam2/
memory_encoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM 2 memory encoder — host-side.
17//!
18//! Mirrors `sam2/modeling/memory_encoder.py` exactly:
19//!
20//! ```text
21//!   MemoryEncoder(pix_feat, masks):
22//!     masks = sigmoid(masks) if not skip_mask_sigmoid
23//!     masks = MaskDownSampler(masks)        # 1×1024×1024 → 256×64×64
24//!     pix_feat = pix_feat_proj(pix_feat)    # 1×1 conv 256→256
25//!     x = pix_feat + masks
26//!     x = Fuser(x)                          # 2 × CXBlock
27//!     x = out_proj(x)                       # 1×1 conv 256→out_dim (64)
28//!     pos = PositionEmbeddingSine(x)        # sinusoidal 2-D PE
29//!     return (x, pos)
30//! ```
31//!
32//! `MaskDownSampler` is a stack of `log_stride(total_stride)` blocks of
33//! `Conv2d(k,s,p) → LayerNorm2d → GELU` that grow the channel dim by
34//! `stride²` each step (1 → 4 → 16 → 64 → 256 for the default
35//! stride=2, total_stride=16). A final 1×1 conv projects to
36//! `embed_dim=in_dim=256`.
37//!
38//! `Fuser` is a ConvNeXt-style stack — depthwise Conv k=7 → LN →
39//! pointwise Linear (4× expansion) → GELU → pointwise Linear → optional
40//! per-channel `gamma` (LayerScale) → residual.
41
42use super::config::{SAM2_IMG_SIZE, Sam2MemoryEncoderConfig};
43use super::memory_mask_ir::{
44    Sam2MemoryConv1x1Compiled, Sam2MemoryFuserCompiled, Sam2MemoryMaskDownCompiled,
45    Sam2MemoryPrefixCompiled,
46};
47use super::prompt_encoder::{conv2d_1x1, gelu_erf_inplace, layernorm2d_nchw, sigmoid_inplace};
48use anyhow::{Result, ensure};
49use rlx_core::weight_map::WeightMap;
50use rlx_runtime::Device;
51use std::f32::consts::PI;
52
53// ─── Weight structs ─────────────────────────────────────────────────
54
55pub struct Sam2MaskDownSamplerWeights {
56    /// Per-level `(conv_w, conv_b, ln_g, ln_b)` for the down-sampling
57    /// Conv → LN2d → GELU pattern.
58    pub levels: Vec<DownSampleLevel>,
59    /// Final 1×1 conv `[embed_dim, last_chans]`.
60    pub final_conv_w: Vec<f32>,
61    pub final_conv_b: Vec<f32>,
62    pub kernel: usize,
63    pub stride: usize,
64    pub padding: usize,
65    pub embed_dim: usize,
66}
67
68pub struct DownSampleLevel {
69    pub conv_w: Vec<f32>, // [out_c, in_c, k, k]
70    pub conv_b: Vec<f32>, // [out_c]
71    pub ln_g: Vec<f32>,   // [out_c]
72    pub ln_b: Vec<f32>,
73    pub in_c: usize,
74    pub out_c: usize,
75}
76
77pub struct Sam2CXBlockWeights {
78    pub dw_conv_w: Vec<f32>, // depthwise [dim, 1, k, k]
79    pub dw_conv_b: Vec<f32>, // [dim]
80    pub ln_g: Vec<f32>,
81    pub ln_b: Vec<f32>,
82    pub pw1_w: Vec<f32>, // [4·dim, dim]
83    pub pw1_b: Vec<f32>, // [4·dim]
84    pub pw2_w: Vec<f32>, // [dim, 4·dim]
85    pub pw2_b: Vec<f32>, // [dim]
86    /// LayerScale per-channel gain (optional in reference; present
87    /// when `layer_scale_init_value > 0`).
88    pub gamma: Option<Vec<f32>>,
89    pub dim: usize,
90    pub kernel: usize,
91    pub padding: usize,
92}
93
94pub struct Sam2FuserWeights {
95    /// Optional input-projection 1×1 conv (rarely used).
96    pub input_proj_w: Option<Vec<f32>>,
97    pub input_proj_b: Option<Vec<f32>>,
98    pub layers: Vec<Sam2CXBlockWeights>,
99    pub dim: usize,
100}
101
102pub struct Sam2MemoryEncoderWeights {
103    pub mask_downsampler: Sam2MaskDownSamplerWeights,
104    pub prefix: Option<Sam2MemoryPrefixCompiled>,
105    pub mask_down: Option<Sam2MemoryMaskDownCompiled>,
106    pub pix_proj: Option<Sam2MemoryConv1x1Compiled>,
107    pub fuser_ir: Option<Sam2MemoryFuserCompiled>,
108    pub out_proj_ir: Option<Sam2MemoryConv1x1Compiled>,
109    pub pix_feat_proj_w: Vec<f32>, // [in_dim, in_dim, 1, 1]
110    pub pix_feat_proj_b: Vec<f32>,
111    pub fuser: Sam2FuserWeights,
112    /// `out_proj`: 1×1 conv `in_dim → out_dim`. None when in_dim == out_dim
113    /// (PyTorch `nn.Identity` in the reference).
114    pub out_proj_w: Option<Vec<f32>>,
115    pub out_proj_b: Option<Vec<f32>>,
116    pub in_dim: usize,
117    pub out_dim: usize,
118    pub pe_num_pos_feats: usize,
119    pub pe_temperature: f32,
120}
121
122// ─── Weight extraction ─────────────────────────────────────────────
123
124pub fn extract_memory_encoder_weights(
125    weights: &mut WeightMap,
126    cfg: &Sam2MemoryEncoderConfig,
127) -> Result<Sam2MemoryEncoderWeights> {
128    let mask_downsampler = extract_mask_downsampler(weights, cfg)?;
129
130    let (pix_feat_proj_w, sh) = weights.take("memory_encoder.pix_feat_proj.weight")?;
131    ensure!(
132        sh == vec![cfg.in_dim, cfg.in_dim, 1, 1],
133        "pix_feat_proj.weight shape {sh:?} not [{}, {}, 1, 1]",
134        cfg.in_dim,
135        cfg.in_dim
136    );
137    let (pix_feat_proj_b, _) = weights.take("memory_encoder.pix_feat_proj.bias")?;
138
139    let fuser = extract_fuser(weights, cfg)?;
140
141    let (out_proj_w, out_proj_b) = if cfg.in_dim == cfg.out_dim {
142        (None, None)
143    } else {
144        let (w, sh) = weights.take("memory_encoder.out_proj.weight")?;
145        ensure!(
146            sh == vec![cfg.out_dim, cfg.in_dim, 1, 1],
147            "out_proj.weight shape {sh:?} not [{}, {}, 1, 1]",
148            cfg.out_dim,
149            cfg.in_dim
150        );
151        let (b, _) = weights.take("memory_encoder.out_proj.bias")?;
152        (Some(w), Some(b))
153    };
154
155    Ok(Sam2MemoryEncoderWeights {
156        mask_downsampler,
157        prefix: None,
158        mask_down: None,
159        pix_proj: None,
160        fuser_ir: None,
161        out_proj_ir: None,
162        pix_feat_proj_w,
163        pix_feat_proj_b,
164        fuser,
165        out_proj_w,
166        out_proj_b,
167        in_dim: cfg.in_dim,
168        out_dim: cfg.out_dim,
169        pe_num_pos_feats: cfg.pe_num_pos_feats,
170        pe_temperature: cfg.pe_temperature,
171    })
172}
173
174fn extract_mask_downsampler(
175    weights: &mut WeightMap,
176    cfg: &Sam2MemoryEncoderConfig,
177) -> Result<Sam2MaskDownSamplerWeights> {
178    // num_layers = log_stride(total_stride). Reference asserts
179    // `stride ** num_layers == total_stride`.
180    let mut num_layers = 0;
181    let mut acc = 1usize;
182    while acc < cfg.mask_downsampler_total_stride {
183        acc *= cfg.mask_downsampler_stride;
184        num_layers += 1;
185    }
186    ensure!(
187        acc == cfg.mask_downsampler_total_stride,
188        "mask_downsampler total_stride {} must be a power of stride {}",
189        cfg.mask_downsampler_total_stride,
190        cfg.mask_downsampler_stride
191    );
192
193    let mut levels = Vec::with_capacity(num_layers);
194    let mut in_c = 1usize;
195    let stride2 = cfg.mask_downsampler_stride * cfg.mask_downsampler_stride;
196    // Reference's MaskDownSampler `encoder` is `nn.Sequential` of
197    // groups (Conv2d, LayerNorm2d, GELU) per level, plus a final
198    // 1×1 conv. Group index increment is 3 per level.
199    for li in 0..num_layers {
200        let out_c = in_c * stride2;
201        let conv_idx = li * 3;
202        let ln_idx = conv_idx + 1;
203        let (conv_w, sh) = weights.take(&format!(
204            "memory_encoder.mask_downsampler.encoder.{conv_idx}.weight"
205        ))?;
206        ensure!(
207            sh == vec![
208                out_c,
209                in_c,
210                cfg.mask_downsampler_kernel,
211                cfg.mask_downsampler_kernel
212            ],
213            "mask_downsampler conv {li} weight shape {sh:?} not [{out_c}, {in_c}, {}, {}]",
214            cfg.mask_downsampler_kernel,
215            cfg.mask_downsampler_kernel
216        );
217        let (conv_b, _) = weights.take(&format!(
218            "memory_encoder.mask_downsampler.encoder.{conv_idx}.bias"
219        ))?;
220        let (ln_g, _) = weights.take(&format!(
221            "memory_encoder.mask_downsampler.encoder.{ln_idx}.weight"
222        ))?;
223        let (ln_b, _) = weights.take(&format!(
224            "memory_encoder.mask_downsampler.encoder.{ln_idx}.bias"
225        ))?;
226        levels.push(DownSampleLevel {
227            conv_w,
228            conv_b,
229            ln_g,
230            ln_b,
231            in_c,
232            out_c,
233        });
234        in_c = out_c;
235    }
236    // Final 1×1 conv goes at index num_layers*3.
237    let final_idx = num_layers * 3;
238    let (final_conv_w, sh) = weights.take(&format!(
239        "memory_encoder.mask_downsampler.encoder.{final_idx}.weight"
240    ))?;
241    ensure!(
242        sh == vec![cfg.in_dim, in_c, 1, 1],
243        "mask_downsampler final conv weight shape {sh:?} not [{}, {in_c}, 1, 1]",
244        cfg.in_dim
245    );
246    let (final_conv_b, _) = weights.take(&format!(
247        "memory_encoder.mask_downsampler.encoder.{final_idx}.bias"
248    ))?;
249
250    Ok(Sam2MaskDownSamplerWeights {
251        levels,
252        final_conv_w,
253        final_conv_b,
254        kernel: cfg.mask_downsampler_kernel,
255        stride: cfg.mask_downsampler_stride,
256        padding: cfg.mask_downsampler_padding,
257        embed_dim: cfg.in_dim,
258    })
259}
260
261fn extract_fuser(
262    weights: &mut WeightMap,
263    cfg: &Sam2MemoryEncoderConfig,
264) -> Result<Sam2FuserWeights> {
265    let (input_proj_w, input_proj_b) = if cfg.fuser_input_projection {
266        let (w, sh) = weights.take("memory_encoder.fuser.proj.weight")?;
267        ensure!(
268            sh == vec![cfg.fuser_dim, cfg.fuser_dim, 1, 1],
269            "fuser.proj.weight shape {sh:?} not [{}, {}, 1, 1]",
270            cfg.fuser_dim,
271            cfg.fuser_dim
272        );
273        let (b, _) = weights.take("memory_encoder.fuser.proj.bias")?;
274        (Some(w), Some(b))
275    } else {
276        (None, None)
277    };
278
279    let mut layers = Vec::with_capacity(cfg.fuser_num_layers);
280    for i in 0..cfg.fuser_num_layers {
281        let p = format!("memory_encoder.fuser.layers.{i}");
282        let (dw_conv_w, sh) = weights.take(&format!("{p}.dwconv.weight"))?;
283        // Depthwise conv: groups=dim → weight shape [dim, 1, k, k].
284        let dim = cfg.fuser_dim;
285        let k = cfg.fuser_kernel;
286        if cfg.fuser_use_dwconv {
287            ensure!(
288                sh == vec![dim, 1, k, k],
289                "{p}.dwconv.weight shape {sh:?} not [{dim}, 1, {k}, {k}]"
290            );
291        } else {
292            ensure!(
293                sh == vec![dim, dim, k, k],
294                "{p}.dwconv.weight shape {sh:?} not [{dim}, {dim}, {k}, {k}]"
295            );
296        }
297        let (dw_conv_b, _) = weights.take(&format!("{p}.dwconv.bias"))?;
298        let (ln_g, _) = weights.take(&format!("{p}.norm.weight"))?;
299        let (ln_b, _) = weights.take(&format!("{p}.norm.bias"))?;
300        let (pw1_w, sh) = weights.take(&format!("{p}.pwconv1.weight"))?;
301        ensure!(
302            sh == vec![4 * dim, dim],
303            "{p}.pwconv1.weight shape {sh:?} not [{}, {dim}]",
304            4 * dim
305        );
306        let (pw1_b, _) = weights.take(&format!("{p}.pwconv1.bias"))?;
307        let (pw2_w, _) = weights.take(&format!("{p}.pwconv2.weight"))?;
308        let (pw2_b, _) = weights.take(&format!("{p}.pwconv2.bias"))?;
309        let gamma = if cfg.fuser_layer_scale_init_value > 0.0 {
310            let (g, _) = weights.take(&format!("{p}.gamma"))?;
311            Some(g)
312        } else {
313            None
314        };
315        layers.push(Sam2CXBlockWeights {
316            dw_conv_w,
317            dw_conv_b,
318            ln_g,
319            ln_b,
320            pw1_w,
321            pw1_b,
322            pw2_w,
323            pw2_b,
324            gamma,
325            dim,
326            kernel: k,
327            padding: cfg.fuser_padding,
328        });
329    }
330    Ok(Sam2FuserWeights {
331        input_proj_w,
332        input_proj_b,
333        layers,
334        dim: cfg.fuser_dim,
335    })
336}
337
338/// Compile memory-encoder IR subgraphs (mask down, pix 1×1, fuser, optional out 1×1).
339pub fn compile_memory_encoder_ir(
340    weights: &mut Sam2MemoryEncoderWeights,
341    mask_in_h: usize,
342    mask_in_w: usize,
343    feat_h: usize,
344    feat_w: usize,
345    device: Device,
346    profile: &rlx_flow::CompileProfile,
347) -> Result<()> {
348    weights.prefix = Some(Sam2MemoryPrefixCompiled::compile_with_profile(
349        &weights.mask_downsampler,
350        weights.in_dim,
351        mask_in_h,
352        mask_in_w,
353        feat_h,
354        feat_w,
355        &weights.pix_feat_proj_w,
356        &weights.pix_feat_proj_b,
357        device,
358        profile,
359    )?);
360    weights.fuser_ir = Some(Sam2MemoryFuserCompiled::compile_with_profile(
361        &weights.fuser,
362        feat_h,
363        feat_w,
364        device,
365        profile,
366    )?);
367    if let (Some(opw), Some(opb)) = (&weights.out_proj_w, &weights.out_proj_b) {
368        weights.out_proj_ir = Some(Sam2MemoryConv1x1Compiled::compile_with_profile(
369            weights.in_dim,
370            weights.out_dim,
371            feat_h,
372            feat_w,
373            opw,
374            opb,
375            device,
376            profile,
377        )?);
378    }
379    Ok(())
380}
381
382/// Back-compat alias for mask-downsampler-only compile.
383pub fn compile_memory_mask_ir(
384    weights: &mut Sam2MemoryEncoderWeights,
385    mask_in_h: usize,
386    mask_in_w: usize,
387    device: Device,
388) -> Result<()> {
389    compile_memory_encoder_ir(
390        weights,
391        mask_in_h,
392        mask_in_w,
393        mask_in_h / total_stride(&weights.mask_downsampler),
394        mask_in_w / total_stride(&weights.mask_downsampler),
395        device,
396        &rlx_flow::CompileProfile::sam2(),
397    )
398}
399
400// ─── Forward ────────────────────────────────────────────────────────
401
402pub struct Sam2MemoryEncoderOutput {
403    /// `[out_dim, h, w]` memory feature map (typically 64×64×64).
404    pub features: Vec<f32>,
405    /// `[2·pe_num_pos_feats, h, w]` sinusoidal PE matching `features`.
406    pub pos: Vec<f32>,
407    pub h: usize,
408    pub w: usize,
409}
410
411/// Run the SAM 2 memory encoder.
412///
413/// `pix_feat`: stride-16 features `[in_dim, h, w]` (typically 256×64×64
414/// from the FpnNeck level 2).
415/// `masks`: mask logits `[1, H_full, W_full]` (or sigmoid probs, with
416/// `skip_mask_sigmoid=true`). H_full = W_full = `SAM2_IMG_SIZE` (1024).
417/// After MaskDownSampler the masks are at stride `total_stride=16`,
418/// giving shape `[in_dim, h, w]` matching pix_feat.
419pub fn memory_encoder_forward(
420    w: &mut Sam2MemoryEncoderWeights,
421    pix_feat: &[f32],
422    masks: &[f32],
423    pix_h: usize,
424    pix_w: usize,
425    skip_mask_sigmoid: bool,
426) -> Result<Sam2MemoryEncoderOutput> {
427    ensure!(
428        pix_feat.len() == w.in_dim * pix_h * pix_w,
429        "pix_feat len {} ≠ in_dim·h·w ({}·{pix_h}·{pix_w})",
430        pix_feat.len(),
431        w.in_dim
432    );
433    let in_h = SAM2_IMG_SIZE;
434    let in_w = SAM2_IMG_SIZE;
435    ensure!(
436        masks.len() == in_h * in_w,
437        "masks len {} ≠ H·W ({in_h}·{in_w}); pass a full-resolution mask",
438        masks.len()
439    );
440
441    // 1) Sigmoid (optional).
442    let mut m: Vec<f32> = masks.to_vec();
443    if !skip_mask_sigmoid {
444        sigmoid_inplace(&mut m);
445    }
446
447    // 2–4) MaskDownSampler + pix_feat_proj + add (fused or split).
448    let x = if let Some(ref mut prefix) = w.prefix {
449        prefix.run(&m, pix_feat)?
450    } else {
451        let m_down = if let Some(ref mut md) = w.mask_down {
452            md.run(&m)?
453        } else {
454            mask_downsampler_forward(&w.mask_downsampler, &m, in_h, in_w)?
455        };
456        let down_h = in_h / total_stride(&w.mask_downsampler);
457        let down_w = in_w / total_stride(&w.mask_downsampler);
458        ensure!(
459            down_h == pix_h && down_w == pix_w,
460            "mask after downsampling ({down_h}×{down_w}) doesn't match pix_feat ({pix_h}×{pix_w})"
461        );
462        let mut x = if let Some(ref mut p) = w.pix_proj {
463            p.run(pix_feat)?
464        } else {
465            conv2d_1x1(
466                pix_feat,
467                w.in_dim,
468                w.in_dim,
469                pix_h,
470                pix_w,
471                &w.pix_feat_proj_w,
472                &w.pix_feat_proj_b,
473            )
474        };
475        for i in 0..x.len() {
476            x[i] += m_down[i];
477        }
478        x
479    };
480
481    // 5) Fuser.
482    let x = if let Some(ref mut f) = w.fuser_ir {
483        f.run(&x)?
484    } else {
485        fuser_forward(&w.fuser, x, pix_h, pix_w)
486    };
487
488    // 6) Optional out_proj.
489    let features = if let Some(ref mut o) = w.out_proj_ir {
490        o.run(&x)?
491    } else if let (Some(opw), Some(opb)) = (&w.out_proj_w, &w.out_proj_b) {
492        conv2d_1x1(&x, w.in_dim, w.out_dim, pix_h, pix_w, opw, opb)
493    } else {
494        x
495    };
496
497    // 7) Sinusoidal PE.
498    let pos = sinusoidal_pos_2d(2 * w.pe_num_pos_feats, pix_h, pix_w, w.pe_temperature);
499
500    Ok(Sam2MemoryEncoderOutput {
501        features,
502        pos,
503        h: pix_h,
504        w: pix_w,
505    })
506}
507
508fn total_stride(d: &Sam2MaskDownSamplerWeights) -> usize {
509    d.stride.pow(d.levels.len() as u32)
510}
511
512/// MaskDownSampler forward. `in`: `[1, H, W]`. Repeats
513/// Conv(k,s,p) → LN2d → GELU `num_levels` times, then a final 1×1 conv
514/// to `embed_dim`.
515fn mask_downsampler_forward(
516    w: &Sam2MaskDownSamplerWeights,
517    input: &[f32],
518    h: usize,
519    ww: usize,
520) -> Result<Vec<f32>> {
521    let mut cur = input.to_vec();
522    let mut cur_c = 1usize;
523    let mut cur_h = h;
524    let mut cur_w = ww;
525    for level in &w.levels {
526        let out_h = (cur_h + 2 * w.padding - w.kernel) / w.stride + 1;
527        let out_w = (cur_w + 2 * w.padding - w.kernel) / w.stride + 1;
528        cur = conv2d_general(
529            &cur,
530            cur_c,
531            level.out_c,
532            cur_h,
533            cur_w,
534            w.kernel,
535            w.stride,
536            w.padding,
537            &level.conv_w,
538            &level.conv_b,
539        );
540        cur_c = level.out_c;
541        cur_h = out_h;
542        cur_w = out_w;
543        layernorm2d_nchw(
544            &mut cur,
545            cur_c,
546            cur_h,
547            cur_w,
548            &level.ln_g,
549            &level.ln_b,
550            1e-6,
551        );
552        gelu_erf_inplace(&mut cur);
553    }
554    // Final 1×1 conv.
555    let out = conv2d_1x1(
556        &cur,
557        cur_c,
558        w.embed_dim,
559        cur_h,
560        cur_w,
561        &w.final_conv_w,
562        &w.final_conv_b,
563    );
564    Ok(out)
565}
566
567fn fuser_forward(w: &Sam2FuserWeights, mut x: Vec<f32>, h: usize, ww: usize) -> Vec<f32> {
568    if let (Some(pw), Some(pb)) = (&w.input_proj_w, &w.input_proj_b) {
569        x = conv2d_1x1(&x, w.dim, w.dim, h, ww, pw, pb);
570    }
571    for layer in &w.layers {
572        x = cx_block_forward(layer, x, h, ww);
573    }
574    x
575}
576
577fn cx_block_forward(w: &Sam2CXBlockWeights, x: Vec<f32>, h: usize, ww: usize) -> Vec<f32> {
578    let dim = w.dim;
579    // Depthwise conv k×k pad=padding.
580    let mut y = conv2d_depthwise_k_pad(
581        &x,
582        dim,
583        h,
584        ww,
585        w.kernel,
586        w.padding,
587        &w.dw_conv_w,
588        &w.dw_conv_b,
589    );
590    // LN over channel dim (NCHW per spatial pos).
591    layernorm2d_nchw(&mut y, dim, h, ww, &w.ln_g, &w.ln_b, 1e-6);
592    // Permute NCHW → NHWC, apply pointwise Linear(dim → 4·dim) → GELU
593    // → Linear(4·dim → dim), permute back.
594    let mut nhwc = vec![0f32; h * ww * dim];
595    for c in 0..dim {
596        for yy in 0..h {
597            for xx in 0..ww {
598                nhwc[(yy * ww + xx) * dim + c] = y[c * h * ww + yy * ww + xx];
599            }
600        }
601    }
602    let four_d = 4 * dim;
603    let mut up = vec![0f32; h * ww * four_d];
604    for r in 0..h * ww {
605        for o in 0..four_d {
606            let mut acc = w.pw1_b[o];
607            for k in 0..dim {
608                acc += nhwc[r * dim + k] * w.pw1_w[o * dim + k];
609            }
610            up[r * four_d + o] = acc;
611        }
612    }
613    gelu_erf_inplace(&mut up);
614    let mut down = vec![0f32; h * ww * dim];
615    for r in 0..h * ww {
616        for o in 0..dim {
617            let mut acc = w.pw2_b[o];
618            for k in 0..four_d {
619                acc += up[r * four_d + k] * w.pw2_w[o * four_d + k];
620            }
621            down[r * dim + o] = acc;
622        }
623    }
624    if let Some(gamma) = &w.gamma {
625        for r in 0..h * ww {
626            for c in 0..dim {
627                down[r * dim + c] *= gamma[c];
628            }
629        }
630    }
631    // Permute NHWC → NCHW, add residual.
632    let mut out = x;
633    for c in 0..dim {
634        for yy in 0..h {
635            for xx in 0..ww {
636                out[c * h * ww + yy * ww + xx] += down[(yy * ww + xx) * dim + c];
637            }
638        }
639    }
640    out
641}
642
643// ─── Generic conv helpers ───────────────────────────────────────────
644
645/// Generic 2-D conv NCHW: `[in_c, h, w]` → `[out_c, h', w']` with
646/// arbitrary kernel/stride/padding (no dilation).
647fn conv2d_general(
648    input: &[f32],
649    in_c: usize,
650    out_c: usize,
651    h: usize,
652    w: usize,
653    k: usize,
654    s: usize,
655    p: usize,
656    weight: &[f32], // [out_c, in_c, k, k]
657    bias: &[f32],   // [out_c]
658) -> Vec<f32> {
659    let out_h = (h + 2 * p - k) / s + 1;
660    let out_w = (w + 2 * p - k) / s + 1;
661    let mut out = vec![0f32; out_c * out_h * out_w];
662    for oc in 0..out_c {
663        let b = bias[oc];
664        for oy in 0..out_h {
665            for ox in 0..out_w {
666                let mut acc = b;
667                for ic in 0..in_c {
668                    for ky in 0..k {
669                        let iy = oy as isize * s as isize + ky as isize - p as isize;
670                        if iy < 0 || iy >= h as isize {
671                            continue;
672                        }
673                        for kx in 0..k {
674                            let ix = ox as isize * s as isize + kx as isize - p as isize;
675                            if ix < 0 || ix >= w as isize {
676                                continue;
677                            }
678                            let v = input[ic * h * w + iy as usize * w + ix as usize];
679                            let w_idx = ((oc * in_c + ic) * k + ky) * k + kx;
680                            acc += v * weight[w_idx];
681                        }
682                    }
683                }
684                out[oc * out_h * out_w + oy * out_w + ox] = acc;
685            }
686        }
687    }
688    out
689}
690
691/// Depthwise 2-D conv k×k stride=1 padding=p. Weight `[dim, 1, k, k]`.
692fn conv2d_depthwise_k_pad(
693    input: &[f32],
694    dim: usize,
695    h: usize,
696    w: usize,
697    k: usize,
698    p: usize,
699    weight: &[f32],
700    bias: &[f32],
701) -> Vec<f32> {
702    let mut out = vec![0f32; dim * h * w];
703    for c in 0..dim {
704        let b = bias[c];
705        let w_base = c * k * k; // weight is [dim, 1, k, k], so per-channel offset = c·k·k
706        for oy in 0..h {
707            for ox in 0..w {
708                let mut acc = b;
709                for ky in 0..k {
710                    let iy = oy as isize + ky as isize - p as isize;
711                    if iy < 0 || iy >= h as isize {
712                        continue;
713                    }
714                    for kx in 0..k {
715                        let ix = ox as isize + kx as isize - p as isize;
716                        if ix < 0 || ix >= w as isize {
717                            continue;
718                        }
719                        let v = input[c * h * w + iy as usize * w + ix as usize];
720                        acc += v * weight[w_base + ky * k + kx];
721                    }
722                }
723                out[c * h * w + oy * w + ox] = acc;
724            }
725        }
726    }
727    out
728}
729
730/// Reference `PositionEmbeddingSine` forward — same code path as the
731/// FpnNeck PE but kept here so the memory-encoder output owns its PE
732/// generator with its own `temperature` config option.
733pub(super) fn sinusoidal_pos_2d(d_model: usize, h: usize, w: usize, temperature: f32) -> Vec<f32> {
734    let nf = d_model / 2;
735    let scale: f32 = 2.0 * PI;
736    let eps: f32 = 1e-6;
737    let mut out = vec![0f32; d_model * h * w];
738    let mut dim_t = vec![0f32; nf];
739    for i in 0..nf {
740        let exp = 2.0 * ((i / 2) as f32) / (nf as f32);
741        dim_t[i] = temperature.powf(exp);
742    }
743    for y in 0..h {
744        let y_emb = ((y + 1) as f32) / ((h as f32) + eps) * scale;
745        for x in 0..w {
746            let x_emb = ((x + 1) as f32) / ((w as f32) + eps) * scale;
747            for i in 0..nf {
748                let py = y_emb / dim_t[i];
749                let v = if i % 2 == 0 { py.sin() } else { py.cos() };
750                out[i * h * w + y * w + x] = v;
751            }
752            for i in 0..nf {
753                let px = x_emb / dim_t[i];
754                let v = if i % 2 == 0 { px.sin() } else { px.cos() };
755                out[(nf + i) * h * w + y * w + x] = v;
756            }
757        }
758    }
759    out
760}