dynamo_llm/protocols/openai/completions/
aggregator.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::{collections::HashMap, str::FromStr};
17
18use anyhow::Result;
19use futures::StreamExt;
20
21use super::{CompletionChoice, CompletionResponse, CompletionUsage, LogprobResult};
22use crate::protocols::{
23    codec::{Message, SseCodecError},
24    common::FinishReason,
25    convert_sse_stream, Annotated, DataStream,
26};
27
28/// Aggregates a stream of [`CompletionResponse`]s into a single [`CompletionResponse`].
29pub struct DeltaAggregator {
30    id: String,
31    model: String,
32    created: u64,
33    usage: Option<CompletionUsage>,
34    system_fingerprint: Option<String>,
35    choices: HashMap<u64, DeltaChoice>,
36    error: Option<String>,
37}
38
39struct DeltaChoice {
40    index: u64,
41    text: String,
42    finish_reason: Option<FinishReason>,
43    logprobs: Option<LogprobResult>,
44}
45
46impl Default for DeltaAggregator {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl DeltaAggregator {
53    pub fn new() -> Self {
54        Self {
55            id: "".to_string(),
56            model: "".to_string(),
57            created: 0,
58            usage: None,
59            system_fingerprint: None,
60            choices: HashMap::new(),
61            error: None,
62        }
63    }
64
65    /// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
66    pub async fn apply(
67        stream: DataStream<Annotated<CompletionResponse>>,
68    ) -> Result<CompletionResponse> {
69        let aggregator = stream
70            .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
71                let delta = match delta.ok() {
72                    Ok(delta) => delta,
73                    Err(error) => {
74                        aggregator.error = Some(error);
75                        return aggregator;
76                    }
77                };
78
79                if aggregator.error.is_none() && delta.data.is_some() {
80                    // note: we could extract annotations here and add them to the aggregator
81                    // to be return as part of the NIM Response Extension
82                    // TODO(#14) - Aggregate Annotation
83
84                    // these are cheap to move so we do it every time since we are consuming the delta
85                    let delta = delta.data.unwrap();
86                    aggregator.id = delta.id;
87                    aggregator.model = delta.model;
88                    aggregator.created = delta.created;
89                    if let Some(usage) = delta.usage {
90                        aggregator.usage = Some(usage);
91                    }
92                    if let Some(system_fingerprint) = delta.system_fingerprint {
93                        aggregator.system_fingerprint = Some(system_fingerprint);
94                    }
95
96                    // handle the choices
97                    for choice in delta.choices {
98                        let state_choice =
99                            aggregator
100                                .choices
101                                .entry(choice.index)
102                                .or_insert(DeltaChoice {
103                                    index: choice.index,
104                                    text: "".to_string(),
105                                    finish_reason: None,
106                                    logprobs: choice.logprobs,
107                                });
108
109                        state_choice.text.push_str(&choice.text);
110
111                        // todo - handle logprobs
112
113                        if let Some(finish_reason) = choice.finish_reason {
114                            let reason = FinishReason::from_str(&finish_reason).ok();
115                            state_choice.finish_reason = reason;
116                        }
117                    }
118                }
119                aggregator
120            })
121            .await;
122
123        // If we have an error, return it
124        let aggregator = if let Some(error) = aggregator.error {
125            return Err(anyhow::anyhow!(error));
126        } else {
127            aggregator
128        };
129
130        // extra the aggregated deltas and sort by index
131        let mut choices: Vec<_> = aggregator
132            .choices
133            .into_values()
134            .map(CompletionChoice::from)
135            .collect();
136
137        choices.sort_by(|a, b| a.index.cmp(&b.index));
138
139        Ok(CompletionResponse {
140            id: aggregator.id,
141            created: aggregator.created,
142            usage: aggregator.usage,
143            model: aggregator.model,
144            object: "text_completion".to_string(),
145            system_fingerprint: aggregator.system_fingerprint,
146            choices,
147        })
148    }
149}
150
151impl From<DeltaChoice> for CompletionChoice {
152    fn from(delta: DeltaChoice) -> Self {
153        let finish_reason = delta.finish_reason.map(|reason| reason.to_string());
154
155        CompletionChoice {
156            index: delta.index,
157            text: delta.text,
158            finish_reason,
159            logprobs: delta.logprobs,
160        }
161    }
162}
163
164impl CompletionResponse {
165    pub async fn from_sse_stream(
166        stream: DataStream<Result<Message, SseCodecError>>,
167    ) -> Result<CompletionResponse> {
168        let stream = convert_sse_stream::<CompletionResponse>(stream);
169        CompletionResponse::from_annotated_stream(stream).await
170    }
171
172    pub async fn from_annotated_stream(
173        stream: DataStream<Annotated<CompletionResponse>>,
174    ) -> Result<CompletionResponse> {
175        DeltaAggregator::apply(stream).await
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use crate::protocols::openai::completions::{CompletionChoice, CompletionResponse};
182
183    use super::*;
184    use futures::stream;
185
186    fn create_test_delta(
187        index: u64,
188        text: &str,
189        finish_reason: Option<String>,
190    ) -> Annotated<CompletionResponse> {
191        Annotated {
192            data: Some(CompletionResponse {
193                id: "test_id".to_string(),
194                model: "meta/llama-3.1-8b".to_string(),
195                created: 1234567890,
196                usage: None,
197                system_fingerprint: None,
198                choices: vec![CompletionChoice {
199                    index,
200                    text: text.to_string(),
201                    finish_reason,
202                    logprobs: None,
203                }],
204                object: "text_completion".to_string(),
205            }),
206            id: Some("test_id".to_string()),
207            event: None,
208            comment: None,
209        }
210    }
211
212    #[tokio::test]
213    async fn test_empty_stream() {
214        // Create an empty stream
215        let stream: DataStream<Annotated<CompletionResponse>> = Box::pin(stream::empty());
216
217        // Call DeltaAggregator::apply
218        let result = DeltaAggregator::apply(stream).await;
219
220        // Check the result
221        assert!(result.is_ok());
222        let response = result.unwrap();
223
224        // Verify that the response is empty and has default values
225        assert_eq!(response.id, "");
226        assert_eq!(response.model, "");
227        assert_eq!(response.created, 0);
228        assert!(response.usage.is_none());
229        assert!(response.system_fingerprint.is_none());
230        assert_eq!(response.choices.len(), 0);
231    }
232
233    #[tokio::test]
234    async fn test_single_delta() {
235        // Create a sample delta
236        let annotated_delta = create_test_delta(0, "Hello,", Some("length".to_string()));
237
238        // Create a stream
239        let stream = Box::pin(stream::iter(vec![annotated_delta]));
240
241        // Call DeltaAggregator::apply
242        let result = DeltaAggregator::apply(stream).await;
243
244        // Check the result
245        assert!(result.is_ok());
246        let response = result.unwrap();
247
248        // Verify the response fields
249        assert_eq!(response.id, "test_id");
250        assert_eq!(response.model, "meta/llama-3.1-8b");
251        assert_eq!(response.created, 1234567890);
252        assert!(response.usage.is_none());
253        assert!(response.system_fingerprint.is_none());
254        assert_eq!(response.choices.len(), 1);
255        let choice = &response.choices[0];
256        assert_eq!(choice.index, 0);
257        assert_eq!(choice.text, "Hello,".to_string());
258        assert_eq!(choice.finish_reason, Some("length".to_string()));
259        assert!(choice.logprobs.is_none());
260    }
261
262    #[tokio::test]
263    async fn test_multiple_deltas_same_choice() {
264        // Create multiple deltas with the same choice index
265        // One will have a MessageRole and no FinishReason,
266        // the other will have a FinishReason and no MessageRole
267        let annotated_delta1 = create_test_delta(0, "Hello,", None);
268        let annotated_delta2 = create_test_delta(0, " world!", Some("stop".to_string()));
269
270        // Create a stream
271        let annotated_deltas = vec![annotated_delta1, annotated_delta2];
272        let stream = Box::pin(stream::iter(annotated_deltas));
273
274        // Call DeltaAggregator::apply
275        let result = DeltaAggregator::apply(stream).await;
276
277        // Check the result
278        assert!(result.is_ok());
279        let response = result.unwrap();
280
281        // Verify the response fields
282        assert_eq!(response.choices.len(), 1);
283        let choice = &response.choices[0];
284        assert_eq!(choice.index, 0);
285        assert_eq!(choice.text, "Hello, world!".to_string());
286        assert_eq!(choice.finish_reason, Some("stop".to_string()));
287    }
288
289    #[tokio::test]
290    async fn test_multiple_choices() {
291        // Create a delta with multiple choices
292        let annotated_delta = Annotated {
293            data: Some(CompletionResponse {
294                id: "test_id".to_string(),
295                model: "meta/llama-3.1-8b".to_string(),
296                created: 1234567890,
297                usage: None,
298                system_fingerprint: None,
299                choices: vec![
300                    CompletionChoice {
301                        index: 0,
302                        text: "Choice 0".to_string(),
303                        finish_reason: Some("stop".to_string()),
304                        logprobs: None,
305                    },
306                    CompletionChoice {
307                        index: 1,
308                        text: "Choice 1".to_string(),
309                        finish_reason: Some("stop".to_string()),
310                        logprobs: None,
311                    },
312                ],
313                object: "text_completion".to_string(),
314            }),
315            id: Some("test_id".to_string()),
316            event: None,
317            comment: None,
318        };
319
320        // Create a stream
321        let stream = Box::pin(stream::iter(vec![annotated_delta]));
322
323        // Call DeltaAggregator::apply
324        let result = DeltaAggregator::apply(stream).await;
325
326        // Check the result
327        assert!(result.is_ok());
328        let mut response = result.unwrap();
329
330        // Verify the response fields
331        assert_eq!(response.choices.len(), 2);
332        response.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered
333        let choice0 = &response.choices[0];
334        assert_eq!(choice0.index, 0);
335        assert_eq!(choice0.text, "Choice 0".to_string());
336        assert_eq!(choice0.finish_reason, Some("stop".to_string()));
337
338        let choice1 = &response.choices[1];
339        assert_eq!(choice1.index, 1);
340        assert_eq!(choice1.text, "Choice 1".to_string());
341        assert_eq!(choice1.finish_reason, Some("stop".to_string()));
342    }
343}