hot_loop/backend/
model_weights.rs1use candle_core::{Device, Result as CandleResult, Tensor};
2use crate::utils::kv_cache::ConcatKvCache;
3use crate::{Error, session::Session};
4use tokenizers::Tokenizer;
5
6pub(crate) type KvCache = ConcatKvCache;
7
8#[derive(Clone, Copy)]
9pub enum Role {
10 System,
11 User,
12 Assistant,
13}
14
15pub trait ModelWeights {
16 fn forward(&self, input: &Tensor, offset: usize, kv_cache: &mut Vec<KvCache>) -> CandleResult<Tensor>;
17
18 fn layers_len(&self) -> usize;
19
20 fn create_kv_cache(&self) -> Vec<KvCache> {
21 let layers_len = self.layers_len();
22
23 let mut kv_cache = Vec::with_capacity(layers_len);
24
25 for _ in 0..layers_len {
26 kv_cache.push(KvCache::new(2));
27 }
28
29 kv_cache
30 }
31
32 fn tokenizer(&self) -> &Tokenizer;
33
34 fn current_device(&self) -> &Device;
35
36 fn fmt_prompt(&self, prompt: &str, role: Role) -> Result<Vec<u32>, Error>;
37 fn assistant_start_template(&self) -> Vec<u32>;
38 fn eos_token(&self) -> u32;
39}
40
41pub trait Model: ModelWeights {
44 fn new_session(&self) -> Session<'_, Self>
45 where
46 Self: Sized
47 {
48 Session::new(self)
49 }
50}
51
52impl<T: ModelWeights> Model for T {}