1use anyhow::Result;
2use async_trait::async_trait;
3use image::{ImageBuffer, Rgb};
4
5use chat::Message;
6
7use crate::cake::{Context, Forwarder};
8use crate::ImageGenerationArgs;
9
10pub mod chat;
11pub mod common;
12#[cfg(feature = "llama")]
13pub mod llama3;
14#[cfg(feature = "qwen2")]
15pub mod qwen2;
16#[cfg(feature = "qwen3_5")]
17pub mod qwen3_5;
18pub mod sd;
19
20pub struct Token {
22 pub id: u32,
24 pub text: Option<String>,
26 pub is_end_of_stream: bool,
28}
29
30impl std::fmt::Display for Token {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 write!(
33 f,
34 "{}",
35 if let Some(text) = &self.text {
36 text.clone()
37 } else {
38 format!("<token {}>", self.id)
39 }
40 )
41 }
42}
43
44#[async_trait]
46pub trait Generator {
47 type Shardable: Forwarder;
49
50 const MODEL_NAME: &'static str;
52
53 async fn load(context: &mut Context) -> Result<Option<Box<Self>>>;
55}
56
57#[async_trait]
58pub trait TextGenerator: Generator {
59 fn add_message(&mut self, message: Message) -> Result<()>;
61 fn reset(&mut self) -> Result<()>;
63 async fn goodbye(&mut self) -> Result<()>;
65
66 async fn next_token(&mut self, index: usize) -> Result<Token>;
68 fn generated_tokens(&self) -> usize;
70}
71
72#[async_trait]
73pub trait ImageGenerator: Generator {
74 async fn generate_image<F>(
75 &mut self,
76 args: &ImageGenerationArgs,
77 mut callback: F,
78 ) -> Result<(), anyhow::Error>
79 where
80 F: FnMut(Vec<ImageBuffer<Rgb<u8>, Vec<u8>>>) + Send + 'static;
81}