ferrum-models 0.7.2

Model architectures (LLaMA, Qwen, BERT) for Ferrum inference
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
//! Qwen3-TTS Talker using the `Backend<B>` trait — the Model-as-Code port
//! that replaces `ferrum_attention::FusedTransformer` for Phase F.
//!
//! Why this exists: the old `qwen3_tts.rs` uses `FusedTransformer` which
//! only has Metal + CPU backends. Its CUDA module is a stub that falls back
//! to CPU, and on Linux that CPU path uses naive fp64 O(n³) matmul. Through
//! 20 decoder layers + ~128 decode steps the accumulated rounding diverges
//! enough from training numerics to select wrong codec tokens, producing
//! YouTube-outro garbage instead of the target text.
//!
//! This file reuses `LlamaFamilyModel<B>` (via `new_backbone_only`) for the
//! transformer stack — inheriting CUDA/Metal/CPU kernels, KV cache pooling,
//! and batched decode — and adds TTS-specific wiring on top: dual text /
//! codec embeddings, a text projection MLP, and the codec output head.

use ferrum_kernels::backend::Backend;
use ferrum_quantization::loader::WeightLoader;
use ferrum_quantization::traits::Linear;
use ferrum_quantization::PrefixedLoader;
use ferrum_types::Result;
use std::collections::HashMap;

use crate::architectures::qwen3_tts::TalkerConfig;
use crate::models::llama_family::{LlamaFamilyConfig, LlamaFamilyModel};

/// Qwen3-TTS Talker, Model-as-Code implementation over `Backend<B>`.
///
/// Ownership model: owns a backbone `LlamaFamilyModel<B>` configured for
/// 20-layer Qwen3 shape (hidden=1024, heads=16/2, head_dim=64) plus the
/// TTS-specific head/tail layers (text embedding + projection, codec
/// embedding, codec head).
///
/// Forward flow per step:
///   text_ids → text_embed → silu(fc1) → fc2 → mixed_embeds
///   codec_ids → codec_embed → mixed_embeds
///   mixed_embeds → backbone.{prefill,decode}_from_embed → hidden
///   hidden → final_norm (via backbone.final_norm_w) → codec_head → logits
pub struct Qwen3TtsTalker<B: Backend> {
    pub cfg: TalkerConfig,

    /// Transformer backbone — only `layers`, `final_norm_w`, `scratch`,
    /// `kv_caches`, and `rope` are used. `embed` and `lm_head` are `None`
    /// because TTS embeds externally and applies `codec_head` separately.
    pub backbone: LlamaFamilyModel<B>,

    /// Text token embedding table: `[text_vocab * text_hidden]`.
    /// Qwen3-TTS: text_vocab=151936, text_hidden=2048.
    pub text_embedding: B::Buffer,

    /// Text projection: `text_hidden -> text_hidden` (linear_fc1) then SiLU
    /// then `text_hidden -> hidden` (linear_fc2). Both have bias.
    pub text_proj_fc1: Box<dyn Linear<B>>,
    pub text_proj_fc2: Box<dyn Linear<B>>,

    /// Codec token embedding table: `[vocab * hidden]`. Qwen3-TTS has
    /// vocab=3072, hidden=1024.
    pub codec_embedding: B::Buffer,

    /// Output head: `hidden -> vocab` (no bias on codec_head).
    pub codec_head: Box<dyn Linear<B>>,

    /// Per-sequence position tracking — each call to `prefill` / `decode`
    /// advances the cached position. The backbone manages its own KV cache
    /// under the same cache_id.
    positions: HashMap<String, u32>,
}

