Skip to main content

hot_loop/backend/session/
generation.rs

1use candle_core::Tensor;
2use candle_transformers::generation::LogitsProcessor;
3use crate::settings::Settings;
4use crate::Error;
5use crate::Model;
6use crate::utils::token_output_stream::TokenOutputStream;
7use crate::utils::kv_cache::KvCache;
8
9#[non_exhaustive]
10pub struct Generation<'session, M: Model> {
11    model: &'session M,
12    index: usize,
13    next_token: u32,
14    tokens_prefill: Option<Vec<u32>>,
15    all_tokens: Vec<u32>,
16    settings: Settings,
17    eos_token: u32,
18    logits_processor: LogitsProcessor,
19    tos: &'session mut TokenOutputStream,
20    kv_cache: &'session mut KvCache
21}
22
23impl<'session, M: Model> Generation<'session, M> {
24    pub(crate) fn new(
25        model: &'session M,
26        tokens_prefill: Vec<u32>,
27        logits_processor: LogitsProcessor,
28        settings: Settings,
29        tos: &'session mut TokenOutputStream,
30        kv_cache: &'session mut KvCache,
31        eos_token: u32,
32    ) -> Self {
33        Self {
34            model,
35            index: 0,
36            next_token: 0,
37            all_tokens: Vec::new(),
38            tokens_prefill: Some(tokens_prefill),
39            logits_processor,
40            settings,
41            tos,
42            kv_cache,
43            eos_token,
44        }
45    }
46
47    pub fn next_chunk(&mut self) -> Result<Option<String>, Error> {
48        loop {
49            if self.settings.sample_len <= self.index || self.next_token == self.eos_token {
50                return Ok(None);
51            }
52
53            let current_pos = self.kv_cache.current_pos();
54
55            let input = if self.index == 0 &&
56                let Some(tokens_prefill) = self.tokens_prefill.as_ref() {
57                let input = Tensor::new(tokens_prefill.as_slice(), self.model.device())?.unsqueeze(0)?;
58                self.tokens_prefill = None;
59                input
60
61            } else {
62                Tensor::new(&[self.next_token], self.model.device())?.unsqueeze(0)?
63            };
64
65            let logits = self.model.forward(&input, current_pos, &mut self.kv_cache)?;
66            let logits = logits.squeeze(0)?;
67
68            let logits = if self.settings.repeat_penalty == 1. {
69                logits
70            } else {
71                let start_at = self.all_tokens.len().saturating_sub(self.settings.repeat_last_n);
72                candle_transformers::utils::apply_repeat_penalty(
73                    &logits,
74                    self.settings.repeat_penalty,
75                    &self.all_tokens[start_at..],
76                )?
77            };
78
79            self.next_token = self.logits_processor.sample(&logits)?;
80            self.all_tokens.push(self.next_token);
81
82            self.index += 1;
83
84            if let Some(chunk) = self.tos.next_token(self.next_token)? {
85                return Ok(Some(chunk))
86            }
87        }
88    }
89}
90
91impl<'session, M: Model> Drop for Generation<'session, M> {
92    fn drop(&mut self) {
93        self.tos.clear();
94    }
95}