Skip to main content

cake_core/cake/
master.rs

1use std::io::Write;
2
3use crate::models::{chat::Message, ImageGenerator, TextGenerator};
4
5use super::{api, Context};
6
7use crate::{ImageGenerationArgs, ModelType};
8use anyhow::Result;
9use image::{ImageBuffer, Rgb};
10
11/// A master connects to, communicates with and orchestrates the workers.
12pub struct Master<TG, IG> {
13    pub ctx: Context,
14    pub llm_model: Option<Box<TG>>,
15    pub sd_model: Option<Box<IG>>,
16}
17
18impl<TG: TextGenerator + Send + Sync + 'static, IG: ImageGenerator + Send + Sync + 'static>
19    Master<TG, IG>
20{
21    /// Create a new instance.
22    pub async fn new(mut ctx: Context) -> Result<Self> {
23        match ctx.args.model_type {
24            ModelType::ImageModel => {
25                let sd_model = IG::load(&mut ctx).await?;
26                Ok(Self {
27                    ctx,
28                    sd_model,
29                    llm_model: None,
30                })
31            }
32            ModelType::TextModel => {
33                let llm_model = TG::load(&mut ctx).await?;
34                Ok(Self {
35                    ctx,
36                    llm_model,
37                    sd_model: None,
38                })
39            }
40        }
41    }
42
43    pub async fn run(mut self) -> Result<()> {
44        if self.ctx.args.api.is_some() {
45            // run as REST api
46            api::start(self).await?;
47        } else {
48            // if running in cli mode, pre add system and user prompts
49            if self.ctx.args.model_type == ModelType::TextModel {
50                let llm_model = self.llm_model.as_mut().expect("LLM model not found");
51                llm_model.add_message(Message::system(self.ctx.args.system_prompt.clone()))?;
52                llm_model.add_message(Message::user(self.ctx.args.prompt.clone()))?;
53
54                // just run one generation to stdout
55                self.generate_text(None, |data| {
56                    if data.is_empty() {
57                        println!();
58                    } else {
59                        print!("{data}")
60                    }
61                    std::io::stdout().flush().unwrap();
62                })
63                .await?;
64            } else {
65                let mut step_num = 0;
66
67                self.generate_image(self.ctx.args.sd_img_gen_args.clone(), move |images| {
68                    let mut batched_num = 0;
69                    for image in images {
70                        image
71                            .save(format!("images/image_{}_{}.png", batched_num, step_num))
72                            .expect("Error saving image to disk");
73                        batched_num += 1;
74                    }
75                    step_num += 1;
76                })
77                .await?;
78            }
79        }
80
81        Ok(())
82    }
83
84    /// Reset the master state for a new inference.
85    pub fn reset(&mut self) -> Result<()> {
86        self.llm_model
87            .as_mut()
88            .expect("LLM model not found")
89            .reset()
90    }
91
92    /// clear worker kv cache
93    pub async fn goodbye(&mut self) -> Result<()> {
94        self.llm_model
95            .as_mut()
96            .expect("LLM model not found")
97            .goodbye()
98            .await
99    }
100
101    /// Start the generation loop and call the stream function for every token.
102    /// `max_tokens` overrides the default sample length if provided.
103    pub async fn generate_text<S>(&mut self, max_tokens: Option<usize>, mut stream: S) -> Result<()>
104    where
105        S: FnMut(&str),
106    {
107        log::info!(
108            "starting the inference loop (mem={})\n\n",
109            human_bytes::human_bytes(memory_stats::memory_stats().unwrap().physical_mem as f64)
110        );
111
112        let sample_len = max_tokens.unwrap_or(self.ctx.args.sample_len);
113        log::debug!("  sample_len = {}", sample_len);
114
115        // stream(&self.ctx.args.prompt);
116
117        let mut start_gen = std::time::Instant::now();
118        let llm_model = self.llm_model.as_mut().expect("LLM model not found");
119
120        for index in 0..sample_len {
121            if index == 1 {
122                // record start time again since the first token is the warmup
123                start_gen = std::time::Instant::now()
124            }
125
126            let token_start = std::time::Instant::now();
127            let token = llm_model.next_token(index).await?;
128            let token_elapsed = token_start.elapsed();
129
130            log::debug!(
131                "token {} generated in {:.1}ms ({:.1} tok/s)",
132                index,
133                token_elapsed.as_secs_f64() * 1000.0,
134                1.0 / token_elapsed.as_secs_f64(),
135            );
136
137            if token.is_end_of_stream {
138                break;
139            } else {
140                stream(&token.to_string());
141            }
142        }
143
144        // signal end of stream
145        stream("");
146
147        let dt = start_gen.elapsed();
148        let generated = llm_model.generated_tokens();
149
150        log::info!(
151            "{} tokens generated ({:.2} token/s) - mem={}",
152            generated,
153            (generated - 1) as f64 / dt.as_secs_f64(),
154            human_bytes::human_bytes(memory_stats::memory_stats().unwrap().physical_mem as f64)
155        );
156
157        Ok(())
158    }
159
160    pub async fn generate_image<F>(&mut self, args: ImageGenerationArgs, callback: F) -> Result<()>
161    where
162        F: FnMut(Vec<ImageBuffer<Rgb<u8>, Vec<u8>>>) + Send + 'static,
163    {
164        let sd_model = self.sd_model.as_mut().expect("SD model not found");
165        sd_model.generate_image(&args, callback).await
166    }
167}