Skip to main content

rlx_sam3/
text_encoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Native SAM3 text encoder (`VETextEncoder`).
17//!
18//! Architecture (matches `facebookresearch/sam3.model.text_encoder_ve`):
19//!
20//!   - `token_embedding`         : `[49408, 1024]`
21//!   - `positional_embedding`    : `[32, 1024]`
22//!   - 24 × ResidualAttentionBlock(width=1024, heads=16, mlp_ratio=4)
23//!     using `nn.MultiheadAttention` (`in_proj_weight [3*W, W]`,
24//!     `in_proj_bias [3*W]`, `out_proj.weight [W, W]`, `out_proj.bias [W]`)
25//!     and a 32×32 upper-triangular `-inf` causal mask.
26//!   - `ln_final`                : LayerNorm 1024
27//!   - `resizer = Linear(1024, 256)` outside the encoder.
28//!
29//! Output: `text_memory_resized` of shape `[seq_len, batch, 256]`.
30//!
31//! Token IDs are accepted as input — the BPE tokenizer port is deferred to
32//! a follow-up that ships an embedded BPE vocab.
33
34use super::config::Sam3TextConfig;
35use super::tensor::{layer_norm, matmul, matmul_bt, softmax_rows};
36use rlx_core::weight_map::WeightMap;
37use rlx_flow::GgufPackedParams;
38
39use crate::packed_gguf::{linear_maybe_gguf, take_or_gguf, take_transposed_with_gguf_key};
40use anyhow::{Result, ensure};
41
42#[derive(Clone)]
43pub struct Sam3TextBlock {
44    pub ln1_w: Vec<f32>,
45    pub ln1_b: Vec<f32>,
46    pub qkv_w_t: Vec<f32>,
47    pub qkv_b: Vec<f32>,
48    pub proj_w_t: Vec<f32>,
49    pub proj_b: Vec<f32>,
50    pub ln2_w: Vec<f32>,
51    pub ln2_b: Vec<f32>,
52    pub mlp_fc_w_t: Vec<f32>,
53    pub mlp_fc_b: Vec<f32>,
54    pub mlp_proj_w_t: Vec<f32>,
55    pub mlp_proj_b: Vec<f32>,
56    pub qkv_gguf_key: Option<String>,
57    pub proj_gguf_key: Option<String>,
58    pub mlp_fc_gguf_key: Option<String>,
59    pub mlp_proj_gguf_key: Option<String>,
60}
61
62#[derive(Clone, Default)]
63pub struct Sam3TextEncoderWeights {
64    pub loaded: bool,
65    pub width: usize,
66    pub heads: usize,
67    pub context_length: usize,
68    pub d_model: usize,
69    pub vocab_size: usize,
70    pub token_embedding: Vec<f32>,
71    pub positional_embedding: Vec<f32>,
72    pub ln_final_w: Vec<f32>,
73    pub ln_final_b: Vec<f32>,
74    pub blocks: Vec<Sam3TextBlock>,
75    pub resizer_w_t: Vec<f32>,
76    pub resizer_b: Vec<f32>,
77    pub resizer_gguf_key: Option<String>,
78}
79
80#[derive(Debug, Clone, Default)]
81pub struct Sam3TextEncoded {
82    /// `[batch, seq]` byte mask (1 = PAD token).
83    pub attention_mask: Vec<u8>,
84    /// `[seq, batch, d_model]` resized text memory.
85    pub text_memory_resized: Vec<f32>,
86    /// `[seq, batch, width]` raw token embeddings.
87    pub inputs_embeds: Vec<f32>,
88    pub seq_len: usize,
89    pub batch: usize,
90    pub d_model: usize,
91    pub width: usize,
92}
93
94pub fn extract_text_encoder_weights(
95    weights: &mut WeightMap,
96    cfg: &Sam3TextConfig,
97    gguf_packed: Option<&GgufPackedParams>,
98) -> Result<Sam3TextEncoderWeights> {
99    let width = cfg.width;
100    let heads = cfg.heads;
101    let layers = cfg.layers;
102    let d_model = cfg.d_model;
103    let context_length = 32usize;
104    let vocab_size = 49408usize;
105    let _mlp_width = width * 4;
106
107    let prefixes = [
108        "detector.backbone.language_backbone",
109        "backbone.language_backbone",
110        "language_backbone",
111    ];
112    let enc_prefix = {
113        let mut found = None;
114        for p in prefixes {
115            let key = format!("{p}.encoder.token_embedding.weight");
116            if weights.has(&key) {
117                found = Some(p);
118                break;
119            }
120        }
121        found.ok_or_else(|| anyhow::anyhow!("SAM3 language_backbone not found"))?
122    };
123
124    let (token_embedding, te_shape) = take_or_gguf(
125        weights,
126        gguf_packed,
127        &format!("{enc_prefix}.encoder.token_embedding.weight"),
128    )?;
129    ensure!(
130        te_shape == vec![vocab_size, width],
131        "token_embedding shape {te_shape:?}"
132    );
133    let (positional_embedding, pe_shape) = take_or_gguf(
134        weights,
135        gguf_packed,
136        &format!("{enc_prefix}.encoder.positional_embedding"),
137    )?;
138    ensure!(
139        pe_shape == vec![context_length, width],
140        "positional_embedding shape {pe_shape:?}"
141    );
142    let (ln_final_w, _) = take_or_gguf(
143        weights,
144        gguf_packed,
145        &format!("{enc_prefix}.encoder.ln_final.weight"),
146    )?;
147    let (ln_final_b, _) = take_or_gguf(
148        weights,
149        gguf_packed,
150        &format!("{enc_prefix}.encoder.ln_final.bias"),
151    )?;
152
153    // text_projection is unused by VETextEncoder.forward (returns the
154    // per-token sequence, not pooled) — drop it without checking shape.
155    let _ = weights.take(&format!("{enc_prefix}.encoder.text_projection"));
156
157    let mut blocks = Vec::with_capacity(layers);
158    for i in 0..layers {
159        let bp = format!("{enc_prefix}.encoder.transformer.resblocks.{i}");
160        let (ln1_w, _) = take_or_gguf(weights, gguf_packed, &format!("{bp}.ln_1.weight"))?;
161        let (ln1_b, _) = take_or_gguf(weights, gguf_packed, &format!("{bp}.ln_1.bias"))?;
162        let (qkv_w_t, qkv_gguf_key) = take_transposed_with_gguf_key(
163            weights,
164            gguf_packed,
165            &format!("{bp}.attn.in_proj_weight"),
166        )?;
167        let (qkv_b, _) = take_or_gguf(weights, gguf_packed, &format!("{bp}.attn.in_proj_bias"))?;
168        let (proj_w_t, proj_gguf_key) = take_transposed_with_gguf_key(
169            weights,
170            gguf_packed,
171            &format!("{bp}.attn.out_proj.weight"),
172        )?;
173        let (proj_b, _) = take_or_gguf(weights, gguf_packed, &format!("{bp}.attn.out_proj.bias"))?;
174        let (ln2_w, _) = take_or_gguf(weights, gguf_packed, &format!("{bp}.ln_2.weight"))?;
175        let (ln2_b, _) = take_or_gguf(weights, gguf_packed, &format!("{bp}.ln_2.bias"))?;
176        let (mlp_fc_w_t, mlp_fc_gguf_key) =
177            take_transposed_with_gguf_key(weights, gguf_packed, &format!("{bp}.mlp.c_fc.weight"))?;
178        let (mlp_fc_b, _) = take_or_gguf(weights, gguf_packed, &format!("{bp}.mlp.c_fc.bias"))?;
179        let (mlp_proj_w_t, mlp_proj_gguf_key) = take_transposed_with_gguf_key(
180            weights,
181            gguf_packed,
182            &format!("{bp}.mlp.c_proj.weight"),
183        )?;
184        let (mlp_proj_b, _) = take_or_gguf(weights, gguf_packed, &format!("{bp}.mlp.c_proj.bias"))?;
185        blocks.push(Sam3TextBlock {
186            ln1_w,
187            ln1_b,
188            qkv_w_t,
189            qkv_b,
190            proj_w_t,
191            proj_b,
192            ln2_w,
193            ln2_b,
194            mlp_fc_w_t,
195            mlp_fc_b,
196            mlp_proj_w_t,
197            mlp_proj_b,
198            qkv_gguf_key,
199            proj_gguf_key,
200            mlp_fc_gguf_key,
201            mlp_proj_gguf_key,
202        });
203    }
204
205    let (resizer_w_t, resizer_gguf_key) = take_transposed_with_gguf_key(
206        weights,
207        gguf_packed,
208        &format!("{enc_prefix}.resizer.weight"),
209    )?;
210    let (resizer_b, _) = take_or_gguf(weights, gguf_packed, &format!("{enc_prefix}.resizer.bias"))?;
211
212    Ok(Sam3TextEncoderWeights {
213        loaded: true,
214        width,
215        heads,
216        context_length,
217        d_model,
218        vocab_size,
219        token_embedding,
220        positional_embedding,
221        ln_final_w,
222        ln_final_b,
223        blocks,
224        resizer_w_t,
225        resizer_b,
226        resizer_gguf_key,
227    })
228}
229
230/// Run the text encoder on already-tokenized inputs (`[batch, seq_len]`).
231pub fn encode_tokens(
232    weights: &Sam3TextEncoderWeights,
233    tokens: &[u32],
234    batch: usize,
235    seq_len: usize,
236    gguf_packed: Option<&GgufPackedParams>,
237) -> Result<Sam3TextEncoded> {
238    ensure!(weights.loaded, "SAM3 text encoder weights not loaded");
239    ensure!(
240        tokens.len() == batch * seq_len,
241        "expected {} tokens, got {}",
242        batch * seq_len,
243        tokens.len()
244    );
245    ensure!(
246        seq_len <= weights.context_length,
247        "seq_len {seq_len} exceeds context_length {}",
248        weights.context_length
249    );
250    let w = weights.width;
251    let h = weights.heads;
252    let head_dim = w / h;
253    ensure!(head_dim * h == w, "width {w} not divisible by heads {h}");
254
255    let mut x = vec![0f32; batch * seq_len * w];
256    let mut inputs_embeds = vec![0f32; batch * seq_len * w];
257    for b in 0..batch {
258        for l in 0..seq_len {
259            let tok = tokens[b * seq_len + l] as usize;
260            ensure!(tok < weights.vocab_size, "token id {tok} out of vocab");
261            let src = &weights.token_embedding[tok * w..(tok + 1) * w];
262            let dst_x = &mut x[(b * seq_len + l) * w..(b * seq_len + l + 1) * w];
263            let dst_emb = &mut inputs_embeds[(b * seq_len + l) * w..(b * seq_len + l + 1) * w];
264            dst_emb.copy_from_slice(src);
265            let pos = &weights.positional_embedding[l * w..(l + 1) * w];
266            for k in 0..w {
267                dst_x[k] = src[k] + pos[k];
268            }
269        }
270    }
271
272    // Causal additive mask [seq_len, seq_len].
273    let neg_inf = f32::NEG_INFINITY;
274    let mut mask = vec![0f32; seq_len * seq_len];
275    for i in 0..seq_len {
276        for j in (i + 1)..seq_len {
277            mask[i * seq_len + j] = neg_inf;
278        }
279    }
280
281    for block in &weights.blocks {
282        x = block_forward(
283            &x,
284            block,
285            batch,
286            seq_len,
287            w,
288            h,
289            head_dim,
290            &mask,
291            gguf_packed,
292        )?;
293    }
294    x = layer_norm(&x, &weights.ln_final_w, &weights.ln_final_b, w, 1e-5)?;
295
296    // Reorder [B, L, W] → [L, B, W] (sequence-first), then resize.
297    let mut text_memory_seq_first = vec![0f32; seq_len * batch * w];
298    for b in 0..batch {
299        for l in 0..seq_len {
300            let src = &x[(b * seq_len + l) * w..(b * seq_len + l + 1) * w];
301            let dst = &mut text_memory_seq_first[(l * batch + b) * w..(l * batch + b + 1) * w];
302            dst.copy_from_slice(src);
303        }
304    }
305    let mut inputs_embeds_seq_first = vec![0f32; seq_len * batch * w];
306    for b in 0..batch {
307        for l in 0..seq_len {
308            let src = &inputs_embeds[(b * seq_len + l) * w..(b * seq_len + l + 1) * w];
309            let dst = &mut inputs_embeds_seq_first[(l * batch + b) * w..(l * batch + b + 1) * w];
310            dst.copy_from_slice(src);
311        }
312    }
313
314    let text_memory_resized = linear_maybe_gguf(
315        &text_memory_seq_first,
316        seq_len * batch,
317        w,
318        &weights.resizer_w_t,
319        weights.resizer_gguf_key.as_deref(),
320        gguf_packed,
321        weights.d_model,
322        &weights.resizer_b,
323    )?;
324
325    let mut attention_mask = vec![0u8; batch * seq_len];
326    for i in 0..batch * seq_len {
327        attention_mask[i] = if tokens[i] == 0 { 1 } else { 0 };
328    }
329
330    Ok(Sam3TextEncoded {
331        attention_mask,
332        text_memory_resized,
333        inputs_embeds: inputs_embeds_seq_first,
334        seq_len,
335        batch,
336        d_model: weights.d_model,
337        width: w,
338    })
339}
340
341fn block_forward(
342    x_in: &[f32],
343    block: &Sam3TextBlock,
344    batch: usize,
345    seq_len: usize,
346    width: usize,
347    heads: usize,
348    head_dim: usize,
349    mask: &[f32],
350    gguf_packed: Option<&GgufPackedParams>,
351) -> Result<Vec<f32>> {
352    let rows = batch * seq_len;
353    let n1 = layer_norm(x_in, &block.ln1_w, &block.ln1_b, width, 1e-5)?;
354    let qkv = linear_maybe_gguf(
355        &n1,
356        rows,
357        width,
358        &block.qkv_w_t,
359        block.qkv_gguf_key.as_deref(),
360        gguf_packed,
361        3 * width,
362        &block.qkv_b,
363    )?;
364
365    let bh = batch * heads;
366    let mut q = vec![0f32; bh * seq_len * head_dim];
367    let mut k = vec![0f32; bh * seq_len * head_dim];
368    let mut v = vec![0f32; bh * seq_len * head_dim];
369    for b in 0..batch {
370        for l in 0..seq_len {
371            let src = (b * seq_len + l) * 3 * width;
372            for hd in 0..heads {
373                let qd_src = src + hd * head_dim;
374                let kd_src = src + width + hd * head_dim;
375                let vd_src = src + 2 * width + hd * head_dim;
376                let dst = ((b * heads + hd) * seq_len + l) * head_dim;
377                q[dst..dst + head_dim].copy_from_slice(&qkv[qd_src..qd_src + head_dim]);
378                k[dst..dst + head_dim].copy_from_slice(&qkv[kd_src..kd_src + head_dim]);
379                v[dst..dst + head_dim].copy_from_slice(&qkv[vd_src..vd_src + head_dim]);
380            }
381        }
382    }
383
384    let scale = 1.0f32 / (head_dim as f32).sqrt();
385    let mut attn_out = vec![0f32; bh * seq_len * head_dim];
386    let mut scores = vec![0f32; seq_len * seq_len];
387    for bhi in 0..bh {
388        let q_h = &q[bhi * seq_len * head_dim..(bhi + 1) * seq_len * head_dim];
389        let k_h = &k[bhi * seq_len * head_dim..(bhi + 1) * seq_len * head_dim];
390        let v_h = &v[bhi * seq_len * head_dim..(bhi + 1) * seq_len * head_dim];
391        matmul_bt(q_h, k_h, &mut scores, seq_len, head_dim, seq_len, scale);
392        for r in 0..seq_len {
393            for c in 0..seq_len {
394                scores[r * seq_len + c] += mask[r * seq_len + c];
395            }
396        }
397        softmax_rows(&mut scores, seq_len, seq_len);
398        let out_h = &mut attn_out[bhi * seq_len * head_dim..(bhi + 1) * seq_len * head_dim];
399        matmul(&scores, v_h, out_h, seq_len, seq_len, head_dim);
400    }
401
402    let mut packed = vec![0f32; rows * width];
403    for b in 0..batch {
404        for l in 0..seq_len {
405            for hd in 0..heads {
406                let src = ((b * heads + hd) * seq_len + l) * head_dim;
407                let dst = (b * seq_len + l) * width + hd * head_dim;
408                packed[dst..dst + head_dim].copy_from_slice(&attn_out[src..src + head_dim]);
409            }
410        }
411    }
412    let attn_proj = linear_maybe_gguf(
413        &packed,
414        rows,
415        width,
416        &block.proj_w_t,
417        block.proj_gguf_key.as_deref(),
418        gguf_packed,
419        width,
420        &block.proj_b,
421    )?;
422
423    let mut x = x_in.to_vec();
424    for i in 0..x.len() {
425        x[i] += attn_proj[i];
426    }
427    let n2 = layer_norm(&x, &block.ln2_w, &block.ln2_b, width, 1e-5)?;
428    let mlp_hidden = block.mlp_fc_b.len();
429    let mut mlp = linear_maybe_gguf(
430        &n2,
431        rows,
432        width,
433        &block.mlp_fc_w_t,
434        block.mlp_fc_gguf_key.as_deref(),
435        gguf_packed,
436        mlp_hidden,
437        &block.mlp_fc_b,
438    )?;
439    gelu_exact_inplace(&mut mlp);
440    let ffn = linear_maybe_gguf(
441        &mlp,
442        rows,
443        mlp_hidden,
444        &block.mlp_proj_w_t,
445        block.mlp_proj_gguf_key.as_deref(),
446        gguf_packed,
447        width,
448        &block.mlp_proj_b,
449    )?;
450    for i in 0..x.len() {
451        x[i] += ffn[i];
452    }
453    Ok(x)
454}
455
456fn gelu_exact_inplace(x: &mut [f32]) {
457    let inv_sqrt2 = 1.0f32 / std::f32::consts::SQRT_2;
458    for v in x.iter_mut() {
459        *v = 0.5 * *v * (1.0 + erf_approx(*v * inv_sqrt2));
460    }
461}
462
463fn erf_approx(x: f32) -> f32 {
464    let sign = if x < 0.0 { -1.0f32 } else { 1.0 };
465    let ax = x.abs();
466    let p = 0.3275911f32;
467    let a1 = 0.2548296f32;
468    let a2 = -0.2844967f32;
469    let a3 = 1.4214138f32;
470    let a4 = -1.4531521f32;
471    let a5 = 1.0614054f32;
472    let t = 1.0 / (1.0 + p * ax);
473    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-ax * ax).exp();
474    sign * y
475}
476
477/// Legacy shim for `Sam3::predict_image`. Returns an empty/PAD encoding —
478/// the BPE tokenizer port is a separate task; for now real prompts must go
479/// through `encode_tokens` with externally-tokenized inputs.
480pub fn encode_text_native(
481    weights: &Sam3TextEncoderWeights,
482    cfg: &Sam3TextConfig,
483    _prompt: Option<&str>,
484    gguf_packed: Option<&GgufPackedParams>,
485) -> Result<Sam3TextEncoded> {
486    if !weights.loaded {
487        return Ok(Sam3TextEncoded {
488            d_model: cfg.d_model,
489            width: cfg.width,
490            ..Default::default()
491        });
492    }
493    let seq_len = weights.context_length;
494    let tokens = vec![0u32; seq_len];
495    encode_tokens(weights, &tokens, 1, seq_len, gguf_packed)
496}