impl<B: Backend> Qwen3TtsTalker<B> {
    /// Build a Qwen3-TTS Talker from weights. Uses `PrefixedLoader` with
    /// `"talker."` prefix internally so the backbone can reuse its standard
    /// `model.layers.{i}.*` / `model.norm.weight` lookups.
    pub fn new(cfg: TalkerConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
        let backbone_cfg = LlamaFamilyConfig {
            hidden_size: cfg.hidden_size,
            intermediate_size: cfg.intermediate_size,
            num_heads: cfg.num_attention_heads,
            num_kv_heads: cfg.num_key_value_heads,
            head_dim: cfg.head_dim,
            num_layers: cfg.num_hidden_layers,
            vocab_size: cfg.vocab_size,
            max_seq_len: cfg.max_position_embeddings,
            rms_norm_eps: cfg.rms_norm_eps as f32,
            rope_theta: cfg.rope_theta,
            has_qk_norm: true,
            sliding_window: 0,
        };

        // Backbone transformer — loads model.layers.{i}.* and model.norm
        // under the `talker.` prefix.
        let talker_loader = PrefixedLoader::new(loader, "talker.");
        let backbone = LlamaFamilyModel::<B>::new_backbone_only(backbone_cfg, &talker_loader)?;

        // Dual embeddings.
        let text_embedding = loader.load_tensor("talker.model.text_embedding.weight")?;
        let codec_embedding = loader.load_tensor("talker.model.codec_embedding.weight")?;

        // Text projection: 2048 → 2048 → 1024, both with bias.
        let text_proj_fc1 = loader.load_linear("talker.text_projection.linear_fc1")?;
        let text_proj_fc2 = loader.load_linear("talker.text_projection.linear_fc2")?;

        // Output head: 1024 → 3072, no bias (Qwen3-TTS stores as linear_no_bias).
        let codec_head = loader.load_linear("talker.codec_head")?;

        Ok(Self {
            cfg,
            backbone,
            text_embedding,
            text_proj_fc1,
            text_proj_fc2,
            codec_embedding,
            codec_head,
            positions: HashMap::new(),
        })
    }

    /// Reset per-sequence state so the next call starts from position 0.
    pub fn reset(&mut self, cache_id: &str) {
        self.positions.remove(cache_id);
        self.backbone.kv_caches.remove(cache_id);
    }

    /// Clear all sessions.
    pub fn reset_all(&mut self) {
        self.positions.clear();
        self.backbone.kv_caches.clear();
    }

    /// Embed a text token as `[hidden]` f32 — goes through text_embedding →
    /// silu(fc1) → fc2. Runs once per text token at prefill time; not in
    /// the decode hot loop, so the SiLU CPU roundtrip is acceptable.
    fn embed_text_token(&mut self, token: u32) -> Vec<f32> {
        let text_hidden = self.cfg.text_hidden_size;
        let hidden = self.cfg.hidden_size;
        let mut ctx = B::new_context();

        // text_embedding[token] → [text_hidden]
        let mut embed_out = B::alloc(text_hidden);
        B::embedding_lookup(
            &mut ctx,
            &self.text_embedding,
            &[token],
            &mut embed_out,
            text_hidden,
        );

        // fc1: text_hidden → text_hidden (+ bias)
        let mut fc1_out = B::alloc(text_hidden);
        self.text_proj_fc1
            .forward(&mut ctx, &embed_out, &mut fc1_out, 1);

        // SiLU(x) = x * sigmoid(x). Done on CPU — prefill-only path.
        B::sync(&mut ctx);
        let fc1_host = B::to_vec(&fc1_out, text_hidden);
        let silu_host: Vec<f32> = fc1_host
            .iter()
            .map(|&x| x * (1.0f32 / (1.0f32 + (-x).exp())))
            .collect();
        let silu_dev = B::from_slice(&silu_host);

        // fc2: text_hidden → hidden (+ bias)
        let mut fc2_out = B::alloc(hidden);
        self.text_proj_fc2
            .forward(&mut ctx, &silu_dev, &mut fc2_out, 1);
        B::sync(&mut ctx);

        B::to_vec(&fc2_out, hidden)
    }

