hot_loop/core/session/
generation.rs1use candle_core::{Device, Tensor};
2use candle_transformers::generation::LogitsProcessor;
3use crate::{
4 Error, ModelWeights, KvCache,
5 settings::Settings,
6};
7use crate::utils::token_output_stream::TokenOutputStream;
8
9#[non_exhaustive]
10pub struct Generation<'a, 'b, M: ModelWeights> {
11 pub(crate) model: &'b M,
12 pub(crate) index: usize,
13 pub(crate) next_token: u32,
14 pub(crate) tokens: Vec<u32>,
15 pub(crate) all_tokens: Vec<u32>,
16 pub(crate) parameters: Settings,
17 pub(crate) device: &'b Device,
18 pub(crate) eos_token: u32,
19 pub(crate) logits_processor: LogitsProcessor,
20 pub(crate) tos: &'a mut TokenOutputStream<'b>,
21 pub(crate) kv_cache: &'a mut Vec<KvCache>
22}
23
24impl<'a, 'b, M: ModelWeights> Generation<'a, 'b, M> {
25 pub fn next_chunk(&mut self) -> Result<Option<String>, Error> {
26 loop {
27 if self.parameters.sample_len <= self.index || self.next_token == self.eos_token {
28 return Ok(None);
29 }
30
31 let current_pos = self.kv_cache
32 .get(0)
33 .ok_or(Error::None)?
34 .current_seq_len();
35
36 let logits = if self.index == 0 {
37 let input = Tensor::new(self.tokens.as_slice(), &self.device)?.unsqueeze(0)?;
38 self.model.forward(&input, current_pos, &mut self.kv_cache)?
39 } else {
40 let input = Tensor::new(&[self.next_token], self.device)?.unsqueeze(0)?;
41 self.model.forward(&input, current_pos, &mut self.kv_cache)?
42 };
43
44 let logits = logits.squeeze(0)?;
45
46 let logits = if self.parameters.repeat_penalty == 1. {
47 logits
48 } else {
49 let start_at = self.all_tokens.len().saturating_sub(self.parameters.repeat_last_n);
50 candle_transformers::utils::apply_repeat_penalty(
51 &logits,
52 self.parameters.repeat_penalty,
53 &self.all_tokens[start_at..],
54 )?
55 };
56
57 self.next_token = self.logits_processor.sample(&logits)?;
58 self.all_tokens.push(self.next_token);
59
60 self.index += 1;
61
62 if let Some(chunk) = self.tos.next_token(self.next_token)? {
63 return Ok(Some(chunk))
64 }
65 }
66 }
67}
68
69impl<'a, 'b, M: ModelWeights> Drop for Generation<'a, 'b, M> {
70 fn drop(&mut self) {
71 self.tos.clear();
72 }
73}