Skip to main content

rlx_sam3/
vision_encoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Native SAM3 ViT trunk.
17//!
18//! Full numerical port of `sam3.model.vitdet.ViT` (the base 1008² model):
19//!
20//!   - Patch embed (Conv2d k=14, s=14, no bias) → [B, H, W, C] in NHWC.
21//!   - Tiled absolute positional embedding (24×24 → 72×72).
22//!   - 32 transformer blocks. Blocks 0..32 are window-attention with
23//!     `window_size=24` except for `global_att_blocks=[7,15,23,31]` which
24//!     run full-resolution attention.
25//!   - 2D RoPE applied to Q/K, interpolated when input differs from
26//!     `rope_pt_size=(24, 24)`.
27//!   - LayerNorm `ln_pre`, identity `ln_post`, `mlp_ratio=4.625`.
28//!
29//! Output: final block features in NHWC `[1, 72, 72, 1024]`, flattened
30//! row-major as `[grid*grid, embed_dim]`.
31
32use super::config::{SAM3_PATCH_GRID, Sam3VitConfig};
33use super::preprocess::{Sam3PreprocessWeights, assemble_patch_tokens, extract_preprocess_weights};
34use super::tensor::{gelu_tanh, layer_norm, linear, matmul, matmul_bt, softmax_rows};
35use anyhow::{Result, ensure};
36use rlx_core::weight_map::WeightMap;
37use rlx_flow::{GgufPackedLinear, GgufPackedParams};
38
39#[derive(Clone)]
40pub struct Sam3VitBlockWeights {
41    pub norm1_w: Vec<f32>,
42    pub norm1_b: Vec<f32>,
43    pub qkv_w_t: Vec<f32>,
44    pub qkv_b: Vec<f32>,
45    /// GGUF prefix for `attn.qkv` (no `.weight` suffix) when `qkv_w_t` is empty.
46    pub qkv_gguf_prefix: Option<String>,
47    pub proj_w_t: Vec<f32>,
48    pub proj_b: Vec<f32>,
49    pub proj_gguf_prefix: Option<String>,
50    pub norm2_w: Vec<f32>,
51    pub norm2_b: Vec<f32>,
52    pub mlp_fc1_w_t: Vec<f32>,
53    pub mlp_fc1_b: Vec<f32>,
54    pub mlp_fc1_gguf_prefix: Option<String>,
55    pub mlp_fc2_w_t: Vec<f32>,
56    pub mlp_fc2_b: Vec<f32>,
57    pub mlp_fc2_gguf_prefix: Option<String>,
58}
59
60#[derive(Clone)]
61pub struct Sam3VisionEncoderWeights {
62    pub pre: Sam3PreprocessWeights,
63    pub ln_pre_w: Vec<f32>,
64    pub ln_pre_b: Vec<f32>,
65    pub blocks: Vec<Sam3VitBlockWeights>,
66}
67
68pub struct Sam3VisionOutput {
69    pub tokens: Vec<f32>,
70    pub grid: usize,
71    pub dim: usize,
72}
73
74pub fn extract_vision_encoder_weights(
75    weights: &mut WeightMap,
76    cfg: &Sam3VitConfig,
77    gguf_packed: Option<&GgufPackedParams>,
78) -> Result<Sam3VisionEncoderWeights> {
79    let pre = extract_preprocess_weights(weights, cfg)?;
80    let e = cfg.embed_dim;
81    let (ln_pre_w, ln_pre_b) = take_layer_norm(weights, &prefixes("ln_pre"), e)?;
82    let hidden = (e as f64 * cfg.mlp_ratio) as usize;
83    let mut blocks = Vec::with_capacity(cfg.depth);
84    for i in 0..cfg.depth {
85        let p = format!("blocks.{i}");
86        let pref = prefixes(&p);
87        let (norm1_w, norm1_b) = take_layer_norm(weights, &prefixed(&pref, "norm1"), e)?;
88        let (qkv_w_t, qkv_gguf_prefix) =
89            take_linear_w_or_gguf(weights, gguf_packed, &prefixed(&pref, "attn.qkv"), e, 3 * e)?;
90        let qkv_b = take_linear_b(weights, &prefixed(&pref, "attn.qkv"), 3 * e)?;
91        let (proj_w_t, proj_gguf_prefix) =
92            take_linear_w_or_gguf(weights, gguf_packed, &prefixed(&pref, "attn.proj"), e, e)?;
93        let proj_b = take_linear_b(weights, &prefixed(&pref, "attn.proj"), e)?;
94        let (norm2_w, norm2_b) = take_layer_norm(weights, &prefixed(&pref, "norm2"), e)?;
95        let (mlp_fc1_w_t, mlp_fc1_gguf_prefix) = take_linear_w_any_or_gguf(
96            weights,
97            gguf_packed,
98            &pref,
99            &["mlp.fc1", "mlp.lin1"],
100            e,
101            hidden,
102        )?;
103        let mlp_fc1_b = take_linear_b_any(weights, &pref, &["mlp.fc1", "mlp.lin1"], hidden)?;
104        let (mlp_fc2_w_t, mlp_fc2_gguf_prefix) = take_linear_w_any_or_gguf(
105            weights,
106            gguf_packed,
107            &pref,
108            &["mlp.fc2", "mlp.lin2"],
109            hidden,
110            e,
111        )?;
112        let mlp_fc2_b = take_linear_b_any(weights, &pref, &["mlp.fc2", "mlp.lin2"], e)?;
113        blocks.push(Sam3VitBlockWeights {
114            norm1_w,
115            norm1_b,
116            qkv_w_t,
117            qkv_b,
118            qkv_gguf_prefix,
119            proj_w_t,
120            proj_b,
121            proj_gguf_prefix,
122            norm2_w,
123            norm2_b,
124            mlp_fc1_w_t,
125            mlp_fc1_b,
126            mlp_fc1_gguf_prefix,
127            mlp_fc2_w_t,
128            mlp_fc2_b,
129            mlp_fc2_gguf_prefix,
130        });
131    }
132    Ok(Sam3VisionEncoderWeights {
133        pre,
134        ln_pre_w,
135        ln_pre_b,
136        blocks,
137    })
138}
139
140pub fn encode_image_native(
141    weights: &Sam3VisionEncoderWeights,
142    gguf_packed: Option<&GgufPackedParams>,
143    cfg: &Sam3VitConfig,
144    image_nchw: &[f32],
145) -> Result<Sam3VisionOutput> {
146    let e = cfg.embed_dim;
147    let grid = cfg.patch_grid();
148    ensure!(
149        grid == SAM3_PATCH_GRID,
150        "SAM3 base grid must be {SAM3_PATCH_GRID}"
151    );
152    let head_dim = e / cfg.num_heads;
153    ensure!(
154        head_dim * cfg.num_heads == e,
155        "embed_dim {e} not divisible by num_heads {}",
156        cfg.num_heads
157    );
158    let rope_pt = if cfg.window_size > 0 {
159        cfg.window_size
160    } else {
161        grid
162    };
163
164    // Patch embed (+ tiled abs pos), flat [grid*grid, embed_dim] NHWC.
165    let mut x = assemble_patch_tokens(&weights.pre, image_nchw)?;
166    x = layer_norm(
167        &x,
168        &weights.ln_pre_w,
169        &weights.ln_pre_b,
170        e,
171        cfg.layer_norm_eps as f32,
172    )?;
173
174    let global_set: std::collections::HashSet<usize> =
175        cfg.global_att_blocks.iter().copied().collect();
176    let rope_global = build_rope_freqs(head_dim, grid, grid, 10000.0, rope_pt as f32 / grid as f32);
177    let rope_window = build_rope_freqs(head_dim, cfg.window_size, cfg.window_size, 10000.0, 1.0);
178
179    for (i, block) in weights.blocks.iter().enumerate() {
180        let is_global = global_set.contains(&i);
181        block_forward(
182            &mut x,
183            block,
184            gguf_packed,
185            cfg,
186            grid,
187            if is_global { 0 } else { cfg.window_size },
188            if is_global {
189                &rope_global
190            } else {
191                &rope_window
192            },
193            head_dim,
194            cfg.num_heads,
195        )?;
196    }
197    // ln_post is Identity for SAM3 base, no-op.
198
199    Ok(Sam3VisionOutput {
200        tokens: x,
201        grid,
202        dim: e,
203    })
204}
205
206/// Compute the 2D RoPE frequency table. The layout is `[L, head_dim]` flat,
207/// with each `head_dim`-long stride storing `head_dim/2` interleaved
208/// `(cos, sin)` pairs — first half from the x axis, second from the y axis.
209fn build_rope_freqs(
210    head_dim: usize,
211    end_x: usize,
212    end_y: usize,
213    theta: f32,
214    scale_pos: f32,
215) -> Vec<f32> {
216    let half = head_dim / 2;
217    assert!(
218        head_dim.is_multiple_of(4),
219        "RoPE head_dim must be divisible by 4"
220    );
221    let pair_per_axis = head_dim / 4;
222    let mut freqs_per_pair = Vec::with_capacity(pair_per_axis);
223    for k in 0..pair_per_axis {
224        let exp = (4 * k) as f32 / head_dim as f32;
225        freqs_per_pair.push(1.0 / theta.powf(exp));
226    }
227    let l = end_x * end_y;
228    let mut out = vec![0f32; l * head_dim];
229    for pos in 0..l {
230        let t_x = (pos % end_x) as f32 * scale_pos;
231        let t_y = (pos / end_x) as f32 * scale_pos;
232        for k in 0..pair_per_axis {
233            let ang_x = t_x * freqs_per_pair[k];
234            let ang_y = t_y * freqs_per_pair[k];
235            out[pos * head_dim + 2 * k] = ang_x.cos();
236            out[pos * head_dim + 2 * k + 1] = ang_x.sin();
237            out[pos * head_dim + 2 * (k + pair_per_axis)] = ang_y.cos();
238            out[pos * head_dim + 2 * (k + pair_per_axis) + 1] = ang_y.sin();
239        }
240    }
241    let _ = half;
242    out
243}
244
245/// Apply RoPE in-place to `qk` of shape `[batch_eff * num_heads * L, head_dim]`.
246/// `freqs_cis` is `[L, head_dim]` (real, imag pairs) and broadcasts over the
247/// outer batch×head axis.
248fn rope_apply_inplace(
249    qk: &mut [f32],
250    freqs_cis: &[f32],
251    rows: usize,
252    seq_len: usize,
253    head_dim: usize,
254) {
255    let pairs = head_dim / 2;
256    for r in 0..rows {
257        let l = r % seq_len;
258        let f = &freqs_cis[l * head_dim..(l + 1) * head_dim];
259        let v = &mut qk[r * head_dim..(r + 1) * head_dim];
260        for k in 0..pairs {
261            let vr = v[2 * k];
262            let vi = v[2 * k + 1];
263            let fr = f[2 * k];
264            let fi = f[2 * k + 1];
265            v[2 * k] = vr * fr - vi * fi;
266            v[2 * k + 1] = vr * fi + vi * fr;
267        }
268    }
269}
270
271/// One transformer block: norm → (windowed) attention → residual →
272/// norm → MLP → residual. `x` is `[grid*grid, embed_dim]` NHWC flat.
273fn linear_maybe_gguf(
274    x: &[f32],
275    m: usize,
276    k: usize,
277    w_t: &[f32],
278    gguf: Option<&GgufPackedLinear>,
279    n: usize,
280    b: &[f32],
281) -> Result<Vec<f32>> {
282    let mut out = vec![0f32; m * n];
283    if let Some(p) = gguf {
284        ensure!(
285            p.in_dim == k && p.out_dim == n,
286            "packed linear shape {k}x{n} vs gguf {}x{}",
287            p.in_dim,
288            p.out_dim
289        );
290        rlx_cpu::gguf_matmul::gguf_matmul_bt(x, &p.w_q, &mut out, m, k, n, p.scheme);
291    } else {
292        ensure!(
293            !w_t.is_empty(),
294            "linear: missing F32 weights and no GGUF packed entry"
295        );
296        return linear(x, m, k, w_t, n, b);
297    }
298    for row in 0..m {
299        for col in 0..n {
300            out[row * n + col] += b[col];
301        }
302    }
303    Ok(out)
304}
305
306fn packed_for_prefix<'a>(
307    packed: Option<&'a GgufPackedParams>,
308    prefix: Option<&String>,
309) -> Option<&'a GgufPackedLinear> {
310    let key = format!("{}.weight", prefix.as_ref()?);
311    packed?.get_linear(&key)
312}
313
314fn block_forward(
315    x: &mut [f32],
316    block: &Sam3VitBlockWeights,
317    gguf_packed: Option<&GgufPackedParams>,
318    cfg: &Sam3VitConfig,
319    grid: usize,
320    window_size: usize,
321    freqs_cis: &[f32],
322    head_dim: usize,
323    num_heads: usize,
324) -> Result<()> {
325    let e = cfg.embed_dim;
326    let n = grid * grid;
327    let eps = cfg.layer_norm_eps as f32;
328
329    // shortcut: x as-is. Compute attention(norm1(x)) in attn_out.
330    let n1 = layer_norm(x, &block.norm1_w, &block.norm1_b, e, eps)?;
331    let qkv_gguf = packed_for_prefix(gguf_packed, block.qkv_gguf_prefix.as_ref());
332    let proj_gguf = packed_for_prefix(gguf_packed, block.proj_gguf_prefix.as_ref());
333    let attn_out = if window_size == 0 {
334        attention_native(
335            &n1,
336            1,
337            n,
338            &block.qkv_w_t,
339            qkv_gguf,
340            &block.qkv_b,
341            &block.proj_w_t,
342            proj_gguf,
343            &block.proj_b,
344            freqs_cis,
345            num_heads,
346            head_dim,
347        )?
348    } else {
349        attention_windowed(
350            &n1,
351            grid,
352            grid,
353            window_size,
354            e,
355            &block.qkv_w_t,
356            qkv_gguf,
357            &block.qkv_b,
358            &block.proj_w_t,
359            proj_gguf,
360            &block.proj_b,
361            freqs_cis,
362            num_heads,
363            head_dim,
364        )?
365    };
366    for i in 0..x.len() {
367        x[i] += attn_out[i];
368    }
369
370    let n2 = layer_norm(x, &block.norm2_w, &block.norm2_b, e, eps)?;
371    let hidden = block.mlp_fc1_b.len();
372    let fc1_gguf = packed_for_prefix(gguf_packed, block.mlp_fc1_gguf_prefix.as_ref());
373    let fc2_gguf = packed_for_prefix(gguf_packed, block.mlp_fc2_gguf_prefix.as_ref());
374    let mut mlp = linear_maybe_gguf(
375        &n2,
376        n,
377        e,
378        &block.mlp_fc1_w_t,
379        fc1_gguf,
380        hidden,
381        &block.mlp_fc1_b,
382    )?;
383    gelu_tanh(&mut mlp);
384    let ffn = linear_maybe_gguf(
385        &mlp,
386        n,
387        hidden,
388        &block.mlp_fc2_w_t,
389        fc2_gguf,
390        e,
391        &block.mlp_fc2_b,
392    )?;
393    for i in 0..x.len() {
394        x[i] += ffn[i];
395    }
396    Ok(())
397}
398
399fn attention_windowed(
400    x: &[f32],
401    h: usize,
402    w: usize,
403    ws: usize,
404    e: usize,
405    qkv_w_t: &[f32],
406    qkv_gguf: Option<&GgufPackedLinear>,
407    qkv_b: &[f32],
408    proj_w_t: &[f32],
409    proj_gguf: Option<&GgufPackedLinear>,
410    proj_b: &[f32],
411    freqs_cis: &[f32],
412    num_heads: usize,
413    head_dim: usize,
414) -> Result<Vec<f32>> {
415    let pad_h = (ws - h % ws) % ws;
416    let pad_w = (ws - w % ws) % ws;
417    let hp = h + pad_h;
418    let wp = w + pad_w;
419    let nh = hp / ws;
420    let nw = wp / ws;
421    let num_windows = nh * nw;
422    let win_len = ws * ws;
423
424    // Partition: produce [num_windows, ws, ws, e].
425    let mut win = vec![0f32; num_windows * win_len * e];
426    for y in 0..hp {
427        for xc in 0..wp {
428            let wy = y / ws;
429            let wx = xc / ws;
430            let ry = y % ws;
431            let rx = xc % ws;
432            let widx = wy * nw + wx;
433            let dst = ((widx * ws + ry) * ws + rx) * e;
434            if y < h && xc < w {
435                let src = (y * w + xc) * e;
436                win[dst..dst + e].copy_from_slice(&x[src..src + e]);
437            }
438            // else: padding stays zero (matches F.pad with 0).
439        }
440    }
441
442    let attn = attention_native(
443        &win,
444        num_windows,
445        win_len,
446        qkv_w_t,
447        qkv_gguf,
448        qkv_b,
449        proj_w_t,
450        proj_gguf,
451        proj_b,
452        freqs_cis,
453        num_heads,
454        head_dim,
455    )?;
456
457    // Unpartition: stitch [num_windows, ws, ws, e] back into [h, w, e],
458    // dropping padding.
459    let mut out = vec![0f32; h * w * e];
460    for y in 0..h {
461        for xc in 0..w {
462            let wy = y / ws;
463            let wx = xc / ws;
464            let ry = y % ws;
465            let rx = xc % ws;
466            let widx = wy * nw + wx;
467            let src = ((widx * ws + ry) * ws + rx) * e;
468            let dst = (y * w + xc) * e;
469            out[dst..dst + e].copy_from_slice(&attn[src..src + e]);
470        }
471    }
472    Ok(out)
473}
474
475/// Multi-head self-attention with 2D RoPE for `b` independent sequences of
476/// length `l`. `x` is `[b, l, e]`; output is `[b, l, e]`.
477fn attention_native(
478    x: &[f32],
479    b: usize,
480    l: usize,
481    qkv_w_t: &[f32],
482    qkv_gguf: Option<&GgufPackedLinear>,
483    qkv_b: &[f32],
484    proj_w_t: &[f32],
485    proj_gguf: Option<&GgufPackedLinear>,
486    proj_b: &[f32],
487    freqs_cis: &[f32],
488    num_heads: usize,
489    head_dim: usize,
490) -> Result<Vec<f32>> {
491    let e = num_heads * head_dim;
492    let rows = b * l;
493    let qkv = linear_maybe_gguf(x, rows, e, qkv_w_t, qkv_gguf, 3 * e, qkv_b)?;
494
495    // Split into [b, num_heads, l, head_dim] for q, k, v. We keep them as
496    // [b*num_heads, l, head_dim] = [bh, l, head_dim] to feed sgemm.
497    let bh = b * num_heads;
498    let mut q = vec![0f32; bh * l * head_dim];
499    let mut k = vec![0f32; bh * l * head_dim];
500    let mut v = vec![0f32; bh * l * head_dim];
501    for bi in 0..b {
502        for li in 0..l {
503            let src = (bi * l + li) * 3 * e;
504            for hd in 0..num_heads {
505                let qd_src = src + hd * head_dim;
506                let kd_src = src + e + hd * head_dim;
507                let vd_src = src + 2 * e + hd * head_dim;
508                let dst = ((bi * num_heads + hd) * l + li) * head_dim;
509                q[dst..dst + head_dim].copy_from_slice(&qkv[qd_src..qd_src + head_dim]);
510                k[dst..dst + head_dim].copy_from_slice(&qkv[kd_src..kd_src + head_dim]);
511                v[dst..dst + head_dim].copy_from_slice(&qkv[vd_src..vd_src + head_dim]);
512            }
513        }
514    }
515
516    rope_apply_inplace(&mut q, freqs_cis, bh * l, l, head_dim);
517    rope_apply_inplace(&mut k, freqs_cis, bh * l, l, head_dim);
518
519    let scale = 1.0f32 / (head_dim as f32).sqrt();
520    let mut attn_out = vec![0f32; bh * l * head_dim];
521    let mut scores = vec![0f32; l * l];
522
523    for bhi in 0..bh {
524        let q_h = &q[bhi * l * head_dim..(bhi + 1) * l * head_dim];
525        let k_h = &k[bhi * l * head_dim..(bhi + 1) * l * head_dim];
526        let v_h = &v[bhi * l * head_dim..(bhi + 1) * l * head_dim];
527        // scores[l, l] = scale * Q[l, hd] @ K[l, hd]^T
528        matmul_bt(q_h, k_h, &mut scores, l, head_dim, l, scale);
529        softmax_rows(&mut scores, l, l);
530        // out[l, hd] = scores[l, l] @ V[l, hd]
531        let out_h = &mut attn_out[bhi * l * head_dim..(bhi + 1) * l * head_dim];
532        matmul(&scores, v_h, out_h, l, l, head_dim);
533    }
534
535    // Repack [b, num_heads, l, head_dim] → [b, l, num_heads*head_dim] for proj.
536    let mut packed = vec![0f32; rows * e];
537    for bi in 0..b {
538        for li in 0..l {
539            for hd in 0..num_heads {
540                let src = ((bi * num_heads + hd) * l + li) * head_dim;
541                let dst = (bi * l + li) * e + hd * head_dim;
542                packed[dst..dst + head_dim].copy_from_slice(&attn_out[src..src + head_dim]);
543            }
544        }
545    }
546    linear_maybe_gguf(&packed, rows, e, proj_w_t, proj_gguf, e, proj_b)
547}
548
549fn prefixes(suffix: &str) -> Vec<String> {
550    [
551        "detector.backbone.vision_backbone.trunk",
552        "detector.backbone.visual.trunk",
553        "backbone.vision_backbone.trunk",
554        "backbone.visual.trunk",
555        "visual.trunk",
556        "trunk",
557    ]
558    .iter()
559    .map(|p| format!("{p}.{suffix}"))
560    .collect()
561}
562
563fn prefixed(prefixes: &[String], suffix: &str) -> Vec<String> {
564    prefixes.iter().map(|p| format!("{p}.{suffix}")).collect()
565}
566
567fn take_layer_norm(
568    weights: &mut WeightMap,
569    bases: &[String],
570    dim: usize,
571) -> Result<(Vec<f32>, Vec<f32>)> {
572    let w = take_shape(weights, &suffixes(bases, "weight"), &[dim])?;
573    let b = take_shape(weights, &suffixes(bases, "bias"), &[dim])?;
574    Ok((w, b))
575}
576
577fn take_linear_w_or_gguf(
578    weights: &mut WeightMap,
579    gguf_packed: Option<&GgufPackedParams>,
580    bases: &[String],
581    in_dim: usize,
582    out_dim: usize,
583) -> Result<(Vec<f32>, Option<String>)> {
584    let keys = suffixes(bases, "weight");
585    for key in &keys {
586        if weights.has(key) {
587            let w = take_linear_w(weights, bases, in_dim, out_dim)?;
588            return Ok((w, None));
589        }
590        if let Some(packed) = gguf_packed {
591            if let Some(prefix) = key.strip_suffix(".weight") {
592                if packed.get_linear(key).is_some() {
593                    return Ok((Vec::new(), Some(prefix.to_string())));
594                }
595            }
596        }
597    }
598    anyhow::bail!("none of the SAM3 linear weight keys were found: {keys:?}")
599}
600
601fn take_linear_w_any_or_gguf(
602    weights: &mut WeightMap,
603    gguf_packed: Option<&GgufPackedParams>,
604    block_prefixes: &[String],
605    names: &[&str],
606    in_dim: usize,
607    out_dim: usize,
608) -> Result<(Vec<f32>, Option<String>)> {
609    let bases: Vec<String> = block_prefixes
610        .iter()
611        .flat_map(|p| names.iter().map(move |name| format!("{p}.{name}")))
612        .collect();
613    take_linear_w_or_gguf(weights, gguf_packed, &bases, in_dim, out_dim)
614}
615
616fn take_linear_w(
617    weights: &mut WeightMap,
618    bases: &[String],
619    in_dim: usize,
620    out_dim: usize,
621) -> Result<Vec<f32>> {
622    let keys = suffixes(bases, "weight");
623    for key in &keys {
624        if weights.has(key) {
625            let (data, shape) = weights.take_transposed(key)?;
626            ensure!(
627                shape == vec![in_dim, out_dim],
628                "{key} expected [{in_dim}, {out_dim}], got {shape:?}"
629            );
630            return Ok(data);
631        }
632    }
633    anyhow::bail!("none of the SAM3 linear weight keys were found: {keys:?}")
634}
635
636#[allow(dead_code)]
637fn take_linear_w_any(
638    weights: &mut WeightMap,
639    block_prefixes: &[String],
640    names: &[&str],
641    in_dim: usize,
642    out_dim: usize,
643) -> Result<Vec<f32>> {
644    let bases: Vec<String> = block_prefixes
645        .iter()
646        .flat_map(|p| names.iter().map(move |name| format!("{p}.{name}")))
647        .collect();
648    take_linear_w(weights, &bases, in_dim, out_dim)
649}
650
651fn take_linear_b(weights: &mut WeightMap, bases: &[String], dim: usize) -> Result<Vec<f32>> {
652    take_shape(weights, &suffixes(bases, "bias"), &[dim])
653}
654
655fn take_linear_b_any(
656    weights: &mut WeightMap,
657    block_prefixes: &[String],
658    names: &[&str],
659    dim: usize,
660) -> Result<Vec<f32>> {
661    let bases: Vec<String> = block_prefixes
662        .iter()
663        .flat_map(|p| names.iter().map(move |name| format!("{p}.{name}")))
664        .collect();
665    take_linear_b(weights, &bases, dim)
666}
667
668fn suffixes(bases: &[String], suffix: &str) -> Vec<String> {
669    bases.iter().map(|b| format!("{b}.{suffix}")).collect()
670}
671
672fn take_shape(weights: &mut WeightMap, keys: &[String], expected: &[usize]) -> Result<Vec<f32>> {
673    for key in keys {
674        if weights.has(key) {
675            let (data, shape) = weights.take(key)?;
676            ensure!(
677                shape == expected,
678                "{key} expected {expected:?}, got {shape:?}"
679            );
680            return Ok(data);
681        }
682    }
683    anyhow::bail!("none of the SAM3 weight keys were found: {keys:?}")
684}