    /// Embed a codec token as `[hidden]` f32.
    fn embed_codec_token(&mut self, token: u32) -> Vec<f32> {
        let hidden = self.cfg.hidden_size;
        let mut ctx = B::new_context();
        let mut out = B::alloc(hidden);
        B::embedding_lookup(&mut ctx, &self.codec_embedding, &[token], &mut out, hidden);
        B::sync(&mut ctx);
        B::to_vec(&out, hidden)
    }

    /// Prefill with a mixed text / codec token sequence. Each input is
    /// `(token_id, is_text)` — text tokens route through text_embedding +
    /// projection; codec tokens route through codec_embedding.
    ///
    /// Returns `[vocab_size]` logits for the last position, ready for
    /// sampling the first codec token.
    pub fn prefill(&mut self, cache_id: &str, tokens: &[(u32, bool)]) -> Vec<f32> {
        let h = self.cfg.hidden_size;
        let seq_len = tokens.len();

        // Build [seq_len * hidden] mixed embedding on host (embeds are
        // tiny-per-token; this stays out of the hot decode loop).
        let mut mixed = Vec::with_capacity(seq_len * h);
        for (tok, is_text) in tokens {
            let emb = if *is_text {
                self.embed_text_token(*tok)
            } else {
                self.embed_codec_token(*tok)
            };
            mixed.extend(emb);
        }

        // Backbone transformer + last-pos hidden extract (no final norm,
        // no output head — we apply those ourselves below).
        let pre_norm_hidden = self.backbone.prefill_from_embeds(cache_id, &mixed, seq_len);

        self.positions.insert(cache_id.to_string(), seq_len as u32);

        // Final norm + codec head.
        self.apply_head(&pre_norm_hidden)
    }

    /// Decode one codec token — advances the sequence by 1.
    pub fn decode_codec(&mut self, cache_id: &str, token: u32) -> Vec<f32> {
        let pos = *self.positions.get(cache_id).unwrap_or(&0);
        let embed = self.embed_codec_token(token);
        let pre_norm = self.backbone.decode_from_embed(cache_id, &embed, pos);
        self.positions.insert(cache_id.to_string(), pos + 1);
        self.apply_head(&pre_norm)
    }

    /// Apply final_norm + codec_head on a `[hidden]` f32 vector, return
    /// `[vocab_size]` logits.
    fn apply_head(&mut self, hidden_f32: &[f32]) -> Vec<f32> {
        let h = self.cfg.hidden_size;
        let vocab = self.cfg.vocab_size;
        debug_assert_eq!(hidden_f32.len(), h);

        let mut ctx = B::new_context();
        let hidden_buf = B::from_slice(hidden_f32);
        let mut normed = B::alloc(h);
        B::rms_norm(
            &mut ctx,
            &hidden_buf,
            &self.backbone.final_norm_w,
            self.cfg.rms_norm_eps as f32,
            &mut normed,
            1,
            h,
        );

        let mut logits = B::alloc(vocab);
        self.codec_head.forward(&mut ctx, &normed, &mut logits, 1);
        B::sync(&mut ctx);
        B::to_vec(&logits, vocab)
    }

    /// Expose hidden state for the last position (after final_norm, before
    /// codec_head). SubTalker needs this to run its own transformer.
    pub fn last_hidden_normed(&mut self, cache_id: &str) -> Vec<f32> {
        // The backbone's scratch.last_hidden holds the last prefill/decode's
        // pre-norm hidden. Re-apply final_norm to get the post-norm vector.
        let h = self.cfg.hidden_size;
        let mut ctx = B::new_context();
        let mut normed = B::alloc(h);
        B::rms_norm(
            &mut ctx,
            &self.backbone.scratch.last_hidden,
            &self.backbone.final_norm_w,
            self.cfg.rms_norm_eps as f32,
            &mut normed,
            1,
            h,
        );
        B::sync(&mut ctx);
        let _ = cache_id; // cache_id not used yet; reserved for per-session variants
        B::to_vec(&normed, h)
    }

