Skip to main content

hot_loop/core/session/
session.rs

1use candle_transformers::generation::{LogitsProcessor, Sampling};
2use candle_core::Tensor;
3use super::Generation;
4use crate::{
5    Error, ModelWeights, KvCache,
6    settings::{Settings, Seed},
7    Role
8};
9use crate::utils::token_output_stream::TokenOutputStream;
10
11#[non_exhaustive]
12#[derive(Clone)]
13pub struct Session<'a, M: ModelWeights> {
14    model: &'a M, // read only
15    settings: Settings,
16    kv_cache: Vec<KvCache>,
17    tos: TokenOutputStream<'a>
18}
19
20impl<'a, M: ModelWeights> Session<'a, M> {
21    pub(crate) fn new(model: &'a M) -> Self {
22        let settings = Settings::default();
23        let kv_cache = model.create_kv_cache();
24        let tos = TokenOutputStream::new(model.tokenizer());
25        
26        Self {
27            model,
28            settings,
29            kv_cache,
30            tos
31        }
32    }
33
34    pub fn generate(&mut self, prompt: &str) -> Result<Generation<'_, 'a, M>, Error> {
35        let user_tokens = self.model.fmt_prompt(prompt, Role::User)?;
36        let assistant_start_tokens = self.model.assistant_start_template();
37
38        let mut tokens = Vec::with_capacity(
39            user_tokens.len() + assistant_start_tokens.len()
40        );
41
42        tokens.extend_from_slice(&user_tokens);
43        tokens.extend_from_slice(&assistant_start_tokens);
44
45        let logits_processor = {
46            let temperature = self.settings.temperature;
47            let sampling = if temperature <= 0. {
48                Sampling::ArgMax
49            } else {
50                match (self.settings.top_k, self.settings.top_p) {
51                    (None, None) => Sampling::All { temperature },
52                    (Some(k), None) => Sampling::TopK { k, temperature },
53                    (None, Some(p)) => Sampling::TopP { p, temperature },
54                    (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
55                }
56            };
57
58            let seed = match self.settings.seed {
59                Seed::Custom(seed) => seed,
60                Seed::Default => 299792458 // temporary
61            };
62
63            LogitsProcessor::from_sampling(seed, sampling)
64        };
65
66        Ok(Generation {
67            model: self.model,
68            index: 0,
69            next_token: 0,
70            tokens,
71            all_tokens: Vec::new(),
72            parameters: self.settings,
73            device: self.model.current_device(),
74            eos_token: self.model.eos_token(),
75            logits_processor,
76            tos: &mut self.tos,
77            kv_cache: &mut self.kv_cache
78        })
79    }
80
81    pub fn with_settings(mut self, settings: Settings) -> Self {
82        self.settings = settings;
83        self
84    }
85
86    pub fn set_settings(&mut self, settings: Settings) {
87        self.settings = settings;
88    }
89
90    pub fn with_system_prompt(mut self, system_prompt: &str) -> Result<Self, Error> {
91        self.set_system_prompt_and_clear_history(system_prompt)?;
92        Ok(self)
93    }
94
95    pub fn set_system_prompt_and_clear_history(&mut self, system_prompt: &str) -> Result<(), Error> {
96        self.reset_all_cache();
97
98        let sys_tokens = self.model.fmt_prompt(system_prompt, Role::System)?;
99
100        let input = Tensor::new(sys_tokens, self.model.current_device())?.unsqueeze(0)?;
101        let _ = self.model.forward(&input, 0, &mut self.kv_cache)?;
102
103        self.set_cache_rollback();
104
105        Ok(())
106    }
107
108    pub fn clear_history(&mut self) {
109        self.cache_rollback();
110    }
111
112    pub fn clear_system_prompt_and_history(&mut self) {
113        self.reset_all_cache();
114    }
115
116    fn set_cache_rollback(&mut self) {
117        for cache in &mut self.kv_cache {
118            cache.set_rollback();
119        }
120    }
121
122    fn cache_rollback(&mut self) {
123        for cache in &mut self.kv_cache {
124            cache.rollback();
125        }
126    }
127
128    fn reset_all_cache(&mut self) {
129        for cache in &mut self.kv_cache {
130            cache.reset_all();
131        }
132    }
133}