dynamo_llm/protocols/openai/chat_completions/
aggregator.rs1use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
17use crate::protocols::{
18 codec::{Message, SseCodecError},
19 convert_sse_stream, Annotated,
20};
21
22use futures::{Stream, StreamExt};
23use std::{collections::HashMap, pin::Pin};
24
25type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
27
28pub struct DeltaAggregator {
32 id: String,
34 model: String,
36 created: u32,
38 usage: Option<async_openai::types::CompletionUsage>,
40 system_fingerprint: Option<String>,
42 choices: HashMap<u32, DeltaChoice>,
44 error: Option<String>,
46 service_tier: Option<async_openai::types::ServiceTierResponse>,
48}
49
50struct DeltaChoice {
52 index: u32,
54 text: String,
56 role: Option<async_openai::types::Role>,
58 finish_reason: Option<async_openai::types::FinishReason>,
60 logprobs: Option<async_openai::types::ChatChoiceLogprobs>,
62}
63
64impl Default for DeltaAggregator {
65 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl DeltaAggregator {
72 pub fn new() -> Self {
74 Self {
75 id: "".to_string(),
76 model: "".to_string(),
77 created: 0,
78 usage: None,
79 system_fingerprint: None,
80 choices: HashMap::new(),
81 error: None,
82 service_tier: None,
83 }
84 }
85
86 pub async fn apply(
96 stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>,
97 ) -> Result<NvCreateChatCompletionResponse, String> {
98 let aggregator = stream
99 .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
100 let delta = match delta.ok() {
102 Ok(delta) => delta,
103 Err(error) => {
104 aggregator.error = Some(error);
105 return aggregator;
106 }
107 };
108
109 if aggregator.error.is_none() && delta.data.is_some() {
110 let delta = delta.data.unwrap();
112 aggregator.id = delta.inner.id;
113 aggregator.model = delta.inner.model;
114 aggregator.created = delta.inner.created;
115 aggregator.service_tier = delta.inner.service_tier;
116
117 if let Some(usage) = delta.inner.usage {
119 aggregator.usage = Some(usage);
120 }
121 if let Some(system_fingerprint) = delta.inner.system_fingerprint {
122 aggregator.system_fingerprint = Some(system_fingerprint);
123 }
124
125 for choice in delta.inner.choices {
127 let state_choice =
128 aggregator
129 .choices
130 .entry(choice.index)
131 .or_insert(DeltaChoice {
132 index: choice.index,
133 text: "".to_string(),
134 role: choice.delta.role,
135 finish_reason: None,
136 logprobs: choice.logprobs,
137 });
138
139 if let Some(content) = &choice.delta.content {
141 state_choice.text.push_str(content);
142 }
143
144 if let Some(finish_reason) = choice.finish_reason {
146 state_choice.finish_reason = Some(finish_reason);
147 }
148 }
149 }
150 aggregator
151 })
152 .await;
153
154 let aggregator = if let Some(error) = aggregator.error {
156 return Err(error);
157 } else {
158 aggregator
159 };
160
161 let mut choices: Vec<_> = aggregator
163 .choices
164 .into_values()
165 .map(async_openai::types::ChatChoice::from)
166 .collect();
167
168 choices.sort_by(|a, b| a.index.cmp(&b.index));
169
170 let inner = async_openai::types::CreateChatCompletionResponse {
172 id: aggregator.id,
173 created: aggregator.created,
174 usage: aggregator.usage,
175 model: aggregator.model,
176 object: "chat.completion".to_string(),
177 system_fingerprint: aggregator.system_fingerprint,
178 choices,
179 service_tier: aggregator.service_tier,
180 };
181
182 let response = NvCreateChatCompletionResponse { inner };
183
184 Ok(response)
185 }
186}
187
188#[allow(deprecated)]
189impl From<DeltaChoice> for async_openai::types::ChatChoice {
190 fn from(delta: DeltaChoice) -> Self {
195 async_openai::types::ChatChoice {
196 message: async_openai::types::ChatCompletionResponseMessage {
197 role: delta.role.expect("delta should have a Role"),
198 content: Some(delta.text),
199 tool_calls: None,
200 refusal: None,
201 function_call: None,
202 audio: None,
203 },
204 index: delta.index,
205 finish_reason: delta.finish_reason,
206 logprobs: delta.logprobs,
207 }
208 }
209}
210
211impl NvCreateChatCompletionResponse {
212 pub async fn from_sse_stream(
221 stream: DataStream<Result<Message, SseCodecError>>,
222 ) -> Result<NvCreateChatCompletionResponse, String> {
223 let stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(stream);
224 NvCreateChatCompletionResponse::from_annotated_stream(stream).await
225 }
226
227 pub async fn from_annotated_stream(
236 stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>,
237 ) -> Result<NvCreateChatCompletionResponse, String> {
238 DeltaAggregator::apply(stream).await
239 }
240}
241
242#[cfg(test)]
243mod tests {
244
245 use super::*;
246 use futures::stream;
247
248 #[allow(deprecated)]
249 fn create_test_delta(
250 index: u32,
251 text: &str,
252 role: Option<async_openai::types::Role>,
253 finish_reason: Option<async_openai::types::FinishReason>,
254 ) -> Annotated<NvCreateChatCompletionStreamResponse> {
255 let delta = async_openai::types::ChatCompletionStreamResponseDelta {
257 content: Some(text.to_string()),
258 function_call: None,
259 tool_calls: None,
260 role,
261 refusal: None,
262 };
263 let choice = async_openai::types::ChatChoiceStream {
264 index,
265 delta,
266 finish_reason,
267 logprobs: None,
268 };
269
270 let inner = async_openai::types::CreateChatCompletionStreamResponse {
271 id: "test_id".to_string(),
272 model: "meta/llama-3.1-8b-instruct".to_string(),
273 created: 1234567890,
274 service_tier: None,
275 usage: None,
276 system_fingerprint: None,
277 choices: vec![choice],
278 object: "chat.completion".to_string(),
279 };
280
281 let data = NvCreateChatCompletionStreamResponse { inner };
282
283 Annotated {
284 data: Some(data),
285 id: Some("test_id".to_string()),
286 event: None,
287 comment: None,
288 }
289 }
290
291 #[tokio::test]
292 async fn test_empty_stream() {
293 let stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>> =
295 Box::pin(stream::empty());
296
297 let result = DeltaAggregator::apply(stream).await;
299
300 assert!(result.is_ok());
302 let response = result.unwrap();
303
304 assert_eq!(response.inner.id, "");
306 assert_eq!(response.inner.model, "");
307 assert_eq!(response.inner.created, 0);
308 assert!(response.inner.usage.is_none());
309 assert!(response.inner.system_fingerprint.is_none());
310 assert_eq!(response.inner.choices.len(), 0);
311 assert!(response.inner.service_tier.is_none());
312 }
313
314 #[tokio::test]
315 async fn test_single_delta() {
316 let annotated_delta =
318 create_test_delta(0, "Hello,", Some(async_openai::types::Role::User), None);
319
320 let stream = Box::pin(stream::iter(vec![annotated_delta]));
322
323 let result = DeltaAggregator::apply(stream).await;
325
326 assert!(result.is_ok());
328 let response = result.unwrap();
329
330 assert_eq!(response.inner.id, "test_id");
332 assert_eq!(response.inner.model, "meta/llama-3.1-8b-instruct");
333 assert_eq!(response.inner.created, 1234567890);
334 assert!(response.inner.usage.is_none());
335 assert!(response.inner.system_fingerprint.is_none());
336 assert_eq!(response.inner.choices.len(), 1);
337 let choice = &response.inner.choices[0];
338 assert_eq!(choice.index, 0);
339 assert_eq!(choice.message.content.as_ref().unwrap(), "Hello,");
340 assert!(choice.finish_reason.is_none());
341 assert_eq!(choice.message.role, async_openai::types::Role::User);
342 assert!(response.inner.service_tier.is_none());
343 }
344
345 #[tokio::test]
346 async fn test_multiple_deltas_same_choice() {
347 let annotated_delta1 =
351 create_test_delta(0, "Hello,", Some(async_openai::types::Role::User), None);
352 let annotated_delta2 = create_test_delta(
353 0,
354 " world!",
355 None,
356 Some(async_openai::types::FinishReason::Stop),
357 );
358
359 let annotated_deltas = vec![annotated_delta1, annotated_delta2];
361 let stream = Box::pin(stream::iter(annotated_deltas));
362
363 let result = DeltaAggregator::apply(stream).await;
365
366 assert!(result.is_ok());
368 let response = result.unwrap();
369
370 assert_eq!(response.inner.choices.len(), 1);
372 let choice = &response.inner.choices[0];
373 assert_eq!(choice.index, 0);
374 assert_eq!(choice.message.content.as_ref().unwrap(), "Hello, world!");
375 assert_eq!(
376 choice.finish_reason,
377 Some(async_openai::types::FinishReason::Stop)
378 );
379 assert_eq!(choice.message.role, async_openai::types::Role::User);
380 }
381
382 #[allow(deprecated)]
383 #[tokio::test]
384 async fn test_multiple_choices() {
385 let delta = async_openai::types::CreateChatCompletionStreamResponse {
388 id: "test_id".to_string(),
389 model: "test_model".to_string(),
390 created: 1234567890,
391 service_tier: None,
392 usage: None,
393 system_fingerprint: None,
394 choices: vec![
395 async_openai::types::ChatChoiceStream {
396 index: 0,
397 delta: async_openai::types::ChatCompletionStreamResponseDelta {
398 role: Some(async_openai::types::Role::Assistant),
399 content: Some("Choice 0".to_string()),
400 function_call: None,
401 tool_calls: None,
402 refusal: None,
403 },
404 finish_reason: Some(async_openai::types::FinishReason::Stop),
405 logprobs: None,
406 },
407 async_openai::types::ChatChoiceStream {
408 index: 1,
409 delta: async_openai::types::ChatCompletionStreamResponseDelta {
410 role: Some(async_openai::types::Role::Assistant),
411 content: Some("Choice 1".to_string()),
412 function_call: None,
413 tool_calls: None,
414 refusal: None,
415 },
416 finish_reason: Some(async_openai::types::FinishReason::Stop),
417 logprobs: None,
418 },
419 ],
420 object: "chat.completion".to_string(),
421 };
422
423 let data = NvCreateChatCompletionStreamResponse { inner: delta };
424
425 let annotated_delta = Annotated {
427 data: Some(data),
428 id: Some("test_id".to_string()),
429 event: None,
430 comment: None,
431 };
432 let stream = Box::pin(stream::iter(vec![annotated_delta]));
433
434 let result = DeltaAggregator::apply(stream).await;
436
437 assert!(result.is_ok());
439 let mut response = result.unwrap();
440
441 assert_eq!(response.inner.choices.len(), 2);
443 response.inner.choices.sort_by(|a, b| a.index.cmp(&b.index)); let choice0 = &response.inner.choices[0];
445 assert_eq!(choice0.index, 0);
446 assert_eq!(choice0.message.content.as_ref().unwrap(), "Choice 0");
447 assert_eq!(
448 choice0.finish_reason,
449 Some(async_openai::types::FinishReason::Stop)
450 );
451 assert_eq!(choice0.message.role, async_openai::types::Role::Assistant);
452
453 let choice1 = &response.inner.choices[1];
454 assert_eq!(choice1.index, 1);
455 assert_eq!(choice1.message.content.as_ref().unwrap(), "Choice 1");
456 assert_eq!(
457 choice1.finish_reason,
458 Some(async_openai::types::FinishReason::Stop)
459 );
460 assert_eq!(choice1.message.role, async_openai::types::Role::Assistant);
461 }
462}