hot_loop/core/
model_weights.rs1use candle_core::{Device, Result as CandleResult, Tensor};
2use crate::utils::kv_cache::ConcatKvCache;
3use crate::Error;
4use tokenizers::Tokenizer;
5
6pub(crate) type KvCache = ConcatKvCache;
7
8#[derive(Clone, Copy)]
9pub enum Role {
10 System,
11 User,
12 Assistant,
13}
14
15pub trait ModelWeights {
16 fn forward(&self, input: &Tensor, offset: usize, kv_cache: &mut Vec<KvCache>) -> CandleResult<Tensor>;
17
18 fn create_kv_cache(&self) -> Vec<KvCache>;
19
20 fn tokenizer(&self) -> &Tokenizer;
21
22 fn current_device(&self) -> &Device;
23
24 fn fmt_prompt(&self, prompt: &str, role: Role) -> Result<Vec<u32>, Error>;
25 fn assistant_start_template(&self) -> Vec<u32>;
26 fn eos_token(&self) -> u32;
27}
28
29pub trait Model: ModelWeights {}
32
33impl<T: ModelWeights> Model for T {}