Skip to main content

hot_loop/backend/session/
session.rs

1use candle_transformers::generation::{LogitsProcessor, Sampling};
2use candle_core::Tensor;
3use super::Generation;
4use crate::settings::{Settings, Seed};
5use crate::Error;
6use crate::Model;
7use crate::session::history::Role;
8use crate::utils::kv_cache::KvCache;
9use crate::utils::token_output_stream::TokenOutputStream;
10
11#[non_exhaustive]
12pub struct Session<M: Model> {
13    model: M, // read only
14    settings: Settings,
15    kv_cache: KvCache,
16    tos: TokenOutputStream,
17    system_prompt_pos: Option<usize>,
18}
19
20impl<M: Model> Session<M> {
21    pub(crate) fn new(model: M) -> Self {
22        let settings = Settings::default();
23
24        let layers_len = model.layers_len();
25        let kv_cache = KvCache::new(layers_len, 2);
26
27        let tos = TokenOutputStream::new(model.tokenizer());
28        
29        Self {
30            model,
31            settings,
32            kv_cache,
33            tos,
34            system_prompt_pos: None,
35        }
36    }
37
38    pub fn generate(&mut self, prompt: &str) -> Result<Generation<'_, M>, Error> {
39        let user_tokens = self.model.fmt_prompt(prompt, Role::User)?;
40        let assistant_start_tokens = self.model.assistant_start_template();
41
42        let mut tokens = Vec::with_capacity(
43            user_tokens.len() + assistant_start_tokens.len()
44        );
45
46        tokens.extend_from_slice(&user_tokens);
47        tokens.extend_from_slice(&assistant_start_tokens);
48
49        let logits_processor = {
50            let temperature = self.settings.temperature;
51            let sampling = if temperature <= 0. {
52                Sampling::ArgMax
53            } else {
54                match (self.settings.top_k, self.settings.top_p) {
55                    (None, None) => Sampling::All { temperature },
56                    (Some(k), None) => Sampling::TopK { k, temperature },
57                    (None, Some(p)) => Sampling::TopP { p, temperature },
58                    (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
59                }
60            };
61
62            let seed = match self.settings.seed {
63                Seed::Custom(seed) => seed,
64                Seed::Default => 299792458 // temporary
65            };
66
67            LogitsProcessor::from_sampling(seed, sampling)
68        };
69
70        Ok(Generation::new(
71            &self.model,
72            tokens,
73            logits_processor,
74            self.settings,
75            &mut self.tos,
76            &mut self.kv_cache,
77            self.model.eos_token(),
78        ))
79    }
80
81    pub fn with_settings(mut self, settings: Settings) -> Self {
82        self.settings = settings;
83        self
84    }
85
86    pub fn set_settings(&mut self, settings: Settings) {
87        self.settings = settings;
88    }
89
90    pub fn with_system_prompt(mut self, system_prompt: &str) -> Result<Self, Error> {
91        self.set_system_prompt_and_clear_history(system_prompt)?;
92        Ok(self)
93    }
94
95    pub fn set_system_prompt_and_clear_history(&mut self, system_prompt: &str) -> Result<(), Error> {
96        self.kv_cache.clear();
97
98        let sys_tokens = self.model.fmt_prompt(system_prompt, Role::System)?;
99        let input = Tensor::new(sys_tokens, self.model.device())?.unsqueeze(0)?;
100        let _ = self.model.forward(&input, 0, &mut self.kv_cache)?;
101
102        let current_pos = self.kv_cache.current_pos();
103
104        self.system_prompt_pos = Some(current_pos);
105
106        Ok(())
107    }
108
109    pub fn clear_history(&mut self) -> Result<(), Error> {
110        match self.system_prompt_pos {
111            Some(pos) => self.kv_cache.truncate(pos)?,
112            None => self.kv_cache.clear()
113        }
114        Ok(())
115    }
116
117    pub fn clear_system_prompt_and_history(&mut self) {
118        self.kv_cache.clear();
119    }
120}