hot_loop/backend/session/
generation.rs1use candle_core::Tensor;
2use candle_transformers::generation::LogitsProcessor;
3use crate::settings::Settings;
4use crate::Error;
5use crate::Model;
6use crate::utils::token_output_stream::TokenOutputStream;
7use crate::utils::kv_cache::KvCache;
8
9#[non_exhaustive]
10pub struct Generation<'session, M: Model> {
11 model: &'session M,
12 index: usize,
13 next_token: u32,
14 tokens_prefill: Option<Vec<u32>>,
15 all_tokens: Vec<u32>,
16 settings: Settings,
17 eos_token: u32,
18 logits_processor: LogitsProcessor,
19 tos: &'session mut TokenOutputStream,
20 kv_cache: &'session mut KvCache
21}
22
23impl<'session, M: Model> Generation<'session, M> {
24 pub(crate) fn new(
25 model: &'session M,
26 tokens_prefill: Vec<u32>,
27 logits_processor: LogitsProcessor,
28 settings: Settings,
29 tos: &'session mut TokenOutputStream,
30 kv_cache: &'session mut KvCache,
31 eos_token: u32,
32 ) -> Self {
33 Self {
34 model,
35 index: 0,
36 next_token: 0,
37 all_tokens: Vec::new(),
38 tokens_prefill: Some(tokens_prefill),
39 logits_processor,
40 settings,
41 tos,
42 kv_cache,
43 eos_token,
44 }
45 }
46
47 pub fn next_chunk(&mut self) -> Result<Option<String>, Error> {
48 loop {
49 if self.settings.sample_len <= self.index || self.next_token == self.eos_token {
50 return Ok(None);
51 }
52
53 let current_pos = self.kv_cache.current_pos();
54
55 let input = if self.index == 0 &&
56 let Some(tokens_prefill) = self.tokens_prefill.as_ref() {
57 let input = Tensor::new(tokens_prefill.as_slice(), self.model.device())?.unsqueeze(0)?;
58 self.tokens_prefill = None;
59 input
60
61 } else {
62 Tensor::new(&[self.next_token], self.model.device())?.unsqueeze(0)?
63 };
64
65 let logits = self.model.forward(&input, current_pos, &mut self.kv_cache)?;
66 let logits = logits.squeeze(0)?;
67
68 let logits = if self.settings.repeat_penalty == 1. {
69 logits
70 } else {
71 let start_at = self.all_tokens.len().saturating_sub(self.settings.repeat_last_n);
72 candle_transformers::utils::apply_repeat_penalty(
73 &logits,
74 self.settings.repeat_penalty,
75 &self.all_tokens[start_at..],
76 )?
77 };
78
79 self.next_token = self.logits_processor.sample(&logits)?;
80 self.all_tokens.push(self.next_token);
81
82 self.index += 1;
83
84 if let Some(chunk) = self.tos.next_token(self.next_token)? {
85 return Ok(Some(chunk))
86 }
87 }
88 }
89}
90
91impl<'session, M: Model> Drop for Generation<'session, M> {
92 fn drop(&mut self) {
93 self.tos.clear();
94 }
95}