1use crate::cake::Master;
2use crate::models::chat::Message;
3use crate::models::{ImageGenerator, TextGenerator};
4use actix_web::{web, HttpRequest, HttpResponse, Responder};
5use serde::{Deserialize, Serialize};
6use std::io::Write;
7use std::sync::Arc;
8use std::time::{SystemTime, UNIX_EPOCH};
9use tokio::sync::RwLock;
10
11#[derive(Deserialize)]
12pub struct ChatRequest {
13 pub messages: Vec<Message>,
14 #[serde(default)]
15 pub model: Option<String>,
16 #[serde(default)]
17 pub stream: Option<bool>,
18 #[serde(default)]
19 pub max_tokens: Option<usize>,
20 #[serde(default)]
21 pub temperature: Option<f64>,
22}
23
24#[derive(Serialize)]
25struct Usage {
26 pub prompt_tokens: usize,
27 pub completion_tokens: usize,
28 pub total_tokens: usize,
29}
30
31#[derive(Serialize)]
32struct Choice {
33 pub index: usize,
34 pub message: Message,
35 pub finish_reason: String,
36}
37
38#[derive(Serialize)]
39struct ChatResponse {
40 pub id: String,
41 pub object: String,
42 pub created: u64,
43 pub model: String,
44 pub choices: Vec<Choice>,
45 pub usage: Usage,
46}
47
48impl ChatResponse {
49 pub fn new(model: String, message: String, prompt_tokens: usize, completion_tokens: usize, finish_reason: String) -> Self {
50 let id = format!("chatcmpl-{}", uuid::Uuid::new_v4());
51 let object = String::from("chat.completion");
52 let created = SystemTime::now()
53 .duration_since(UNIX_EPOCH)
54 .unwrap()
55 .as_secs();
56 let choices = vec![Choice {
57 index: 0,
58 message: Message::assistant(message),
59 finish_reason,
60 }];
61
62 Self {
63 id,
64 object,
65 created,
66 model,
67 choices,
68 usage: Usage {
69 prompt_tokens,
70 completion_tokens,
71 total_tokens: prompt_tokens + completion_tokens,
72 },
73 }
74 }
75}
76
77#[derive(Serialize)]
79struct StreamChoice {
80 pub index: usize,
81 pub delta: StreamDelta,
82 pub finish_reason: Option<String>,
83}
84
85#[derive(Serialize)]
86struct StreamDelta {
87 #[serde(skip_serializing_if = "Option::is_none")]
88 pub role: Option<String>,
89 #[serde(skip_serializing_if = "Option::is_none")]
90 pub content: Option<String>,
91}
92
93#[derive(Serialize)]
94struct StreamResponse {
95 pub id: String,
96 pub object: String,
97 pub created: u64,
98 pub model: String,
99 pub choices: Vec<StreamChoice>,
100}
101
102pub async fn generate_text<TG, IG>(
103 state: web::Data<Arc<RwLock<Master<TG, IG>>>>,
104 req: HttpRequest,
105 body: web::Json<ChatRequest>,
106) -> impl Responder
107where
108 TG: TextGenerator + Send + Sync + 'static,
109 IG: ImageGenerator + Send + Sync + 'static,
110{
111 let client = req
112 .peer_addr()
113 .map(|a| a.to_string())
114 .unwrap_or_else(|| "unknown".to_string());
115 let stream = body.0.stream.unwrap_or(false);
116
117 log::info!("starting chat for {} (stream={}) ...", &client, stream);
118
119 if stream {
120 generate_text_stream(state, body.0).await
121 } else {
122 generate_text_blocking(state, body.0).await
123 }
124}
125
126async fn generate_text_blocking<TG, IG>(
127 state: web::Data<Arc<RwLock<Master<TG, IG>>>>,
128 request: ChatRequest,
129) -> HttpResponse
130where
131 TG: TextGenerator + Send + Sync + 'static,
132 IG: ImageGenerator + Send + Sync + 'static,
133{
134 let mut master = state.write().await;
135
136 if let Err(e) = master.reset() {
137 return HttpResponse::InternalServerError().json(serde_json::json!({"error": format!("{e}")}));
138 }
139
140 let num_messages = request.messages.len();
141 let llm_model = master.llm_model.as_mut().expect("LLM model not found");
142 for message in request.messages {
143 if let Err(e) = llm_model.add_message(message) {
144 return HttpResponse::InternalServerError().json(serde_json::json!({"error": format!("{e}")}));
145 }
146 }
147
148 let mut resp = String::new();
149 let mut finish_reason = "length".to_string();
150
151 let gen_result = master
152 .generate_text(request.max_tokens, |data| {
153 if data.is_empty() {
154 finish_reason = "stop".to_string();
155 } else {
156 resp += data;
157 print!("{data}");
158 }
159 let _ = std::io::stdout().flush();
160 })
161 .await;
162
163 println!();
164
165 let completion_tokens = master
166 .llm_model
167 .as_ref()
168 .expect("LLM model not found")
169 .generated_tokens();
170
171 let _ = master.goodbye().await;
172
173 if let Err(e) = gen_result {
174 return HttpResponse::InternalServerError().json(serde_json::json!({"error": format!("{e}")}));
175 }
176
177 let response = ChatResponse::new(
178 TG::MODEL_NAME.to_string(),
179 resp,
180 num_messages,
181 completion_tokens,
182 finish_reason,
183 );
184
185 HttpResponse::Ok().json(response)
186}
187
188async fn generate_text_stream<TG, IG>(
189 state: web::Data<Arc<RwLock<Master<TG, IG>>>>,
190 request: ChatRequest,
191) -> HttpResponse
192where
193 TG: TextGenerator + Send + Sync + 'static,
194 IG: ImageGenerator + Send + Sync + 'static,
195{
196 let id = format!("chatcmpl-{}", uuid::Uuid::new_v4());
197 let created = SystemTime::now()
198 .duration_since(UNIX_EPOCH)
199 .unwrap()
200 .as_secs();
201 let model = TG::MODEL_NAME.to_string();
202
203 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Option<String>>();
204
205 let state_clone = state.clone();
206 tokio::spawn(async move {
207 let mut master = state_clone.write().await;
208
209 if let Err(e) = master.reset() {
210 log::error!("reset error: {e}");
211 let _ = tx.send(None);
212 return;
213 }
214
215 let llm_model = master.llm_model.as_mut().expect("LLM model not found");
216 for message in request.messages {
217 if let Err(e) = llm_model.add_message(message) {
218 log::error!("add_message error: {e}");
219 let _ = tx.send(None);
220 return;
221 }
222 }
223
224 if let Err(e) = master
225 .generate_text(request.max_tokens, |data| {
226 if data.is_empty() {
227 let _ = tx.send(None);
228 } else {
229 let _ = tx.send(Some(data.to_string()));
230 }
231 })
232 .await
233 {
234 log::error!("generate_text error: {e}");
235 let _ = tx.send(None);
236 }
237
238 let _ = master.goodbye().await;
239 });
240
241 let stream = async_stream::stream! {
242 let initial = StreamResponse {
244 id: id.clone(),
245 object: "chat.completion.chunk".to_string(),
246 created,
247 model: model.clone(),
248 choices: vec![StreamChoice {
249 index: 0,
250 delta: StreamDelta {
251 role: Some("assistant".to_string()),
252 content: None,
253 },
254 finish_reason: None,
255 }],
256 };
257 yield Ok::<_, actix_web::Error>(
258 web::Bytes::from(format!("data: {}\n\n", serde_json::to_string(&initial).unwrap()))
259 );
260
261 while let Some(msg) = rx.recv().await {
263 match msg {
264 Some(content) => {
265 let chunk = StreamResponse {
266 id: id.clone(),
267 object: "chat.completion.chunk".to_string(),
268 created,
269 model: model.clone(),
270 choices: vec![StreamChoice {
271 index: 0,
272 delta: StreamDelta {
273 role: None,
274 content: Some(content),
275 },
276 finish_reason: None,
277 }],
278 };
279 yield Ok(web::Bytes::from(format!("data: {}\n\n", serde_json::to_string(&chunk).unwrap())));
280 }
281 None => {
282 let done = StreamResponse {
284 id: id.clone(),
285 object: "chat.completion.chunk".to_string(),
286 created,
287 model: model.clone(),
288 choices: vec![StreamChoice {
289 index: 0,
290 delta: StreamDelta {
291 role: None,
292 content: None,
293 },
294 finish_reason: Some("stop".to_string()),
295 }],
296 };
297 yield Ok(web::Bytes::from(format!("data: {}\n\n", serde_json::to_string(&done).unwrap())));
298 yield Ok(web::Bytes::from("data: [DONE]\n\n"));
299 break;
300 }
301 }
302 }
303 };
304
305 HttpResponse::Ok()
306 .content_type("text/event-stream")
307 .insert_header(("Cache-Control", "no-cache"))
308 .insert_header(("Connection", "keep-alive"))
309 .streaming(stream)
310}