Skip to main content

hot_loop/models/models_core/
model.rs

1use std::sync::Arc;
2use candle_core::{Device, Result as CandleResult, Tensor};
3use crate::{session::Session, Error};
4use crate::utils::kv_cache::KvCache;
5use crate::session::history::Role;
6use tokenizers::Tokenizer;
7
8#[doc(hidden)]
9pub trait ModelWeights {
10    fn forward(&self, input: &Tensor, offset: usize, kv_cache: &mut KvCache) -> CandleResult<Tensor>;
11
12    fn layers_len(&self) -> usize;
13
14    fn tokenizer(&self) -> Arc<Tokenizer>;
15
16    fn device(&self) -> &Device;
17
18    fn fmt_prompt(&self, prompt: &str, role: Role) -> Result<Vec<u32>, Error>;
19    fn assistant_start_template(&self) -> Vec<u32>;
20    fn eos_token(&self) -> u32;
21}
22
23pub trait Model: ModelWeights {
24    fn new_session(&self) -> Session<Self> where Self: Sized;
25}
26
27impl<M: ModelWeights + Clone> Model for M {
28    fn new_session(&self) -> Session<M> {
29        Session::new(self.clone())
30    }
31}