embacle_server/
streaming.rs1use std::convert::Infallible;
8
9use axum::response::sse::{Event, Sse};
10use axum::response::{IntoResponse, Response};
11use embacle::types::ChatStream;
12use futures::StreamExt;
13
14use crate::completions::{generate_id, unix_timestamp};
15use crate::openai_types::{ChatCompletionChunk, ChunkChoice, Delta, ResponseMessage};
16
17pub fn sse_response(stream: ChatStream, model: &str) -> Response {
25 let completion_id = generate_id();
26 let created = unix_timestamp();
27 let model = model.to_owned();
28
29 let sse_stream = {
30 let mut sent_role = false;
31
32 stream.map(move |chunk_result| {
33 match chunk_result {
34 Ok(chunk) => {
35 let (role, content, finish_reason) = if !sent_role {
36 sent_role = true;
37 if chunk.delta.is_empty() && !chunk.is_final {
38 (Some("assistant"), None, None)
40 } else {
41 (Some("assistant"), Some(chunk.delta), chunk.finish_reason)
43 }
44 } else if chunk.is_final {
45 (
46 None,
47 if chunk.delta.is_empty() {
48 None
49 } else {
50 Some(chunk.delta)
51 },
52 Some(chunk.finish_reason.unwrap_or_else(|| "stop".to_owned())),
53 )
54 } else {
55 (None, Some(chunk.delta), None)
56 };
57
58 let content = content.map(|c| {
61 if !c.is_empty() && !c.ends_with('\n') {
62 let mut normalized = c;
63 normalized.push('\n');
64 normalized
65 } else {
66 c
67 }
68 });
69
70 let data = ChatCompletionChunk {
71 id: completion_id.clone(),
72 object: "chat.completion.chunk",
73 created,
74 model: model.clone(),
75 choices: vec![ChunkChoice {
76 index: 0,
77 delta: Delta {
78 role,
79 content,
80 tool_calls: None,
81 },
82 finish_reason,
83 }],
84 };
85
86 let json = serde_json::to_string(&data).unwrap_or_default();
87 Ok::<_, Infallible>(Event::default().data(json))
88 }
89 Err(e) => {
90 let error_json = serde_json::json!({
91 "error": {
92 "message": e.message,
93 "type": "stream_error"
94 }
95 });
96 Ok(Event::default().data(error_json.to_string()))
97 }
98 }
99 })
100 };
101
102 let done_stream =
104 futures::stream::once(async { Ok::<_, Infallible>(Event::default().data("[DONE]")) });
105
106 let combined = sse_stream.chain(done_stream);
107
108 Sse::new(combined)
109 .keep_alive(axum::response::sse::KeepAlive::default())
110 .into_response()
111}
112
113pub fn sse_single_response(message: ResponseMessage, finish_reason: &str, model: &str) -> Response {
121 let completion_id = generate_id();
122 let created = unix_timestamp();
123
124 let content_chunk = ChatCompletionChunk {
125 id: completion_id.clone(),
126 object: "chat.completion.chunk",
127 created,
128 model: model.to_owned(),
129 choices: vec![ChunkChoice {
130 index: 0,
131 delta: Delta {
132 role: Some("assistant"),
133 content: message.content,
134 tool_calls: message.tool_calls,
135 },
136 finish_reason: None,
137 }],
138 };
139
140 let final_chunk = ChatCompletionChunk {
141 id: completion_id,
142 object: "chat.completion.chunk",
143 created,
144 model: model.to_owned(),
145 choices: vec![ChunkChoice {
146 index: 0,
147 delta: Delta {
148 role: None,
149 content: None,
150 tool_calls: None,
151 },
152 finish_reason: Some(finish_reason.to_owned()),
153 }],
154 };
155
156 let events = vec![
157 serde_json::to_string(&content_chunk).unwrap_or_default(),
158 serde_json::to_string(&final_chunk).unwrap_or_default(),
159 ];
160
161 let event_stream = futures::stream::iter(
162 events
163 .into_iter()
164 .map(|json| Ok::<_, Infallible>(Event::default().data(json))),
165 );
166 let done_stream =
167 futures::stream::once(async { Ok::<_, Infallible>(Event::default().data("[DONE]")) });
168
169 let combined = event_stream.chain(done_stream);
170
171 Sse::new(combined)
172 .keep_alive(axum::response::sse::KeepAlive::default())
173 .into_response()
174}