Skip to main content

cake_core/models/
mod.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use image::{ImageBuffer, Rgb};
4
5use chat::Message;
6
7use crate::cake::{Context, Forwarder};
8use crate::ImageGenerationArgs;
9
10pub mod chat;
11pub mod common;
12#[cfg(feature = "llama")]
13pub mod llama3;
14#[cfg(feature = "qwen2")]
15pub mod qwen2;
16#[cfg(feature = "qwen3_5")]
17pub mod qwen3_5;
18pub mod sd;
19
20/// A token.
21pub struct Token {
22    /// Numerical identifier.
23    pub id: u32,
24    /// Resolved text token or None if not present in the tokenizer.
25    pub text: Option<String>,
26    /// Set to true if the stream of tokens is over.
27    pub is_end_of_stream: bool,
28}
29
30impl std::fmt::Display for Token {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        write!(
33            f,
34            "{}",
35            if let Some(text) = &self.text {
36                text.clone()
37            } else {
38                format!("<token {}>", self.id)
39            }
40        )
41    }
42}
43
44/// A model must implement this trait in order to be usable by the Cake framework.
45#[async_trait]
46pub trait Generator {
47    /// This associated type determines which part of the model can be sharded.
48    type Shardable: Forwarder;
49
50    /// The model name.
51    const MODEL_NAME: &'static str;
52
53    /// Load the model from the context.
54    async fn load(context: &mut Context) -> Result<Option<Box<Self>>>;
55}
56
57#[async_trait]
58pub trait TextGenerator: Generator {
59    /// Add a message to the chat.
60    fn add_message(&mut self, message: Message) -> Result<()>;
61    /// Clear chat history.
62    fn reset(&mut self) -> Result<()>;
63    /// clear worker kv cache
64    async fn goodbye(&mut self) -> Result<()>;
65
66    /// Return the next token.
67    async fn next_token(&mut self, index: usize) -> Result<Token>;
68    /// Return the number of generated tokens so far.
69    fn generated_tokens(&self) -> usize;
70}
71
72#[async_trait]
73pub trait ImageGenerator: Generator {
74    async fn generate_image<F>(
75        &mut self,
76        args: &ImageGenerationArgs,
77        mut callback: F,
78    ) -> Result<(), anyhow::Error>
79    where
80        F: FnMut(Vec<ImageBuffer<Rgb<u8>, Vec<u8>>>) + Send + 'static;
81}