dynamo_llm/protocols/openai/completions/
aggregator.rs1use std::{collections::HashMap, str::FromStr};
17
18use anyhow::Result;
19use futures::StreamExt;
20
21use super::{CompletionChoice, CompletionResponse, CompletionUsage, LogprobResult};
22use crate::protocols::{
23 codec::{Message, SseCodecError},
24 common::FinishReason,
25 convert_sse_stream, Annotated, DataStream,
26};
27
28pub struct DeltaAggregator {
30 id: String,
31 model: String,
32 created: u64,
33 usage: Option<CompletionUsage>,
34 system_fingerprint: Option<String>,
35 choices: HashMap<u64, DeltaChoice>,
36 error: Option<String>,
37}
38
39struct DeltaChoice {
40 index: u64,
41 text: String,
42 finish_reason: Option<FinishReason>,
43 logprobs: Option<LogprobResult>,
44}
45
46impl Default for DeltaAggregator {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl DeltaAggregator {
53 pub fn new() -> Self {
54 Self {
55 id: "".to_string(),
56 model: "".to_string(),
57 created: 0,
58 usage: None,
59 system_fingerprint: None,
60 choices: HashMap::new(),
61 error: None,
62 }
63 }
64
65 pub async fn apply(
67 stream: DataStream<Annotated<CompletionResponse>>,
68 ) -> Result<CompletionResponse> {
69 let aggregator = stream
70 .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
71 let delta = match delta.ok() {
72 Ok(delta) => delta,
73 Err(error) => {
74 aggregator.error = Some(error);
75 return aggregator;
76 }
77 };
78
79 if aggregator.error.is_none() && delta.data.is_some() {
80 let delta = delta.data.unwrap();
86 aggregator.id = delta.id;
87 aggregator.model = delta.model;
88 aggregator.created = delta.created;
89 if let Some(usage) = delta.usage {
90 aggregator.usage = Some(usage);
91 }
92 if let Some(system_fingerprint) = delta.system_fingerprint {
93 aggregator.system_fingerprint = Some(system_fingerprint);
94 }
95
96 for choice in delta.choices {
98 let state_choice =
99 aggregator
100 .choices
101 .entry(choice.index)
102 .or_insert(DeltaChoice {
103 index: choice.index,
104 text: "".to_string(),
105 finish_reason: None,
106 logprobs: choice.logprobs,
107 });
108
109 state_choice.text.push_str(&choice.text);
110
111 if let Some(finish_reason) = choice.finish_reason {
114 let reason = FinishReason::from_str(&finish_reason).ok();
115 state_choice.finish_reason = reason;
116 }
117 }
118 }
119 aggregator
120 })
121 .await;
122
123 let aggregator = if let Some(error) = aggregator.error {
125 return Err(anyhow::anyhow!(error));
126 } else {
127 aggregator
128 };
129
130 let mut choices: Vec<_> = aggregator
132 .choices
133 .into_values()
134 .map(CompletionChoice::from)
135 .collect();
136
137 choices.sort_by(|a, b| a.index.cmp(&b.index));
138
139 Ok(CompletionResponse {
140 id: aggregator.id,
141 created: aggregator.created,
142 usage: aggregator.usage,
143 model: aggregator.model,
144 object: "text_completion".to_string(),
145 system_fingerprint: aggregator.system_fingerprint,
146 choices,
147 })
148 }
149}
150
151impl From<DeltaChoice> for CompletionChoice {
152 fn from(delta: DeltaChoice) -> Self {
153 let finish_reason = delta.finish_reason.map(|reason| reason.to_string());
154
155 CompletionChoice {
156 index: delta.index,
157 text: delta.text,
158 finish_reason,
159 logprobs: delta.logprobs,
160 }
161 }
162}
163
164impl CompletionResponse {
165 pub async fn from_sse_stream(
166 stream: DataStream<Result<Message, SseCodecError>>,
167 ) -> Result<CompletionResponse> {
168 let stream = convert_sse_stream::<CompletionResponse>(stream);
169 CompletionResponse::from_annotated_stream(stream).await
170 }
171
172 pub async fn from_annotated_stream(
173 stream: DataStream<Annotated<CompletionResponse>>,
174 ) -> Result<CompletionResponse> {
175 DeltaAggregator::apply(stream).await
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use crate::protocols::openai::completions::{CompletionChoice, CompletionResponse};
182
183 use super::*;
184 use futures::stream;
185
186 fn create_test_delta(
187 index: u64,
188 text: &str,
189 finish_reason: Option<String>,
190 ) -> Annotated<CompletionResponse> {
191 Annotated {
192 data: Some(CompletionResponse {
193 id: "test_id".to_string(),
194 model: "meta/llama-3.1-8b".to_string(),
195 created: 1234567890,
196 usage: None,
197 system_fingerprint: None,
198 choices: vec![CompletionChoice {
199 index,
200 text: text.to_string(),
201 finish_reason,
202 logprobs: None,
203 }],
204 object: "text_completion".to_string(),
205 }),
206 id: Some("test_id".to_string()),
207 event: None,
208 comment: None,
209 }
210 }
211
212 #[tokio::test]
213 async fn test_empty_stream() {
214 let stream: DataStream<Annotated<CompletionResponse>> = Box::pin(stream::empty());
216
217 let result = DeltaAggregator::apply(stream).await;
219
220 assert!(result.is_ok());
222 let response = result.unwrap();
223
224 assert_eq!(response.id, "");
226 assert_eq!(response.model, "");
227 assert_eq!(response.created, 0);
228 assert!(response.usage.is_none());
229 assert!(response.system_fingerprint.is_none());
230 assert_eq!(response.choices.len(), 0);
231 }
232
233 #[tokio::test]
234 async fn test_single_delta() {
235 let annotated_delta = create_test_delta(0, "Hello,", Some("length".to_string()));
237
238 let stream = Box::pin(stream::iter(vec![annotated_delta]));
240
241 let result = DeltaAggregator::apply(stream).await;
243
244 assert!(result.is_ok());
246 let response = result.unwrap();
247
248 assert_eq!(response.id, "test_id");
250 assert_eq!(response.model, "meta/llama-3.1-8b");
251 assert_eq!(response.created, 1234567890);
252 assert!(response.usage.is_none());
253 assert!(response.system_fingerprint.is_none());
254 assert_eq!(response.choices.len(), 1);
255 let choice = &response.choices[0];
256 assert_eq!(choice.index, 0);
257 assert_eq!(choice.text, "Hello,".to_string());
258 assert_eq!(choice.finish_reason, Some("length".to_string()));
259 assert!(choice.logprobs.is_none());
260 }
261
262 #[tokio::test]
263 async fn test_multiple_deltas_same_choice() {
264 let annotated_delta1 = create_test_delta(0, "Hello,", None);
268 let annotated_delta2 = create_test_delta(0, " world!", Some("stop".to_string()));
269
270 let annotated_deltas = vec![annotated_delta1, annotated_delta2];
272 let stream = Box::pin(stream::iter(annotated_deltas));
273
274 let result = DeltaAggregator::apply(stream).await;
276
277 assert!(result.is_ok());
279 let response = result.unwrap();
280
281 assert_eq!(response.choices.len(), 1);
283 let choice = &response.choices[0];
284 assert_eq!(choice.index, 0);
285 assert_eq!(choice.text, "Hello, world!".to_string());
286 assert_eq!(choice.finish_reason, Some("stop".to_string()));
287 }
288
289 #[tokio::test]
290 async fn test_multiple_choices() {
291 let annotated_delta = Annotated {
293 data: Some(CompletionResponse {
294 id: "test_id".to_string(),
295 model: "meta/llama-3.1-8b".to_string(),
296 created: 1234567890,
297 usage: None,
298 system_fingerprint: None,
299 choices: vec![
300 CompletionChoice {
301 index: 0,
302 text: "Choice 0".to_string(),
303 finish_reason: Some("stop".to_string()),
304 logprobs: None,
305 },
306 CompletionChoice {
307 index: 1,
308 text: "Choice 1".to_string(),
309 finish_reason: Some("stop".to_string()),
310 logprobs: None,
311 },
312 ],
313 object: "text_completion".to_string(),
314 }),
315 id: Some("test_id".to_string()),
316 event: None,
317 comment: None,
318 };
319
320 let stream = Box::pin(stream::iter(vec![annotated_delta]));
322
323 let result = DeltaAggregator::apply(stream).await;
325
326 assert!(result.is_ok());
328 let mut response = result.unwrap();
329
330 assert_eq!(response.choices.len(), 2);
332 response.choices.sort_by(|a, b| a.index.cmp(&b.index)); let choice0 = &response.choices[0];
334 assert_eq!(choice0.index, 0);
335 assert_eq!(choice0.text, "Choice 0".to_string());
336 assert_eq!(choice0.finish_reason, Some("stop".to_string()));
337
338 let choice1 = &response.choices[1];
339 assert_eq!(choice1.index, 1);
340 assert_eq!(choice1.text, "Choice 1".to_string());
341 assert_eq!(choice1.finish_reason, Some("stop".to_string()));
342 }
343}