1use std::collections::HashSet;
2
3use anyhow::Result;
4use candle_core::{Device, IndexOp, Tensor};
5use candle_nn::{linear_no_bias as linear, Embedding, Linear, Module, RmsNorm};
6use candle_transformers::generation::{LogitsProcessor, Sampling};
7use tokenizers::Tokenizer;
8
9use super::EosTokenId;
10use crate::{
11 cake::{Context, Forwarder},
12 models::Token,
13};
14
15pub fn load_tokenizer(
19 ctx: &Context,
20 default_eos_token: &str,
21) -> Result<(Tokenizer, Option<EosTokenId>)> {
22 let tokenizer_filename = ctx.data_path.join("tokenizer.json");
23
24 log::info!("loading tokenizer from {}", tokenizer_filename.display());
25
26 let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(anyhow::Error::msg)?;
27
28 let config = ctx.config.as_ref().expect("No config specified");
29
30 let eos_token_id = if config.eos_token_id.is_some() {
31 config.eos_token_id.clone()
32 } else {
33 tokenizer
35 .token_to_id(default_eos_token)
36 .map(EosTokenId::Single)
37 };
38
39 Ok((tokenizer, eos_token_id))
40}
41
42fn apply_repeat_penalty_gpu(
52 logits: &Tensor,
53 penalty: f32,
54 context: &[u32],
55) -> Result<Tensor> {
56 let mut seen = HashSet::new();
58 let unique: Vec<u32> = context
59 .iter()
60 .filter(|t| seen.insert(**t))
61 .copied()
62 .collect();
63
64 if unique.is_empty() {
65 return Ok(logits.clone());
66 }
67
68 let device = logits.device();
69 let dtype = logits.dtype();
70 let indices = Tensor::new(unique.as_slice(), device)?;
71
72 let selected = logits.index_select(&indices, 0)?;
74
75 let is_non_negative = selected.ge(0f32)?;
77 let recip = Tensor::new(1.0f32 / penalty, device)?
78 .to_dtype(dtype)?
79 .broadcast_as(selected.shape())?;
80 let pen = Tensor::new(penalty, device)?
81 .to_dtype(dtype)?
82 .broadcast_as(selected.shape())?;
83 let mult = is_non_negative.where_cond(&recip, &pen)?;
84
85 let penalized = (&selected * &mult)?;
87 let delta = (&penalized - &selected)?;
88
89 Ok(logits.index_add(&indices, &delta, 0)?)
90}
91
92pub fn create_logits_processor(ctx: &Context) -> LogitsProcessor {
94 let temperature = ctx.args.temperature;
95 let sampling = if temperature <= 0. {
96 Sampling::ArgMax
97 } else {
98 match (ctx.args.top_k, ctx.args.top_p) {
99 (None, None) => Sampling::GumbelSoftmax { temperature },
103 (Some(k), None) => Sampling::TopK { k, temperature },
104 (None, Some(p)) => Sampling::TopP { p, temperature },
105 (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
106 }
107 };
108 LogitsProcessor::from_sampling(ctx.args.seed, sampling)
109}
110
111pub struct TextModelBase {
117 pub ctx: Context,
118
119 pub tokenizer: Tokenizer,
120 pub embedding: Embedding,
121 pub eos_token_id: Option<EosTokenId>,
122 pub index_pos: usize,
123 pub generated: usize,
124 pub prompt_len: usize,
125
126 pub blocks: Vec<Box<dyn Forwarder>>,
127
128 pub ln_f: RmsNorm,
129 pub lm_head: Linear,
130
131 pub logits_processor: LogitsProcessor,
132
133 pub tokens: Vec<u32>,
134}
135
136impl TextModelBase {
137 pub async fn load<B: Forwarder + 'static>(
141 ctx: &mut Context,
142 default_eos_token: &str,
143 ) -> Result<Self> {
144 let config = ctx.config.as_ref().expect("No config specified");
145 let var_builder = ctx.var_builder.as_ref().expect("No var_builder specified");
146 let prefix = &config.model_prefix;
147
148 log::info!("loading embeddings (prefix={}) ...", prefix);
149 let embedding: Embedding = candle_nn::embedding(
150 config.vocab_size,
151 config.hidden_size,
152 var_builder.pp(format!("{prefix}.embed_tokens")),
153 )?;
154
155 log::info!("loading lm_head ...");
156 let lm_head = if config.tie_word_embeddings {
157 log::info!(" using tied word embeddings (lm_head = embed_tokens)");
158 Linear::new(embedding.embeddings().clone(), None)
159 } else {
160 match linear(
162 config.hidden_size,
163 config.vocab_size,
164 var_builder.pp("lm_head"),
165 ) {
166 Ok(l) => l,
167 Err(_) => linear(
168 config.hidden_size,
169 config.vocab_size,
170 var_builder.pp(format!("{prefix}.lm_head")),
171 )?,
172 }
173 };
174
175 log::info!("loading {prefix}.norm ...");
176 let ln_f = crate::models::common::load_rms_norm(
177 config.hidden_size,
178 config.rms_norm_eps,
179 config.residual_rms_norm,
180 var_builder.pp(format!("{prefix}.norm")),
181 )?;
182
183 log::info!("loading {} blocks ...", config.num_hidden_layers);
184
185 let mut blocks: Vec<Option<Box<dyn Forwarder>>> =
189 (0..config.num_hidden_layers).map(|_| None).collect();
190
191 for i in 0..config.num_hidden_layers {
193 let block_layer_name = format!("{prefix}.layers.{i}");
194 if ctx.topology.get_node_for_layer(&block_layer_name).is_none() {
195 log::info!("loading {} ...", &block_layer_name);
196 blocks[i] = Some(B::load(block_layer_name, ctx)?);
197 }
198 }
199
200 for i in 0..config.num_hidden_layers {
202 let block_layer_name = format!("{prefix}.layers.{i}");
203 if let Some((_node_name, node)) = ctx.topology.get_node_for_layer(&block_layer_name) {
204 log::info!("connecting {} to {} ...", &block_layer_name, &node.host);
205 blocks[i] = Some(Box::new(
206 crate::cake::Client::new(
207 ctx.device.clone(),
208 &node.host,
209 &block_layer_name,
210 ctx.args.cluster_key.as_deref(),
211 )
212 .await?,
213 ));
214 }
215 }
216
217 let blocks: Vec<Box<dyn Forwarder>> = blocks.into_iter().map(|b| b.unwrap()).collect();
218
219 for block in &blocks {
220 log::info!(" {}", block)
221 }
222
223 let (tokenizer, eos_token_id) = load_tokenizer(ctx, default_eos_token)?;
224 let tokens = vec![];
225
226 let logits_processor = create_logits_processor(ctx);
227 let index_pos = 0;
228
229 log::info!(
230 "model loaded - mem={}",
231 human_bytes::human_bytes(memory_stats::memory_stats().unwrap().physical_mem as f64)
232 );
233
234 let generated = 0;
235
236 Ok(Self {
237 tokenizer,
238 tokens,
239 generated,
240 eos_token_id,
241 index_pos,
242 prompt_len: 0,
243 ctx: ctx.clone(),
244 embedding,
245 blocks,
246 ln_f,
247 lm_head,
248 logits_processor,
249 })
250 }
251
252 pub async fn forward(&mut self, x: &Tensor, idx: usize) -> Result<Tensor> {
254 let forward_start = std::time::Instant::now();
255 let (_batch_size, seq_len) = x.dims2()?;
256
257 let emb_start = std::time::Instant::now();
258 let mut x = self.embedding.forward(x)?;
259 let emb_elapsed = emb_start.elapsed();
260
261 let num_blocks = self.blocks.len();
262 let mut block_idx = 0;
263 let mut local_elapsed = std::time::Duration::ZERO;
264 let mut local_count: usize = 0;
265
266 while block_idx < num_blocks {
267 let curr_block_id = self.blocks[block_idx].ident().to_owned();
268 if curr_block_id == "local" {
269 let local_start = std::time::Instant::now();
270 x = self.blocks[block_idx]
271 .forward_mut(&x, idx, block_idx, &mut self.ctx)
272 .await
273 .map_err(|e| {
274 anyhow!("error in forward operation of local block {block_idx}: {e}")
275 })?;
276 local_elapsed += local_start.elapsed();
277 local_count += 1;
278
279 block_idx += 1;
280 } else {
281 let mut batch = vec![];
283 let first = block_idx;
284 while block_idx < num_blocks && self.blocks[block_idx].ident() == curr_block_id {
285 batch.push((
286 self.blocks[block_idx].layer_name().to_string(),
287 idx,
288 block_idx,
289 ));
290 block_idx += 1;
291 }
292
293 let num_layers = batch.len();
294 let batch_start = std::time::Instant::now();
295 x = self.blocks[first]
296 .forward_batch(&x, batch, &mut self.ctx)
297 .await
298 .map_err(|e| {
299 anyhow!(
300 "error in forward batch for blocks {first}..{block_idx} on {}: {e}",
301 &curr_block_id
302 )
303 })?;
304 let batch_elapsed = batch_start.elapsed();
305 log::debug!(
306 " worker {} layers {}-{} ({} layers): {:.1}ms",
307 &curr_block_id,
308 first,
309 block_idx - 1,
310 num_layers,
311 batch_elapsed.as_secs_f64() * 1000.0
312 );
313 }
314 }
315
316 let head_start = std::time::Instant::now();
317 let x = self
318 .ln_f
319 .forward(&x)
320 .map_err(|e| anyhow!("error in ln_f.forward: {e}"))?;
321
322 let x = x
323 .i((.., seq_len - 1, ..))
324 .map_err(|e| anyhow!("error in x.i: {e}"))?
325 .contiguous()
326 .map_err(|e| anyhow!("error in x.i.contiguous: {e}"))?;
327
328 let logits = self
329 .lm_head
330 .forward(&x)
331 .map_err(|e| anyhow!("error in lm_head.forward: {e}"))?;
332 let head_elapsed = head_start.elapsed();
333
334 let total_elapsed = forward_start.elapsed();
335 log::debug!(
336 " forward total={:.1}ms emb={:.1}ms local={:.1}ms ({} blocks) head={:.1}ms",
337 total_elapsed.as_secs_f64() * 1000.0,
338 emb_elapsed.as_secs_f64() * 1000.0,
339 local_elapsed.as_secs_f64() * 1000.0,
340 local_count,
341 head_elapsed.as_secs_f64() * 1000.0,
342 );
343
344 Ok(logits)
345 }
346
347 pub fn prepare_prompt(&mut self, dialog: &str) -> Result<()> {
349 self.tokens.clear();
351 self.ctx.cache.as_mut().expect("No cache specified").clear();
352 self.index_pos = 0;
353
354 log::debug!("dialog={}", dialog);
355
356 self.tokens = self
358 .tokenizer
359 .encode(dialog, false) .map_err(anyhow::Error::msg)?
361 .get_ids()
362 .to_vec();
363
364 log::debug!("encoded={:?}", &self.tokens);
365 log::debug!("history tokens: {}", self.tokens.len());
366
367 self.prompt_len = self.tokens.len();
369
370 Ok(())
371 }
372
373 pub async fn next_token(&mut self, index: usize) -> Result<Token> {
375 log::trace!("model.next_token({index})");
376
377 let num_tokens = self.tokens.len();
378 let (context_size, context_index) = if self
379 .ctx
380 .cache
381 .as_ref()
382 .expect("No cache specified")
383 .with_kv_cache()
384 && index > 0
385 {
386 (1, self.index_pos)
387 } else {
388 (num_tokens, 0)
389 };
390
391 let context_offset = num_tokens.saturating_sub(context_size);
392 let context_tokens = &self.tokens[context_offset..];
393 let num_context_tokens = context_tokens.len();
394
395 let input = Tensor::new(context_tokens, &self.ctx.device)?
396 .unsqueeze(0)
397 .map_err(|e| anyhow!("error squeezing context tokens: {e}"))?;
398
399 let logits = self
400 .forward(&input, context_index)
401 .await
402 .map_err(|e| anyhow!("error in model.forward: {e}"))?;
403
404 let post_start = std::time::Instant::now();
405
406 let logits = logits
407 .squeeze(0)
408 .map_err(|e| anyhow!("error squeezing logits: {e}"))?;
409
410 let penalty_start = std::time::Instant::now();
412 let logits = if self.ctx.args.repeat_penalty == 1. {
413 logits
414 } else {
415 let generated_start = self.prompt_len;
416 let penalty_tokens = &self.tokens[generated_start..];
417 if penalty_tokens.is_empty() {
418 logits
419 } else {
420 let start_at = penalty_tokens
421 .len()
422 .saturating_sub(self.ctx.args.repeat_last_n);
423 apply_repeat_penalty_gpu(
424 &logits,
425 self.ctx.args.repeat_penalty,
426 &penalty_tokens[start_at..],
427 )?
428 }
429 };
430 let penalty_elapsed = penalty_start.elapsed();
431 self.index_pos += num_context_tokens;
432
433 let sample_start = std::time::Instant::now();
434 let next_token = self
435 .logits_processor
436 .sample(&logits)
437 .map_err(|e| anyhow!("error sampling logits {logits}: {e}"))?;
438 let sample_elapsed = sample_start.elapsed();
439
440 self.generated += 1;
441 self.tokens.push(next_token);
442
443 let is_end_of_stream = self
444 .eos_token_id
445 .as_ref()
446 .map_or(false, |eos| eos.is_eos(next_token));
447
448 let decode_start = std::time::Instant::now();
449 let text = match self.tokenizer.decode(&[next_token], false) {
450 Ok(s) => Some(s),
451 Err(e) => {
452 log::error!("could not decode token {next_token}: {e}");
453 None
454 }
455 };
456 let decode_elapsed = decode_start.elapsed();
457 let post_elapsed = post_start.elapsed();
458
459 log::debug!(
460 " post-forward: total={:.1}ms penalty={:.1}ms sample={:.1}ms decode={:.1}ms",
461 post_elapsed.as_secs_f64() * 1000.0,
462 penalty_elapsed.as_secs_f64() * 1000.0,
463 sample_elapsed.as_secs_f64() * 1000.0,
464 decode_elapsed.as_secs_f64() * 1000.0,
465 );
466
467 Ok(Token {
468 id: next_token,
469 text,
470 is_end_of_stream,
471 })
472 }
473
474 pub fn reset(&mut self) {
476 self.tokens.clear();
477 self.ctx.cache.as_mut().expect("No cache specified").clear();
478 self.index_pos = 0;
479 self.generated = 0;
480 self.prompt_len = 0;
481
482 #[cfg(feature = "cuda")]
488 if let Device::Cuda(cuda_dev) = &self.ctx.device {
489 let _ = cuda_dev.cuda_stream().context().bind_to_thread();
490 }
491 }
492
493 pub async fn goodbye(&mut self) -> Result<()> {
495 let num_blocks = self.blocks.len();
496 let mut block_idx = 0;
497 while block_idx < num_blocks {
498 self.blocks[block_idx]
499 .goodbye()
500 .await
501 .map_err(|e| anyhow!("error in goodbye operation for block {block_idx}: {e}"))?;
502 block_idx += 1;
503 }
504 Ok(())
505 }
506}