1use std::io::Write;
2
3use crate::models::{chat::Message, ImageGenerator, TextGenerator};
4
5use super::{api, Context};
6
7use crate::{ImageGenerationArgs, ModelType};
8use anyhow::Result;
9use image::{ImageBuffer, Rgb};
10
11pub struct Master<TG, IG> {
13 pub ctx: Context,
14 pub llm_model: Option<Box<TG>>,
15 pub sd_model: Option<Box<IG>>,
16}
17
18impl<TG: TextGenerator + Send + Sync + 'static, IG: ImageGenerator + Send + Sync + 'static>
19 Master<TG, IG>
20{
21 pub async fn new(mut ctx: Context) -> Result<Self> {
23 match ctx.args.model_type {
24 ModelType::ImageModel => {
25 let sd_model = IG::load(&mut ctx).await?;
26 Ok(Self {
27 ctx,
28 sd_model,
29 llm_model: None,
30 })
31 }
32 ModelType::TextModel => {
33 let llm_model = TG::load(&mut ctx).await?;
34 Ok(Self {
35 ctx,
36 llm_model,
37 sd_model: None,
38 })
39 }
40 }
41 }
42
43 pub async fn run(mut self) -> Result<()> {
44 if self.ctx.args.api.is_some() {
45 api::start(self).await?;
47 } else {
48 if self.ctx.args.model_type == ModelType::TextModel {
50 let llm_model = self.llm_model.as_mut().expect("LLM model not found");
51 llm_model.add_message(Message::system(self.ctx.args.system_prompt.clone()))?;
52 llm_model.add_message(Message::user(self.ctx.args.prompt.clone()))?;
53
54 self.generate_text(None, |data| {
56 if data.is_empty() {
57 println!();
58 } else {
59 print!("{data}")
60 }
61 std::io::stdout().flush().unwrap();
62 })
63 .await?;
64 } else {
65 let mut step_num = 0;
66
67 self.generate_image(self.ctx.args.sd_img_gen_args.clone(), move |images| {
68 let mut batched_num = 0;
69 for image in images {
70 image
71 .save(format!("images/image_{}_{}.png", batched_num, step_num))
72 .expect("Error saving image to disk");
73 batched_num += 1;
74 }
75 step_num += 1;
76 })
77 .await?;
78 }
79 }
80
81 Ok(())
82 }
83
84 pub fn reset(&mut self) -> Result<()> {
86 self.llm_model
87 .as_mut()
88 .expect("LLM model not found")
89 .reset()
90 }
91
92 pub async fn goodbye(&mut self) -> Result<()> {
94 self.llm_model
95 .as_mut()
96 .expect("LLM model not found")
97 .goodbye()
98 .await
99 }
100
101 pub async fn generate_text<S>(&mut self, max_tokens: Option<usize>, mut stream: S) -> Result<()>
104 where
105 S: FnMut(&str),
106 {
107 log::info!(
108 "starting the inference loop (mem={})\n\n",
109 human_bytes::human_bytes(memory_stats::memory_stats().unwrap().physical_mem as f64)
110 );
111
112 let sample_len = max_tokens.unwrap_or(self.ctx.args.sample_len);
113 log::debug!(" sample_len = {}", sample_len);
114
115 let mut start_gen = std::time::Instant::now();
118 let llm_model = self.llm_model.as_mut().expect("LLM model not found");
119
120 for index in 0..sample_len {
121 if index == 1 {
122 start_gen = std::time::Instant::now()
124 }
125
126 let token_start = std::time::Instant::now();
127 let token = llm_model.next_token(index).await?;
128 let token_elapsed = token_start.elapsed();
129
130 log::debug!(
131 "token {} generated in {:.1}ms ({:.1} tok/s)",
132 index,
133 token_elapsed.as_secs_f64() * 1000.0,
134 1.0 / token_elapsed.as_secs_f64(),
135 );
136
137 if token.is_end_of_stream {
138 break;
139 } else {
140 stream(&token.to_string());
141 }
142 }
143
144 stream("");
146
147 let dt = start_gen.elapsed();
148 let generated = llm_model.generated_tokens();
149
150 log::info!(
151 "{} tokens generated ({:.2} token/s) - mem={}",
152 generated,
153 (generated - 1) as f64 / dt.as_secs_f64(),
154 human_bytes::human_bytes(memory_stats::memory_stats().unwrap().physical_mem as f64)
155 );
156
157 Ok(())
158 }
159
160 pub async fn generate_image<F>(&mut self, args: ImageGenerationArgs, callback: F) -> Result<()>
161 where
162 F: FnMut(Vec<ImageBuffer<Rgb<u8>, Vec<u8>>>) + Send + 'static,
163 {
164 let sd_model = self.sd_model.as_mut().expect("SD model not found");
165 sd_model.generate_image(&args, callback).await
166 }
167}