hot_loop/core/session/
session.rs1use candle_transformers::generation::{LogitsProcessor, Sampling};
2use candle_core::Tensor;
3use super::Generation;
4use crate::{
5 Error, ModelWeights, KvCache,
6 settings::{Settings, Seed},
7 Role
8};
9use crate::utils::token_output_stream::TokenOutputStream;
10
11#[non_exhaustive]
12#[derive(Clone)]
13pub struct Session<'a, M: ModelWeights> {
14 model: &'a M, settings: Settings,
16 kv_cache: Vec<KvCache>,
17 tos: TokenOutputStream<'a>
18}
19
20impl<'a, M: ModelWeights> Session<'a, M> {
21 pub(crate) fn new(model: &'a M) -> Self {
22 let settings = Settings::default();
23 let kv_cache = model.create_kv_cache();
24 let tos = TokenOutputStream::new(model.tokenizer());
25
26 Self {
27 model,
28 settings,
29 kv_cache,
30 tos
31 }
32 }
33
34 pub fn generate(&mut self, prompt: &str) -> Result<Generation<'_, 'a, M>, Error> {
35 let user_tokens = self.model.fmt_prompt(prompt, Role::User)?;
36 let assistant_start_tokens = self.model.assistant_start_template();
37
38 let mut tokens = Vec::with_capacity(
39 user_tokens.len() + assistant_start_tokens.len()
40 );
41
42 tokens.extend_from_slice(&user_tokens);
43 tokens.extend_from_slice(&assistant_start_tokens);
44
45 let logits_processor = {
46 let temperature = self.settings.temperature;
47 let sampling = if temperature <= 0. {
48 Sampling::ArgMax
49 } else {
50 match (self.settings.top_k, self.settings.top_p) {
51 (None, None) => Sampling::All { temperature },
52 (Some(k), None) => Sampling::TopK { k, temperature },
53 (None, Some(p)) => Sampling::TopP { p, temperature },
54 (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
55 }
56 };
57
58 let seed = match self.settings.seed {
59 Seed::Custom(seed) => seed,
60 Seed::Default => 299792458 };
62
63 LogitsProcessor::from_sampling(seed, sampling)
64 };
65
66 Ok(Generation {
67 model: self.model,
68 index: 0,
69 next_token: 0,
70 tokens,
71 all_tokens: Vec::new(),
72 parameters: self.settings,
73 device: self.model.current_device(),
74 eos_token: self.model.eos_token(),
75 logits_processor,
76 tos: &mut self.tos,
77 kv_cache: &mut self.kv_cache
78 })
79 }
80
81 pub fn with_settings(mut self, settings: Settings) -> Self {
82 self.settings = settings;
83 self
84 }
85
86 pub fn set_settings(&mut self, settings: Settings) {
87 self.settings = settings;
88 }
89
90 pub fn with_system_prompt(mut self, system_prompt: &str) -> Result<Self, Error> {
91 self.set_system_prompt_and_clear_history(system_prompt)?;
92 Ok(self)
93 }
94
95 pub fn set_system_prompt_and_clear_history(&mut self, system_prompt: &str) -> Result<(), Error> {
96 self.reset_all_cache();
97
98 let sys_tokens = self.model.fmt_prompt(system_prompt, Role::System)?;
99
100 let input = Tensor::new(sys_tokens, self.model.current_device())?.unsqueeze(0)?;
101 let _ = self.model.forward(&input, 0, &mut self.kv_cache)?;
102
103 self.set_cache_rollback();
104
105 Ok(())
106 }
107
108 pub fn clear_history(&mut self) {
109 self.cache_rollback();
110 }
111
112 pub fn clear_system_prompt_and_history(&mut self) {
113 self.reset_all_cache();
114 }
115
116 fn set_cache_rollback(&mut self) {
117 for cache in &mut self.kv_cache {
118 cache.set_rollback();
119 }
120 }
121
122 fn cache_rollback(&mut self) {
123 for cache in &mut self.kv_cache {
124 cache.rollback();
125 }
126 }
127
128 fn reset_all_cache(&mut self) {
129 for cache in &mut self.kv_cache {
130 cache.reset_all();
131 }
132 }
133}