embacle_server/
streaming.rs1use std::convert::Infallible;
8
9use axum::response::sse::{Event, Sse};
10use axum::response::{IntoResponse, Response};
11use embacle::types::ChatStream;
12use futures::StreamExt;
13
14use crate::completions::{generate_id, unix_timestamp};
15use crate::openai_types::{ChatCompletionChunk, ChunkChoice, Delta, ResponseMessage};
16
17pub fn sse_response(stream: ChatStream, model: &str) -> Response {
25 let completion_id = generate_id();
26 let created = unix_timestamp();
27 let model = model.to_owned();
28
29 let sse_stream = {
30 let mut sent_role = false;
31
32 stream.map(move |chunk_result| {
33 match chunk_result {
34 Ok(chunk) => {
35 let (role, content, finish_reason) = if !sent_role {
36 sent_role = true;
37 if chunk.delta.is_empty() && !chunk.is_final {
38 (Some("assistant"), None, None)
40 } else {
41 (Some("assistant"), Some(chunk.delta), chunk.finish_reason)
43 }
44 } else if chunk.is_final {
45 (
46 None,
47 if chunk.delta.is_empty() {
48 None
49 } else {
50 Some(chunk.delta)
51 },
52 Some(chunk.finish_reason.unwrap_or_else(|| "stop".to_owned())),
53 )
54 } else {
55 (None, Some(chunk.delta), None)
56 };
57
58 let content = content.map(|c| {
61 if !c.is_empty() && !c.ends_with('\n') {
62 let mut normalized = c;
63 normalized.push('\n');
64 normalized
65 } else {
66 c
67 }
68 });
69
70 let data = ChatCompletionChunk {
71 id: completion_id.clone(),
72 object: "chat.completion.chunk",
73 created,
74 model: model.clone(),
75 choices: vec![ChunkChoice {
76 index: 0,
77 delta: Delta {
78 role,
79 content,
80 tool_calls: None,
81 },
82 finish_reason,
83 }],
84 };
85
86 let json = serde_json::to_string(&data).unwrap_or_default();
87 Ok::<_, Infallible>(Event::default().data(json))
88 }
89 Err(e) => {
90 let error_json = serde_json::json!({
91 "error": {
92 "message": e.message,
93 "type": "stream_error"
94 }
95 });
96 Ok(Event::default().data(error_json.to_string()))
97 }
98 }
99 })
100 };
101
102 let done_stream =
104 futures::stream::once(async { Ok::<_, Infallible>(Event::default().data("[DONE]")) });
105
106 let combined = sse_stream.chain(done_stream);
107
108 Sse::new(combined)
109 .keep_alive(axum::response::sse::KeepAlive::default())
110 .into_response()
111}
112
113pub fn sse_response_strip_fences(stream: ChatStream, model: &str) -> Response {
119 let filtered = strip_fence_chunks(stream);
120 sse_response(filtered, model)
121}
122
123fn strip_fence_chunks(stream: ChatStream) -> ChatStream {
129 use embacle::types::StreamChunk;
130
131 Box::pin(stream.filter_map(|result| async move {
132 match result {
133 Ok(chunk) => {
134 if is_markdown_fence(&chunk.delta) {
135 if chunk.is_final {
136 Some(Ok(StreamChunk {
138 delta: String::new(),
139 is_final: true,
140 finish_reason: chunk.finish_reason,
141 }))
142 } else {
143 None
144 }
145 } else {
146 Some(Ok(chunk))
147 }
148 }
149 Err(e) => Some(Err(e)),
150 }
151 }))
152}
153
154fn is_markdown_fence(text: &str) -> bool {
156 let trimmed = text.trim();
157 trimmed.starts_with("```") && trimmed.bytes().skip(3).all(|b| b.is_ascii_alphanumeric())
158}
159
160pub fn sse_single_response(message: ResponseMessage, finish_reason: &str, model: &str) -> Response {
168 let completion_id = generate_id();
169 let created = unix_timestamp();
170
171 let content_chunk = ChatCompletionChunk {
172 id: completion_id.clone(),
173 object: "chat.completion.chunk",
174 created,
175 model: model.to_owned(),
176 choices: vec![ChunkChoice {
177 index: 0,
178 delta: Delta {
179 role: Some("assistant"),
180 content: message.content,
181 tool_calls: message.tool_calls,
182 },
183 finish_reason: None,
184 }],
185 };
186
187 let final_chunk = ChatCompletionChunk {
188 id: completion_id,
189 object: "chat.completion.chunk",
190 created,
191 model: model.to_owned(),
192 choices: vec![ChunkChoice {
193 index: 0,
194 delta: Delta {
195 role: None,
196 content: None,
197 tool_calls: None,
198 },
199 finish_reason: Some(finish_reason.to_owned()),
200 }],
201 };
202
203 let events = vec![
204 serde_json::to_string(&content_chunk).unwrap_or_default(),
205 serde_json::to_string(&final_chunk).unwrap_or_default(),
206 ];
207
208 let event_stream = futures::stream::iter(
209 events
210 .into_iter()
211 .map(|json| Ok::<_, Infallible>(Event::default().data(json))),
212 );
213 let done_stream =
214 futures::stream::once(async { Ok::<_, Infallible>(Event::default().data("[DONE]")) });
215
216 let combined = event_stream.chain(done_stream);
217
218 Sse::new(combined)
219 .keep_alive(axum::response::sse::KeepAlive::default())
220 .into_response()
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn is_markdown_fence_detects_fences() {
229 assert!(is_markdown_fence("```json\n"));
230 assert!(is_markdown_fence("```\n"));
231 assert!(is_markdown_fence("```json"));
232 assert!(is_markdown_fence("```"));
233 assert!(is_markdown_fence(" ```json "));
234 }
235
236 #[test]
237 fn is_markdown_fence_rejects_non_fences() {
238 assert!(!is_markdown_fence("{\"key\": \"value\"}"));
239 assert!(!is_markdown_fence("some text"));
240 assert!(!is_markdown_fence(""));
241 assert!(!is_markdown_fence("```json is cool```"));
242 assert!(!is_markdown_fence("``` code here"));
243 }
244
245 #[tokio::test]
246 async fn strip_fence_chunks_removes_fences() {
247 use embacle::types::StreamChunk;
248
249 let chunks = vec![
250 Ok(StreamChunk {
251 delta: "```json\n".to_owned(),
252 is_final: false,
253 finish_reason: None,
254 }),
255 Ok(StreamChunk {
256 delta: "{\"key\":\"value\"}\n".to_owned(),
257 is_final: false,
258 finish_reason: None,
259 }),
260 Ok(StreamChunk {
261 delta: "```\n".to_owned(),
262 is_final: true,
263 finish_reason: Some("stop".to_owned()),
264 }),
265 ];
266
267 let stream: ChatStream = Box::pin(futures::stream::iter(chunks));
268 let filtered = strip_fence_chunks(stream);
269
270 let results: Vec<_> = filtered.collect().await;
271 assert_eq!(results.len(), 2);
272
273 let first = results[0].as_ref().unwrap();
275 assert_eq!(first.delta, "{\"key\":\"value\"}\n");
276 assert!(!first.is_final);
277
278 let second = results[1].as_ref().unwrap();
280 assert!(second.delta.is_empty());
281 assert!(second.is_final);
282 assert_eq!(second.finish_reason.as_deref(), Some("stop"));
283 }
284}