hot_loop/backend/session/
session.rs1use candle_transformers::generation::{LogitsProcessor, Sampling};
2use candle_core::Tensor;
3use super::Generation;
4use crate::settings::{Settings, Seed};
5use crate::Error;
6use crate::Model;
7use crate::session::history::Role;
8use crate::utils::kv_cache::KvCache;
9use crate::utils::token_output_stream::TokenOutputStream;
10
11#[non_exhaustive]
12pub struct Session<M: Model> {
13 model: M, settings: Settings,
15 kv_cache: KvCache,
16 tos: TokenOutputStream,
17 system_prompt_pos: Option<usize>,
18}
19
20impl<M: Model> Session<M> {
21 pub(crate) fn new(model: M) -> Self {
22 let settings = Settings::default();
23
24 let layers_len = model.layers_len();
25 let kv_cache = KvCache::new(layers_len, 2);
26
27 let tos = TokenOutputStream::new(model.tokenizer());
28
29 Self {
30 model,
31 settings,
32 kv_cache,
33 tos,
34 system_prompt_pos: None,
35 }
36 }
37
38 pub fn generate(&mut self, prompt: &str) -> Result<Generation<'_, M>, Error> {
39 let user_tokens = self.model.fmt_prompt(prompt, Role::User)?;
40 let assistant_start_tokens = self.model.assistant_start_template();
41
42 let mut tokens = Vec::with_capacity(
43 user_tokens.len() + assistant_start_tokens.len()
44 );
45
46 tokens.extend_from_slice(&user_tokens);
47 tokens.extend_from_slice(&assistant_start_tokens);
48
49 let logits_processor = {
50 let temperature = self.settings.temperature;
51 let sampling = if temperature <= 0. {
52 Sampling::ArgMax
53 } else {
54 match (self.settings.top_k, self.settings.top_p) {
55 (None, None) => Sampling::All { temperature },
56 (Some(k), None) => Sampling::TopK { k, temperature },
57 (None, Some(p)) => Sampling::TopP { p, temperature },
58 (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
59 }
60 };
61
62 let seed = match self.settings.seed {
63 Seed::Custom(seed) => seed,
64 Seed::Default => 299792458 };
66
67 LogitsProcessor::from_sampling(seed, sampling)
68 };
69
70 Ok(Generation::new(
71 &self.model,
72 tokens,
73 logits_processor,
74 self.settings,
75 &mut self.tos,
76 &mut self.kv_cache,
77 self.model.eos_token(),
78 ))
79 }
80
81 pub fn with_settings(mut self, settings: Settings) -> Self {
82 self.settings = settings;
83 self
84 }
85
86 pub fn set_settings(&mut self, settings: Settings) {
87 self.settings = settings;
88 }
89
90 pub fn with_system_prompt(mut self, system_prompt: &str) -> Result<Self, Error> {
91 self.set_system_prompt_and_clear_history(system_prompt)?;
92 Ok(self)
93 }
94
95 pub fn set_system_prompt_and_clear_history(&mut self, system_prompt: &str) -> Result<(), Error> {
96 self.kv_cache.clear();
97
98 let sys_tokens = self.model.fmt_prompt(system_prompt, Role::System)?;
99 let input = Tensor::new(sys_tokens, self.model.device())?.unsqueeze(0)?;
100 let _ = self.model.forward(&input, 0, &mut self.kv_cache)?;
101
102 let current_pos = self.kv_cache.current_pos();
103
104 self.system_prompt_pos = Some(current_pos);
105
106 Ok(())
107 }
108
109 pub fn clear_history(&mut self) -> Result<(), Error> {
110 match self.system_prompt_pos {
111 Some(pos) => self.kv_cache.truncate(pos)?,
112 None => self.kv_cache.clear()
113 }
114 Ok(())
115 }
116
117 pub fn clear_system_prompt_and_history(&mut self) {
118 self.kv_cache.clear();
119 }
120}