dynamo_llm/protocols/openai/chat_completions/
aggregator.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
17use crate::protocols::{
18    codec::{Message, SseCodecError},
19    convert_sse_stream, Annotated,
20};
21
22use futures::{Stream, StreamExt};
23use std::{collections::HashMap, pin::Pin};
24
25/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
26type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
27
28/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
29/// [`NvCreateChatCompletionResponse`]. This struct accumulates incremental responses
30/// from a streaming OpenAI API call into a complete final response.
31pub struct DeltaAggregator {
32    /// Unique identifier for the chat completion.
33    id: String,
34    /// Model name used for the chat completion.
35    model: String,
36    /// Timestamp (Unix epoch) indicating when the response was created.
37    created: u32,
38    /// Optional usage statistics for the completion request.
39    usage: Option<async_openai::types::CompletionUsage>,
40    /// Optional system fingerprint for version tracking.
41    system_fingerprint: Option<String>,
42    /// Map of incremental response choices, keyed by index.
43    choices: HashMap<u32, DeltaChoice>,
44    /// Optional error message if an error occurs during aggregation.
45    error: Option<String>,
46    /// Optional service tier information for the response.
47    service_tier: Option<async_openai::types::ServiceTierResponse>,
48}
49
50/// Represents the accumulated state of a single chat choice during streaming aggregation.
51struct DeltaChoice {
52    /// The index of the choice in the completion.
53    index: u32,
54    /// The accumulated text content for the choice.
55    text: String,
56    /// The role associated with this message (e.g., `system`, `user`, `assistant`).
57    role: Option<async_openai::types::Role>,
58    /// The reason the completion was finished (if applicable).
59    finish_reason: Option<async_openai::types::FinishReason>,
60    /// Optional log probabilities for the chat choice.
61    logprobs: Option<async_openai::types::ChatChoiceLogprobs>,
62}
63
64impl Default for DeltaAggregator {
65    /// Provides a default implementation for `DeltaAggregator` by calling [`DeltaAggregator::new`].
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl DeltaAggregator {
72    /// Creates a new, empty [`DeltaAggregator`] instance.
73    pub fn new() -> Self {
74        Self {
75            id: "".to_string(),
76            model: "".to_string(),
77            created: 0,
78            usage: None,
79            system_fingerprint: None,
80            choices: HashMap::new(),
81            error: None,
82            service_tier: None,
83        }
84    }
85
86    /// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
87    /// [`NvCreateChatCompletionResponse`].
88    ///
89    /// # Arguments
90    /// * `stream` - A stream of annotated chat completion responses.
91    ///
92    /// # Returns
93    /// * `Ok(NvCreateChatCompletionResponse)` if aggregation is successful.
94    /// * `Err(String)` if an error occurs during processing.
95    pub async fn apply(
96        stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>,
97    ) -> Result<NvCreateChatCompletionResponse, String> {
98        let aggregator = stream
99            .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
100                // Attempt to unwrap the delta, capturing any errors.
101                let delta = match delta.ok() {
102                    Ok(delta) => delta,
103                    Err(error) => {
104                        aggregator.error = Some(error);
105                        return aggregator;
106                    }
107                };
108
109                if aggregator.error.is_none() && delta.data.is_some() {
110                    // Extract the data payload from the delta.
111                    let delta = delta.data.unwrap();
112                    aggregator.id = delta.inner.id;
113                    aggregator.model = delta.inner.model;
114                    aggregator.created = delta.inner.created;
115                    aggregator.service_tier = delta.inner.service_tier;
116
117                    // Aggregate usage statistics if available.
118                    if let Some(usage) = delta.inner.usage {
119                        aggregator.usage = Some(usage);
120                    }
121                    if let Some(system_fingerprint) = delta.inner.system_fingerprint {
122                        aggregator.system_fingerprint = Some(system_fingerprint);
123                    }
124
125                    // Aggregate choices incrementally.
126                    for choice in delta.inner.choices {
127                        let state_choice =
128                            aggregator
129                                .choices
130                                .entry(choice.index)
131                                .or_insert(DeltaChoice {
132                                    index: choice.index,
133                                    text: "".to_string(),
134                                    role: choice.delta.role,
135                                    finish_reason: None,
136                                    logprobs: choice.logprobs,
137                                });
138
139                        // Append content if available.
140                        if let Some(content) = &choice.delta.content {
141                            state_choice.text.push_str(content);
142                        }
143
144                        // Update finish reason if provided.
145                        if let Some(finish_reason) = choice.finish_reason {
146                            state_choice.finish_reason = Some(finish_reason);
147                        }
148                    }
149                }
150                aggregator
151            })
152            .await;
153
154        // Return early if an error was encountered.
155        let aggregator = if let Some(error) = aggregator.error {
156            return Err(error);
157        } else {
158            aggregator
159        };
160
161        // Extract aggregated choices and sort them by index.
162        let mut choices: Vec<_> = aggregator
163            .choices
164            .into_values()
165            .map(async_openai::types::ChatChoice::from)
166            .collect();
167
168        choices.sort_by(|a, b| a.index.cmp(&b.index));
169
170        // Construct the final response object.
171        let inner = async_openai::types::CreateChatCompletionResponse {
172            id: aggregator.id,
173            created: aggregator.created,
174            usage: aggregator.usage,
175            model: aggregator.model,
176            object: "chat.completion".to_string(),
177            system_fingerprint: aggregator.system_fingerprint,
178            choices,
179            service_tier: aggregator.service_tier,
180        };
181
182        let response = NvCreateChatCompletionResponse { inner };
183
184        Ok(response)
185    }
186}
187
188#[allow(deprecated)]
189impl From<DeltaChoice> for async_openai::types::ChatChoice {
190    /// Converts a [`DeltaChoice`] into an [`async_openai::types::ChatChoice`].
191    ///
192    /// # Note
193    /// The `function_call` field is deprecated.
194    fn from(delta: DeltaChoice) -> Self {
195        async_openai::types::ChatChoice {
196            message: async_openai::types::ChatCompletionResponseMessage {
197                role: delta.role.expect("delta should have a Role"),
198                content: Some(delta.text),
199                tool_calls: None,
200                refusal: None,
201                function_call: None,
202                audio: None,
203            },
204            index: delta.index,
205            finish_reason: delta.finish_reason,
206            logprobs: delta.logprobs,
207        }
208    }
209}
210
211impl NvCreateChatCompletionResponse {
212    /// Converts an SSE stream into a [`NvCreateChatCompletionResponse`].
213    ///
214    /// # Arguments
215    /// * `stream` - A stream of SSE messages containing chat completion responses.
216    ///
217    /// # Returns
218    /// * `Ok(NvCreateChatCompletionResponse)` if aggregation succeeds.
219    /// * `Err(String)` if an error occurs.
220    pub async fn from_sse_stream(
221        stream: DataStream<Result<Message, SseCodecError>>,
222    ) -> Result<NvCreateChatCompletionResponse, String> {
223        let stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(stream);
224        NvCreateChatCompletionResponse::from_annotated_stream(stream).await
225    }
226
227    /// Aggregates an annotated stream of chat completion responses into a final response.
228    ///
229    /// # Arguments
230    /// * `stream` - A stream of annotated chat completion responses.
231    ///
232    /// # Returns
233    /// * `Ok(NvCreateChatCompletionResponse)` if aggregation succeeds.
234    /// * `Err(String)` if an error occurs.
235    pub async fn from_annotated_stream(
236        stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>,
237    ) -> Result<NvCreateChatCompletionResponse, String> {
238        DeltaAggregator::apply(stream).await
239    }
240}
241
242#[cfg(test)]
243mod tests {
244
245    use super::*;
246    use futures::stream;
247
248    #[allow(deprecated)]
249    fn create_test_delta(
250        index: u32,
251        text: &str,
252        role: Option<async_openai::types::Role>,
253        finish_reason: Option<async_openai::types::FinishReason>,
254    ) -> Annotated<NvCreateChatCompletionStreamResponse> {
255        // ALLOW: function_call is deprecated
256        let delta = async_openai::types::ChatCompletionStreamResponseDelta {
257            content: Some(text.to_string()),
258            function_call: None,
259            tool_calls: None,
260            role,
261            refusal: None,
262        };
263        let choice = async_openai::types::ChatChoiceStream {
264            index,
265            delta,
266            finish_reason,
267            logprobs: None,
268        };
269
270        let inner = async_openai::types::CreateChatCompletionStreamResponse {
271            id: "test_id".to_string(),
272            model: "meta/llama-3.1-8b-instruct".to_string(),
273            created: 1234567890,
274            service_tier: None,
275            usage: None,
276            system_fingerprint: None,
277            choices: vec![choice],
278            object: "chat.completion".to_string(),
279        };
280
281        let data = NvCreateChatCompletionStreamResponse { inner };
282
283        Annotated {
284            data: Some(data),
285            id: Some("test_id".to_string()),
286            event: None,
287            comment: None,
288        }
289    }
290
291    #[tokio::test]
292    async fn test_empty_stream() {
293        // Create an empty stream
294        let stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>> =
295            Box::pin(stream::empty());
296
297        // Call DeltaAggregator::apply
298        let result = DeltaAggregator::apply(stream).await;
299
300        // Check the result
301        assert!(result.is_ok());
302        let response = result.unwrap();
303
304        // Verify that the response is empty and has default values
305        assert_eq!(response.inner.id, "");
306        assert_eq!(response.inner.model, "");
307        assert_eq!(response.inner.created, 0);
308        assert!(response.inner.usage.is_none());
309        assert!(response.inner.system_fingerprint.is_none());
310        assert_eq!(response.inner.choices.len(), 0);
311        assert!(response.inner.service_tier.is_none());
312    }
313
314    #[tokio::test]
315    async fn test_single_delta() {
316        // Create a sample delta
317        let annotated_delta =
318            create_test_delta(0, "Hello,", Some(async_openai::types::Role::User), None);
319
320        // Create a stream
321        let stream = Box::pin(stream::iter(vec![annotated_delta]));
322
323        // Call DeltaAggregator::apply
324        let result = DeltaAggregator::apply(stream).await;
325
326        // Check the result
327        assert!(result.is_ok());
328        let response = result.unwrap();
329
330        // Verify the response fields
331        assert_eq!(response.inner.id, "test_id");
332        assert_eq!(response.inner.model, "meta/llama-3.1-8b-instruct");
333        assert_eq!(response.inner.created, 1234567890);
334        assert!(response.inner.usage.is_none());
335        assert!(response.inner.system_fingerprint.is_none());
336        assert_eq!(response.inner.choices.len(), 1);
337        let choice = &response.inner.choices[0];
338        assert_eq!(choice.index, 0);
339        assert_eq!(choice.message.content.as_ref().unwrap(), "Hello,");
340        assert!(choice.finish_reason.is_none());
341        assert_eq!(choice.message.role, async_openai::types::Role::User);
342        assert!(response.inner.service_tier.is_none());
343    }
344
345    #[tokio::test]
346    async fn test_multiple_deltas_same_choice() {
347        // Create multiple deltas with the same choice index
348        // One will have a MessageRole and no FinishReason,
349        // the other will have a FinishReason and no MessageRole
350        let annotated_delta1 =
351            create_test_delta(0, "Hello,", Some(async_openai::types::Role::User), None);
352        let annotated_delta2 = create_test_delta(
353            0,
354            " world!",
355            None,
356            Some(async_openai::types::FinishReason::Stop),
357        );
358
359        // Create a stream
360        let annotated_deltas = vec![annotated_delta1, annotated_delta2];
361        let stream = Box::pin(stream::iter(annotated_deltas));
362
363        // Call DeltaAggregator::apply
364        let result = DeltaAggregator::apply(stream).await;
365
366        // Check the result
367        assert!(result.is_ok());
368        let response = result.unwrap();
369
370        // Verify the response fields
371        assert_eq!(response.inner.choices.len(), 1);
372        let choice = &response.inner.choices[0];
373        assert_eq!(choice.index, 0);
374        assert_eq!(choice.message.content.as_ref().unwrap(), "Hello, world!");
375        assert_eq!(
376            choice.finish_reason,
377            Some(async_openai::types::FinishReason::Stop)
378        );
379        assert_eq!(choice.message.role, async_openai::types::Role::User);
380    }
381
382    #[allow(deprecated)]
383    #[tokio::test]
384    async fn test_multiple_choices() {
385        // Create a delta with multiple choices
386        // ALLOW: function_call is deprecated
387        let delta = async_openai::types::CreateChatCompletionStreamResponse {
388            id: "test_id".to_string(),
389            model: "test_model".to_string(),
390            created: 1234567890,
391            service_tier: None,
392            usage: None,
393            system_fingerprint: None,
394            choices: vec![
395                async_openai::types::ChatChoiceStream {
396                    index: 0,
397                    delta: async_openai::types::ChatCompletionStreamResponseDelta {
398                        role: Some(async_openai::types::Role::Assistant),
399                        content: Some("Choice 0".to_string()),
400                        function_call: None,
401                        tool_calls: None,
402                        refusal: None,
403                    },
404                    finish_reason: Some(async_openai::types::FinishReason::Stop),
405                    logprobs: None,
406                },
407                async_openai::types::ChatChoiceStream {
408                    index: 1,
409                    delta: async_openai::types::ChatCompletionStreamResponseDelta {
410                        role: Some(async_openai::types::Role::Assistant),
411                        content: Some("Choice 1".to_string()),
412                        function_call: None,
413                        tool_calls: None,
414                        refusal: None,
415                    },
416                    finish_reason: Some(async_openai::types::FinishReason::Stop),
417                    logprobs: None,
418                },
419            ],
420            object: "chat.completion".to_string(),
421        };
422
423        let data = NvCreateChatCompletionStreamResponse { inner: delta };
424
425        // Wrap it in Annotated and create a stream
426        let annotated_delta = Annotated {
427            data: Some(data),
428            id: Some("test_id".to_string()),
429            event: None,
430            comment: None,
431        };
432        let stream = Box::pin(stream::iter(vec![annotated_delta]));
433
434        // Call DeltaAggregator::apply
435        let result = DeltaAggregator::apply(stream).await;
436
437        // Check the result
438        assert!(result.is_ok());
439        let mut response = result.unwrap();
440
441        // Verify the response fields
442        assert_eq!(response.inner.choices.len(), 2);
443        response.inner.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered
444        let choice0 = &response.inner.choices[0];
445        assert_eq!(choice0.index, 0);
446        assert_eq!(choice0.message.content.as_ref().unwrap(), "Choice 0");
447        assert_eq!(
448            choice0.finish_reason,
449            Some(async_openai::types::FinishReason::Stop)
450        );
451        assert_eq!(choice0.message.role, async_openai::types::Role::Assistant);
452
453        let choice1 = &response.inner.choices[1];
454        assert_eq!(choice1.index, 1);
455        assert_eq!(choice1.message.content.as_ref().unwrap(), "Choice 1");
456        assert_eq!(
457            choice1.finish_reason,
458            Some(async_openai::types::FinishReason::Stop)
459        );
460        assert_eq!(choice1.message.role, async_openai::types::Role::Assistant);
461    }
462}