    /// Access the codec embedding buffer for external use (e.g. SubTalker
    /// needs the first codec token's embedding as its starting input).
    pub fn codec_embed_lookup(&self, token: u32) -> Vec<f32> {
        let mut ctx = B::new_context();
        let h = self.cfg.hidden_size;
        let mut out = B::alloc(h);
        B::embedding_lookup(&mut ctx, &self.codec_embedding, &[token], &mut out, h);
        B::sync(&mut ctx);
        B::to_vec(&out, h)
    }
}

// ── SubTalker (Code Predictor) ──────────────────────────────────────────
//
// Smaller transformer (4-5 layers) that predicts codec tokens 1..N-1 given
// the Talker's post-norm hidden state and the first codec token's
// embedding. Has per-codebook embedding tables and per-codebook output
// heads (N-1 of each).

pub struct Qwen3TtsSubTalker<B: Backend> {
    pub cfg: TalkerConfig,

    /// Backbone — 4-5 layer Qwen3 with `code_predictor_*` dims.
    pub backbone: LlamaFamilyModel<B>,

    /// Projection from Talker's `hidden_size` to SubTalker's
    /// `code_predictor_hidden_size`. `None` if sizes match.
    pub projection: Option<Box<dyn Linear<B>>>,

    /// Per-codebook embedding tables: `codec_embeddings[i]` is
    /// `[code_predictor_vocab_size * hidden_size]` for i in 0..num_code_groups-1.
    pub codec_embeddings: Vec<B::Buffer>,

    /// Per-codebook output heads: `lm_heads[i]` is
    /// `code_predictor_hidden_size -> code_predictor_vocab_size`.
    pub lm_heads: Vec<Box<dyn Linear<B>>>,
}

impl<B: Backend> Qwen3TtsSubTalker<B> {
    pub fn new(cfg: TalkerConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
        let cp_h = cfg.code_predictor_hidden_size;
        let cp_im = cp_h * 3; // ~3072 for 1024 hidden; matches existing impl

        // Backbone config: same pattern as Qwen3 but smaller dims.
        let backbone_cfg = LlamaFamilyConfig {
            hidden_size: cp_h,
            intermediate_size: cp_im,
            num_heads: cfg.code_predictor_num_heads,
            num_kv_heads: cfg.code_predictor_num_kv_heads,
            head_dim: cp_h / cfg.code_predictor_num_heads,
            num_layers: cfg.code_predictor_num_layers,
            vocab_size: cfg.code_predictor_vocab_size,
            max_seq_len: cfg.max_position_embeddings,
            rms_norm_eps: cfg.rms_norm_eps as f32,
            rope_theta: cfg.rope_theta,
            has_qk_norm: true,
            sliding_window: 0,
        };

        let cp_loader = PrefixedLoader::new(loader, "talker.code_predictor.");
        let backbone = LlamaFamilyModel::<B>::new_backbone_only(backbone_cfg, &cp_loader)?;

        // Optional projection talker.code_predictor.small_to_mtp_projection.
        let projection = if cfg.hidden_size != cp_h {
            Some(loader.load_linear("talker.code_predictor.small_to_mtp_projection")?)
        } else {
            None
        };

        // Per-codebook embeddings + heads (num_code_groups - 1).
        let n_extra = cfg.num_code_groups - 1;
        let mut codec_embeddings = Vec::with_capacity(n_extra);
        for i in 0..n_extra {
            let name = format!("talker.code_predictor.model.codec_embedding.{i}.weight");
            codec_embeddings.push(loader.load_tensor(&name)?);
        }
        let mut lm_heads = Vec::with_capacity(n_extra);
        for i in 0..n_extra {
            let name = format!("talker.code_predictor.lm_head.{i}");
            lm_heads.push(loader.load_linear(&name)?);
        }

        Ok(Self {
            cfg,
            backbone,
            projection,
            codec_embeddings,
            lm_heads,
        })
    }

