Skip to main content

cake_core/cake/api/
text.rs

1use crate::cake::Master;
2use crate::models::chat::Message;
3use crate::models::{ImageGenerator, TextGenerator};
4use actix_web::{web, HttpRequest, HttpResponse, Responder};
5use serde::{Deserialize, Serialize};
6use std::io::Write;
7use std::sync::Arc;
8use std::time::{SystemTime, UNIX_EPOCH};
9use tokio::sync::RwLock;
10
11#[derive(Deserialize)]
12pub struct ChatRequest {
13    pub messages: Vec<Message>,
14    #[serde(default)]
15    pub model: Option<String>,
16    #[serde(default)]
17    pub stream: Option<bool>,
18    #[serde(default)]
19    pub max_tokens: Option<usize>,
20    #[serde(default)]
21    pub temperature: Option<f64>,
22}
23
24#[derive(Serialize)]
25struct Usage {
26    pub prompt_tokens: usize,
27    pub completion_tokens: usize,
28    pub total_tokens: usize,
29}
30
31#[derive(Serialize)]
32struct Choice {
33    pub index: usize,
34    pub message: Message,
35    pub finish_reason: String,
36}
37
38#[derive(Serialize)]
39struct ChatResponse {
40    pub id: String,
41    pub object: String,
42    pub created: u64,
43    pub model: String,
44    pub choices: Vec<Choice>,
45    pub usage: Usage,
46}
47
48impl ChatResponse {
49    pub fn new(model: String, message: String, prompt_tokens: usize, completion_tokens: usize, finish_reason: String) -> Self {
50        let id = format!("chatcmpl-{}", uuid::Uuid::new_v4());
51        let object = String::from("chat.completion");
52        let created = SystemTime::now()
53            .duration_since(UNIX_EPOCH)
54            .unwrap()
55            .as_secs();
56        let choices = vec![Choice {
57            index: 0,
58            message: Message::assistant(message),
59            finish_reason,
60        }];
61
62        Self {
63            id,
64            object,
65            created,
66            model,
67            choices,
68            usage: Usage {
69                prompt_tokens,
70                completion_tokens,
71                total_tokens: prompt_tokens + completion_tokens,
72            },
73        }
74    }
75}
76
77// SSE streaming types
78#[derive(Serialize)]
79struct StreamChoice {
80    pub index: usize,
81    pub delta: StreamDelta,
82    pub finish_reason: Option<String>,
83}
84
85#[derive(Serialize)]
86struct StreamDelta {
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub role: Option<String>,
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub content: Option<String>,
91}
92
93#[derive(Serialize)]
94struct StreamResponse {
95    pub id: String,
96    pub object: String,
97    pub created: u64,
98    pub model: String,
99    pub choices: Vec<StreamChoice>,
100}
101
102pub async fn generate_text<TG, IG>(
103    state: web::Data<Arc<RwLock<Master<TG, IG>>>>,
104    req: HttpRequest,
105    body: web::Json<ChatRequest>,
106) -> impl Responder
107where
108    TG: TextGenerator + Send + Sync + 'static,
109    IG: ImageGenerator + Send + Sync + 'static,
110{
111    let client = req
112        .peer_addr()
113        .map(|a| a.to_string())
114        .unwrap_or_else(|| "unknown".to_string());
115    let stream = body.0.stream.unwrap_or(false);
116
117    log::info!("starting chat for {} (stream={}) ...", &client, stream);
118
119    if stream {
120        generate_text_stream(state, body.0).await
121    } else {
122        generate_text_blocking(state, body.0).await
123    }
124}
125
126async fn generate_text_blocking<TG, IG>(
127    state: web::Data<Arc<RwLock<Master<TG, IG>>>>,
128    request: ChatRequest,
129) -> HttpResponse
130where
131    TG: TextGenerator + Send + Sync + 'static,
132    IG: ImageGenerator + Send + Sync + 'static,
133{
134    let mut master = state.write().await;
135
136    if let Err(e) = master.reset() {
137        return HttpResponse::InternalServerError().json(serde_json::json!({"error": format!("{e}")}));
138    }
139
140    let num_messages = request.messages.len();
141    let llm_model = master.llm_model.as_mut().expect("LLM model not found");
142    for message in request.messages {
143        if let Err(e) = llm_model.add_message(message) {
144            return HttpResponse::InternalServerError().json(serde_json::json!({"error": format!("{e}")}));
145        }
146    }
147
148    let mut resp = String::new();
149    let mut finish_reason = "length".to_string();
150
151    let gen_result = master
152        .generate_text(request.max_tokens, |data| {
153            if data.is_empty() {
154                finish_reason = "stop".to_string();
155            } else {
156                resp += data;
157                print!("{data}");
158            }
159            let _ = std::io::stdout().flush();
160        })
161        .await;
162
163    println!();
164
165    let completion_tokens = master
166        .llm_model
167        .as_ref()
168        .expect("LLM model not found")
169        .generated_tokens();
170
171    let _ = master.goodbye().await;
172
173    if let Err(e) = gen_result {
174        return HttpResponse::InternalServerError().json(serde_json::json!({"error": format!("{e}")}));
175    }
176
177    let response = ChatResponse::new(
178        TG::MODEL_NAME.to_string(),
179        resp,
180        num_messages,
181        completion_tokens,
182        finish_reason,
183    );
184
185    HttpResponse::Ok().json(response)
186}
187
188async fn generate_text_stream<TG, IG>(
189    state: web::Data<Arc<RwLock<Master<TG, IG>>>>,
190    request: ChatRequest,
191) -> HttpResponse
192where
193    TG: TextGenerator + Send + Sync + 'static,
194    IG: ImageGenerator + Send + Sync + 'static,
195{
196    let id = format!("chatcmpl-{}", uuid::Uuid::new_v4());
197    let created = SystemTime::now()
198        .duration_since(UNIX_EPOCH)
199        .unwrap()
200        .as_secs();
201    let model = TG::MODEL_NAME.to_string();
202
203    let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Option<String>>();
204
205    let state_clone = state.clone();
206    tokio::spawn(async move {
207        let mut master = state_clone.write().await;
208
209        if let Err(e) = master.reset() {
210            log::error!("reset error: {e}");
211            let _ = tx.send(None);
212            return;
213        }
214
215        let llm_model = master.llm_model.as_mut().expect("LLM model not found");
216        for message in request.messages {
217            if let Err(e) = llm_model.add_message(message) {
218                log::error!("add_message error: {e}");
219                let _ = tx.send(None);
220                return;
221            }
222        }
223
224        if let Err(e) = master
225            .generate_text(request.max_tokens, |data| {
226                if data.is_empty() {
227                    let _ = tx.send(None);
228                } else {
229                    let _ = tx.send(Some(data.to_string()));
230                }
231            })
232            .await
233        {
234            log::error!("generate_text error: {e}");
235            let _ = tx.send(None);
236        }
237
238        let _ = master.goodbye().await;
239    });
240
241    let stream = async_stream::stream! {
242        // Send initial role chunk
243        let initial = StreamResponse {
244            id: id.clone(),
245            object: "chat.completion.chunk".to_string(),
246            created,
247            model: model.clone(),
248            choices: vec![StreamChoice {
249                index: 0,
250                delta: StreamDelta {
251                    role: Some("assistant".to_string()),
252                    content: None,
253                },
254                finish_reason: None,
255            }],
256        };
257        yield Ok::<_, actix_web::Error>(
258            web::Bytes::from(format!("data: {}\n\n", serde_json::to_string(&initial).unwrap()))
259        );
260
261        // Stream content chunks
262        while let Some(msg) = rx.recv().await {
263            match msg {
264                Some(content) => {
265                    let chunk = StreamResponse {
266                        id: id.clone(),
267                        object: "chat.completion.chunk".to_string(),
268                        created,
269                        model: model.clone(),
270                        choices: vec![StreamChoice {
271                            index: 0,
272                            delta: StreamDelta {
273                                role: None,
274                                content: Some(content),
275                            },
276                            finish_reason: None,
277                        }],
278                    };
279                    yield Ok(web::Bytes::from(format!("data: {}\n\n", serde_json::to_string(&chunk).unwrap())));
280                }
281                None => {
282                    // Final chunk with finish_reason
283                    let done = StreamResponse {
284                        id: id.clone(),
285                        object: "chat.completion.chunk".to_string(),
286                        created,
287                        model: model.clone(),
288                        choices: vec![StreamChoice {
289                            index: 0,
290                            delta: StreamDelta {
291                                role: None,
292                                content: None,
293                            },
294                            finish_reason: Some("stop".to_string()),
295                        }],
296                    };
297                    yield Ok(web::Bytes::from(format!("data: {}\n\n", serde_json::to_string(&done).unwrap())));
298                    yield Ok(web::Bytes::from("data: [DONE]\n\n"));
299                    break;
300                }
301            }
302        }
303    };
304
305    HttpResponse::Ok()
306        .content_type("text/event-stream")
307        .insert_header(("Cache-Control", "no-cache"))
308        .insert_header(("Connection", "keep-alive"))
309        .streaming(stream)
310}