Skip to main content

hot_loop/backend/
model_weights.rs

1use candle_core::{Device, Result as CandleResult, Tensor};
2use crate::utils::kv_cache::ConcatKvCache;
3use crate::{Error, session::Session};
4use tokenizers::Tokenizer;
5
6pub(crate) type KvCache = ConcatKvCache;
7
8#[derive(Clone, Copy)]
9pub enum Role {
10    System,
11    User,
12    Assistant,
13}
14
15pub trait ModelWeights {
16    fn forward(&self, input: &Tensor, offset: usize, kv_cache: &mut Vec<KvCache>) -> CandleResult<Tensor>;
17    
18    fn layers_len(&self) -> usize;
19    
20    fn create_kv_cache(&self) -> Vec<KvCache> {
21        let layers_len = self.layers_len();
22        
23        let mut kv_cache = Vec::with_capacity(layers_len);
24
25        for _ in 0..layers_len {
26            kv_cache.push(KvCache::new(2));
27        }
28
29        kv_cache
30    }
31
32    fn tokenizer(&self) -> &Tokenizer;
33
34    fn current_device(&self) -> &Device;
35
36    fn fmt_prompt(&self, prompt: &str, role: Role) -> Result<Vec<u32>, Error>;
37    fn assistant_start_template(&self) -> Vec<u32>;
38    fn eos_token(&self) -> u32;
39}
40
41// ADD extend_from_history
42
43pub trait Model: ModelWeights {
44    fn new_session(&self) -> Session<'_, Self>
45    where
46        Self: Sized
47    {
48        Session::new(self)
49    }
50}
51
52impl<T: ModelWeights> Model for T {}