    /// Predict codec tokens 1..num_code_groups-1 given Talker's post-norm
    /// hidden state `[hidden_size]` and the first codec token's embedding
    /// `[hidden_size]`. Returns `num_code_groups - 1` token IDs.
    ///
    /// Sampling uses greedy argmax; caller can wrap for top-k / temperature
    /// sampling by decoding logits.
    pub fn predict_greedy(
        &mut self,
        cache_id: &str,
        talker_hidden: &[f32],
        first_codec_embed: &[f32],
    ) -> Vec<u32> {
        let h_talker = self.cfg.hidden_size;
        let cp_h = self.cfg.code_predictor_hidden_size;
        let n_extra = self.cfg.num_code_groups - 1;

        debug_assert_eq!(talker_hidden.len(), h_talker);
        debug_assert_eq!(first_codec_embed.len(), h_talker);

        // Fresh KV cache per predict call.
        self.backbone.kv_caches.remove(cache_id);

        // 1. Concat [talker_hidden | first_codec_embed] → [2, h_talker].
        let mut combined = Vec::with_capacity(2 * h_talker);
        combined.extend_from_slice(talker_hidden);
        combined.extend_from_slice(first_codec_embed);

        // 2. Optional projection h_talker → cp_h (per token).
        let projected: Vec<f32> = if let Some(ref proj) = self.projection {
            let mut ctx = B::new_context();
            let in_buf = B::from_slice(&combined);
            let mut out = B::alloc(2 * cp_h);
            proj.forward(&mut ctx, &in_buf, &mut out, 2);
            B::sync(&mut ctx);
            B::to_vec(&out, 2 * cp_h)
        } else {
            combined
        };

        // 3. Prefill through SubTalker backbone (2 tokens → pre-norm last).
        let _ = self.backbone.prefill_from_embeds(cache_id, &projected, 2);

        // 4. Autoregressive loop: for each codebook i, apply lm_heads[i] on
        // the current post-norm hidden, greedy sample, embed via
        // codec_embeddings[i], decode next step.
        let mut pos: u32 = 2;
        let mut predicted = Vec::with_capacity(n_extra);
        let vocab = self.cfg.code_predictor_vocab_size;

        for i in 0..n_extra {
            // Get post-norm last hidden via backbone.final_norm_w.
            let mut ctx = B::new_context();
            let mut normed = B::alloc(cp_h);
            B::rms_norm(
                &mut ctx,
                &self.backbone.scratch.last_hidden,
                &self.backbone.final_norm_w,
                self.cfg.rms_norm_eps as f32,
                &mut normed,
                1,
                cp_h,
            );
            let mut logits = B::alloc(vocab);
            self.lm_heads[i].forward(&mut ctx, &normed, &mut logits, 1);
            B::sync(&mut ctx);
            let logits_host = B::to_vec(&logits, vocab);

            // Greedy argmax.
            let token = logits_host
                .iter()
                .enumerate()
                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
                .map(|(idx, _)| idx as u32)
                .unwrap_or(0);
            predicted.push(token);

            // If this is the last codebook, no need to decode next step.
            if i == n_extra - 1 {
                break;
            }

            // Embed via codec_embeddings[i] (dim = h_talker).
            let mut ctx2 = B::new_context();
            let mut emb = B::alloc(h_talker);
            B::embedding_lookup(
                &mut ctx2,
                &self.codec_embeddings[i],
                &[token],
                &mut emb,
                h_talker,
            );
            B::sync(&mut ctx2);
            let emb_host = B::to_vec(&emb, h_talker);

            // Project h_talker → cp_h if needed.
            let next_embed: Vec<f32> = if let Some(ref proj) = self.projection {
                let mut ctx3 = B::new_context();
                let in_buf = B::from_slice(&emb_host);
                let mut out = B::alloc(cp_h);
                proj.forward(&mut ctx3, &in_buf, &mut out, 1);
                B::sync(&mut ctx3);
                B::to_vec(&out, cp_h)
            } else {
                emb_host
            };

            // Decode one step → advances scratch.last_hidden.
            let _ = self.backbone.decode_from_embed(cache_id, &next_embed, pos);
            pos += 1;
        }

        predicted
    }
}