Skip to main content

hot_loop/core/session/
generation.rs

1use candle_core::{Device, Tensor};
2use candle_transformers::generation::LogitsProcessor;
3use crate::{
4    Error, ModelWeights, KvCache,
5    settings::Settings,
6};
7use crate::utils::token_output_stream::TokenOutputStream;
8
9#[non_exhaustive]
10pub struct Generation<'a, 'b, M: ModelWeights> {
11    pub(crate) model: &'b M,
12    pub(crate) index: usize,
13    pub(crate) next_token: u32,
14    pub(crate) tokens: Vec<u32>,
15    pub(crate) all_tokens: Vec<u32>,
16    pub(crate) parameters: Settings,
17    pub(crate) device: &'b Device,
18    pub(crate) eos_token: u32,
19    pub(crate) logits_processor: LogitsProcessor,
20    pub(crate) tos: &'a mut TokenOutputStream<'b>,
21    pub(crate) kv_cache: &'a mut Vec<KvCache>
22}
23
24impl<'a, 'b, M: ModelWeights> Generation<'a, 'b, M> {
25    pub fn next_chunk(&mut self) -> Result<Option<String>, Error> {
26        loop {
27            if self.parameters.sample_len <= self.index || self.next_token == self.eos_token {
28                return Ok(None);
29            }
30
31            let current_pos = self.kv_cache
32                .get(0)
33                .ok_or(Error::None)?
34                .current_seq_len();
35
36            let logits = if self.index == 0 {
37                let input = Tensor::new(self.tokens.as_slice(), &self.device)?.unsqueeze(0)?;
38                self.model.forward(&input, current_pos, &mut self.kv_cache)?
39            } else {
40                let input = Tensor::new(&[self.next_token], self.device)?.unsqueeze(0)?;
41                self.model.forward(&input, current_pos, &mut self.kv_cache)?
42            };
43
44            let logits = logits.squeeze(0)?;
45
46            let logits = if self.parameters.repeat_penalty == 1. {
47                logits
48            } else {
49                let start_at = self.all_tokens.len().saturating_sub(self.parameters.repeat_last_n);
50                candle_transformers::utils::apply_repeat_penalty(
51                    &logits,
52                    self.parameters.repeat_penalty,
53                    &self.all_tokens[start_at..],
54                )?
55            };
56
57            self.next_token = self.logits_processor.sample(&logits)?;
58            self.all_tokens.push(self.next_token);
59
60            self.index += 1;
61
62            if let Some(chunk) = self.tos.next_token(self.next_token)? {
63                return Ok(Some(chunk))
64            }
65        }
66    }
67}
68
69impl<'a, 'b, M: ModelWeights> Drop for Generation<'a, 'b, M> {
70    fn drop(&mut self) {
71        self.tos.clear();
72    }
73}