Skip to main content

cake_core/models/common/
text_model.rs

1use std::collections::HashSet;
2
3use anyhow::Result;
4use candle_core::{Device, IndexOp, Tensor};
5use candle_nn::{linear_no_bias as linear, Embedding, Linear, Module, RmsNorm};
6use candle_transformers::generation::{LogitsProcessor, Sampling};
7use tokenizers::Tokenizer;
8
9use super::EosTokenId;
10use crate::{
11    cake::{Context, Forwarder},
12    models::Token,
13};
14
15/// Load the tokenizer and resolve EOS token ID(s).
16/// `default_eos_token` is the model-specific fallback (e.g. "<|eot_id|>" for LLaMA,
17/// "<|endoftext|>" for Qwen2).
18pub fn load_tokenizer(
19    ctx: &Context,
20    default_eos_token: &str,
21) -> Result<(Tokenizer, Option<EosTokenId>)> {
22    let tokenizer_filename = ctx.data_path.join("tokenizer.json");
23
24    log::info!("loading tokenizer from {}", tokenizer_filename.display());
25
26    let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(anyhow::Error::msg)?;
27
28    let config = ctx.config.as_ref().expect("No config specified");
29
30    let eos_token_id = if config.eos_token_id.is_some() {
31        config.eos_token_id.clone()
32    } else {
33        // Fallback: try to resolve from tokenizer vocabulary
34        tokenizer
35            .token_to_id(default_eos_token)
36            .map(EosTokenId::Single)
37    };
38
39    Ok((tokenizer, eos_token_id))
40}
41
42/// Apply repeat penalty entirely on GPU to avoid costly GPU↔CPU round-trips.
43///
44/// The upstream `candle_transformers::utils::apply_repeat_penalty` copies the entire
45/// logits tensor (vocab_size × 4 bytes ≈ 600 KB) to CPU, modifies a handful of elements,
46/// then copies everything back.  This forces a full GPU synchronisation and two large PCIe
47/// transfers per token.
48///
49/// This implementation stays on-device: it selects only the penalty positions, computes
50/// sign-aware multipliers, and scatters the deltas back with `index_add`.
51fn apply_repeat_penalty_gpu(
52    logits: &Tensor,
53    penalty: f32,
54    context: &[u32],
55) -> Result<Tensor> {
56    // Deduplicate tokens (same semantics as the upstream version).
57    let mut seen = HashSet::new();
58    let unique: Vec<u32> = context
59        .iter()
60        .filter(|t| seen.insert(**t))
61        .copied()
62        .collect();
63
64    if unique.is_empty() {
65        return Ok(logits.clone());
66    }
67
68    let device = logits.device();
69    let dtype = logits.dtype();
70    let indices = Tensor::new(unique.as_slice(), device)?;
71
72    // Gather logits at penalty positions  (N elements, tiny).
73    let selected = logits.index_select(&indices, 0)?;
74
75    // Sign-aware multiplier:  1/penalty for logits ≥ 0,  penalty for logits < 0.
76    let is_non_negative = selected.ge(0f32)?;
77    let recip = Tensor::new(1.0f32 / penalty, device)?
78        .to_dtype(dtype)?
79        .broadcast_as(selected.shape())?;
80    let pen = Tensor::new(penalty, device)?
81        .to_dtype(dtype)?
82        .broadcast_as(selected.shape())?;
83    let mult = is_non_negative.where_cond(&recip, &pen)?;
84
85    // delta = selected * mult - selected  (what to add to the original logits).
86    let penalized = (&selected * &mult)?;
87    let delta = (&penalized - &selected)?;
88
89    Ok(logits.index_add(&indices, &delta, 0)?)
90}
91
92/// Create the logit sampling logic from the context.
93pub fn create_logits_processor(ctx: &Context) -> LogitsProcessor {
94    let temperature = ctx.args.temperature;
95    let sampling = if temperature <= 0. {
96        Sampling::ArgMax
97    } else {
98        match (ctx.args.top_k, ctx.args.top_p) {
99            // Gumbel-Softmax keeps everything on GPU: generates random noise,
100            // adds to logits/temperature, and takes argmax — only 4 bytes
101            // transferred instead of the full 600 KB vocabulary vector.
102            (None, None) => Sampling::GumbelSoftmax { temperature },
103            (Some(k), None) => Sampling::TopK { k, temperature },
104            (None, Some(p)) => Sampling::TopP { p, temperature },
105            (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
106        }
107    };
108    LogitsProcessor::from_sampling(ctx.args.seed, sampling)
109}
110
111/// Shared base for decoder-only text models (LLaMA, Qwen2, Qwen3.5, etc.).
112///
113/// Contains all the state and logic that is identical across model architectures:
114/// embedding, transformer blocks, final norm, lm_head, tokenizer, sampling, and
115/// the forward/generation loop.
116pub struct TextModelBase {
117    pub ctx: Context,
118
119    pub tokenizer: Tokenizer,
120    pub embedding: Embedding,
121    pub eos_token_id: Option<EosTokenId>,
122    pub index_pos: usize,
123    pub generated: usize,
124    pub prompt_len: usize,
125
126    pub blocks: Vec<Box<dyn Forwarder>>,
127
128    pub ln_f: RmsNorm,
129    pub lm_head: Linear,
130
131    pub logits_processor: LogitsProcessor,
132
133    pub tokens: Vec<u32>,
134}
135
136impl TextModelBase {
137    /// Load the shared model structure from the context.
138    /// `default_eos_token` is the model-specific fallback EOS string.
139    /// The type parameter `B` determines which block type to use for local layers.
140    pub async fn load<B: Forwarder + 'static>(
141        ctx: &mut Context,
142        default_eos_token: &str,
143    ) -> Result<Self> {
144        let config = ctx.config.as_ref().expect("No config specified");
145        let var_builder = ctx.var_builder.as_ref().expect("No var_builder specified");
146        let prefix = &config.model_prefix;
147
148        log::info!("loading embeddings (prefix={}) ...", prefix);
149        let embedding: Embedding = candle_nn::embedding(
150            config.vocab_size,
151            config.hidden_size,
152            var_builder.pp(format!("{prefix}.embed_tokens")),
153        )?;
154
155        log::info!("loading lm_head ...");
156        let lm_head = if config.tie_word_embeddings {
157            log::info!("  using tied word embeddings (lm_head = embed_tokens)");
158            Linear::new(embedding.embeddings().clone(), None)
159        } else {
160            // Try root-level lm_head first (LLaMA/Qwen2), then prefixed (Qwen3.5)
161            match linear(
162                config.hidden_size,
163                config.vocab_size,
164                var_builder.pp("lm_head"),
165            ) {
166                Ok(l) => l,
167                Err(_) => linear(
168                    config.hidden_size,
169                    config.vocab_size,
170                    var_builder.pp(format!("{prefix}.lm_head")),
171                )?,
172            }
173        };
174
175        log::info!("loading {prefix}.norm ...");
176        let ln_f = crate::models::common::load_rms_norm(
177            config.hidden_size,
178            config.rms_norm_eps,
179            config.residual_rms_norm,
180            var_builder.pp(format!("{prefix}.norm")),
181        )?;
182
183        log::info!("loading {} blocks ...", config.num_hidden_layers);
184
185        // Two-pass loading: local layers first (no network wait), then remote
186        // layers (may block until workers finish loading). This overlaps
187        // master's local layer loading with worker startup time.
188        let mut blocks: Vec<Option<Box<dyn Forwarder>>> =
189            (0..config.num_hidden_layers).map(|_| None).collect();
190
191        // Pass 1: load local layers
192        for i in 0..config.num_hidden_layers {
193            let block_layer_name = format!("{prefix}.layers.{i}");
194            if ctx.topology.get_node_for_layer(&block_layer_name).is_none() {
195                log::info!("loading {} ...", &block_layer_name);
196                blocks[i] = Some(B::load(block_layer_name, ctx)?);
197            }
198        }
199
200        // Pass 2: connect to remote layers
201        for i in 0..config.num_hidden_layers {
202            let block_layer_name = format!("{prefix}.layers.{i}");
203            if let Some((_node_name, node)) = ctx.topology.get_node_for_layer(&block_layer_name) {
204                log::info!("connecting {} to {} ...", &block_layer_name, &node.host);
205                blocks[i] = Some(Box::new(
206                    crate::cake::Client::new(
207                        ctx.device.clone(),
208                        &node.host,
209                        &block_layer_name,
210                        ctx.args.cluster_key.as_deref(),
211                    )
212                    .await?,
213                ));
214            }
215        }
216
217        let blocks: Vec<Box<dyn Forwarder>> = blocks.into_iter().map(|b| b.unwrap()).collect();
218
219        for block in &blocks {
220            log::info!("  {}", block)
221        }
222
223        let (tokenizer, eos_token_id) = load_tokenizer(ctx, default_eos_token)?;
224        let tokens = vec![];
225
226        let logits_processor = create_logits_processor(ctx);
227        let index_pos = 0;
228
229        log::info!(
230            "model loaded - mem={}",
231            human_bytes::human_bytes(memory_stats::memory_stats().unwrap().physical_mem as f64)
232        );
233
234        let generated = 0;
235
236        Ok(Self {
237            tokenizer,
238            tokens,
239            generated,
240            eos_token_id,
241            index_pos,
242            prompt_len: 0,
243            ctx: ctx.clone(),
244            embedding,
245            blocks,
246            ln_f,
247            lm_head,
248            logits_processor,
249        })
250    }
251
252    /// Forward pass through all blocks.
253    pub async fn forward(&mut self, x: &Tensor, idx: usize) -> Result<Tensor> {
254        let forward_start = std::time::Instant::now();
255        let (_batch_size, seq_len) = x.dims2()?;
256
257        let emb_start = std::time::Instant::now();
258        let mut x = self.embedding.forward(x)?;
259        let emb_elapsed = emb_start.elapsed();
260
261        let num_blocks = self.blocks.len();
262        let mut block_idx = 0;
263        let mut local_elapsed = std::time::Duration::ZERO;
264        let mut local_count: usize = 0;
265
266        while block_idx < num_blocks {
267            let curr_block_id = self.blocks[block_idx].ident().to_owned();
268            if curr_block_id == "local" {
269                let local_start = std::time::Instant::now();
270                x = self.blocks[block_idx]
271                    .forward_mut(&x, idx, block_idx, &mut self.ctx)
272                    .await
273                    .map_err(|e| {
274                        anyhow!("error in forward operation of local block {block_idx}: {e}")
275                    })?;
276                local_elapsed += local_start.elapsed();
277                local_count += 1;
278
279                block_idx += 1;
280            } else {
281                // collect all contiguous layers running on the same worker
282                let mut batch = vec![];
283                let first = block_idx;
284                while block_idx < num_blocks && self.blocks[block_idx].ident() == curr_block_id {
285                    batch.push((
286                        self.blocks[block_idx].layer_name().to_string(),
287                        idx,
288                        block_idx,
289                    ));
290                    block_idx += 1;
291                }
292
293                let num_layers = batch.len();
294                let batch_start = std::time::Instant::now();
295                x = self.blocks[first]
296                    .forward_batch(&x, batch, &mut self.ctx)
297                    .await
298                    .map_err(|e| {
299                        anyhow!(
300                            "error in forward batch for blocks {first}..{block_idx} on {}: {e}",
301                            &curr_block_id
302                        )
303                    })?;
304                let batch_elapsed = batch_start.elapsed();
305                log::debug!(
306                    "  worker {} layers {}-{} ({} layers): {:.1}ms",
307                    &curr_block_id,
308                    first,
309                    block_idx - 1,
310                    num_layers,
311                    batch_elapsed.as_secs_f64() * 1000.0
312                );
313            }
314        }
315
316        let head_start = std::time::Instant::now();
317        let x = self
318            .ln_f
319            .forward(&x)
320            .map_err(|e| anyhow!("error in ln_f.forward: {e}"))?;
321
322        let x = x
323            .i((.., seq_len - 1, ..))
324            .map_err(|e| anyhow!("error in x.i: {e}"))?
325            .contiguous()
326            .map_err(|e| anyhow!("error in x.i.contiguous: {e}"))?;
327
328        let logits = self
329            .lm_head
330            .forward(&x)
331            .map_err(|e| anyhow!("error in lm_head.forward: {e}"))?;
332        let head_elapsed = head_start.elapsed();
333
334        let total_elapsed = forward_start.elapsed();
335        log::debug!(
336            "  forward total={:.1}ms emb={:.1}ms local={:.1}ms ({} blocks) head={:.1}ms",
337            total_elapsed.as_secs_f64() * 1000.0,
338            emb_elapsed.as_secs_f64() * 1000.0,
339            local_elapsed.as_secs_f64() * 1000.0,
340            local_count,
341            head_elapsed.as_secs_f64() * 1000.0,
342        );
343
344        Ok(logits)
345    }
346
347    /// Tokenize a prompt string and set up token state for generation.
348    pub fn prepare_prompt(&mut self, dialog: &str) -> Result<()> {
349        // make sure we start clean
350        self.tokens.clear();
351        self.ctx.cache.as_mut().expect("No cache specified").clear();
352        self.index_pos = 0;
353
354        log::debug!("dialog={}", dialog);
355
356        // tokenize raw
357        self.tokens = self
358            .tokenizer
359            .encode(dialog, false) // do not add special tokens as we already added them
360            .map_err(anyhow::Error::msg)?
361            .get_ids()
362            .to_vec();
363
364        log::debug!("encoded={:?}", &self.tokens);
365        log::debug!("history tokens: {}", self.tokens.len());
366
367        // Track prompt length for repeat penalty scoping
368        self.prompt_len = self.tokens.len();
369
370        Ok(())
371    }
372
373    /// Generate the next token. Assumes `prepare_prompt()` has been called for the first token.
374    pub async fn next_token(&mut self, index: usize) -> Result<Token> {
375        log::trace!("model.next_token({index})");
376
377        let num_tokens = self.tokens.len();
378        let (context_size, context_index) = if self
379            .ctx
380            .cache
381            .as_ref()
382            .expect("No cache specified")
383            .with_kv_cache()
384            && index > 0
385        {
386            (1, self.index_pos)
387        } else {
388            (num_tokens, 0)
389        };
390
391        let context_offset = num_tokens.saturating_sub(context_size);
392        let context_tokens = &self.tokens[context_offset..];
393        let num_context_tokens = context_tokens.len();
394
395        let input = Tensor::new(context_tokens, &self.ctx.device)?
396            .unsqueeze(0)
397            .map_err(|e| anyhow!("error squeezing context tokens: {e}"))?;
398
399        let logits = self
400            .forward(&input, context_index)
401            .await
402            .map_err(|e| anyhow!("error in model.forward: {e}"))?;
403
404        let post_start = std::time::Instant::now();
405
406        let logits = logits
407            .squeeze(0)
408            .map_err(|e| anyhow!("error squeezing logits: {e}"))?;
409
410        // Apply repeat penalty only to generated tokens (not prompt tokens)
411        let penalty_start = std::time::Instant::now();
412        let logits = if self.ctx.args.repeat_penalty == 1. {
413            logits
414        } else {
415            let generated_start = self.prompt_len;
416            let penalty_tokens = &self.tokens[generated_start..];
417            if penalty_tokens.is_empty() {
418                logits
419            } else {
420                let start_at = penalty_tokens
421                    .len()
422                    .saturating_sub(self.ctx.args.repeat_last_n);
423                apply_repeat_penalty_gpu(
424                    &logits,
425                    self.ctx.args.repeat_penalty,
426                    &penalty_tokens[start_at..],
427                )?
428            }
429        };
430        let penalty_elapsed = penalty_start.elapsed();
431        self.index_pos += num_context_tokens;
432
433        let sample_start = std::time::Instant::now();
434        let next_token = self
435            .logits_processor
436            .sample(&logits)
437            .map_err(|e| anyhow!("error sampling logits {logits}: {e}"))?;
438        let sample_elapsed = sample_start.elapsed();
439
440        self.generated += 1;
441        self.tokens.push(next_token);
442
443        let is_end_of_stream = self
444            .eos_token_id
445            .as_ref()
446            .map_or(false, |eos| eos.is_eos(next_token));
447
448        let decode_start = std::time::Instant::now();
449        let text = match self.tokenizer.decode(&[next_token], false) {
450            Ok(s) => Some(s),
451            Err(e) => {
452                log::error!("could not decode token {next_token}: {e}");
453                None
454            }
455        };
456        let decode_elapsed = decode_start.elapsed();
457        let post_elapsed = post_start.elapsed();
458
459        log::debug!(
460            "  post-forward: total={:.1}ms penalty={:.1}ms sample={:.1}ms decode={:.1}ms",
461            post_elapsed.as_secs_f64() * 1000.0,
462            penalty_elapsed.as_secs_f64() * 1000.0,
463            sample_elapsed.as_secs_f64() * 1000.0,
464            decode_elapsed.as_secs_f64() * 1000.0,
465        );
466
467        Ok(Token {
468            id: next_token,
469            text,
470            is_end_of_stream,
471        })
472    }
473
474    /// Reset all generation state.
475    pub fn reset(&mut self) {
476        self.tokens.clear();
477        self.ctx.cache.as_mut().expect("No cache specified").clear();
478        self.index_pos = 0;
479        self.generated = 0;
480        self.prompt_len = 0;
481
482        // Clear any stale CUDA error state left by tensor cleanup (CudaSlice drops).
483        // cudarc's error_state is an atomic that gets poisoned by internal operations
484        // (e.g. SyncOnDrop event recording, async memory frees) and causes the NEXT
485        // inference request to fail via check_err(). Clearing it here prevents the
486        // alternating success/failure pattern.
487        #[cfg(feature = "cuda")]
488        if let Device::Cuda(cuda_dev) = &self.ctx.device {
489            let _ = cuda_dev.cuda_stream().context().bind_to_thread();
490        }
491    }
492
493    /// Notify all remote blocks of session end (clears their KV caches).
494    pub async fn goodbye(&mut self) -> Result<()> {
495        let num_blocks = self.blocks.len();
496        let mut block_idx = 0;
497        while block_idx < num_blocks {
498            self.blocks[block_idx]
499                .goodbye()
500                .await
501                .map_err(|e| anyhow!("error in goodbye operation for block {block_idx}: {e}"))?;
502            block_idx += 1;
503        }
504        Ok(())
505    }
506}