dynamo_llm/protocols/openai/completions/
aggregator.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5
6use anyhow::Result;
7use futures::{Stream, StreamExt};
8
9use super::NvCreateCompletionResponse;
10use crate::protocols::{
11    Annotated, DataStream,
12    codec::{Message, SseCodecError},
13    common::FinishReason,
14    convert_sse_stream,
15    openai::ParsingOptions,
16};
17
18/// Aggregates a stream of [`CompletionResponse`]s into a single [`CompletionResponse`].
19pub struct DeltaAggregator {
20    id: String,
21    model: String,
22    created: u32,
23    usage: Option<dynamo_async_openai::types::CompletionUsage>,
24    system_fingerprint: Option<String>,
25    choices: HashMap<u32, DeltaChoice>,
26    error: Option<String>,
27}
28
29struct DeltaChoice {
30    index: u32,
31    text: String,
32    finish_reason: Option<FinishReason>,
33    logprobs: Option<dynamo_async_openai::types::Logprobs>,
34}
35
36impl Default for DeltaAggregator {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl DeltaAggregator {
43    pub fn new() -> Self {
44        Self {
45            id: "".to_string(),
46            model: "".to_string(),
47            created: 0,
48            usage: None,
49            system_fingerprint: None,
50            choices: HashMap::new(),
51            error: None,
52        }
53    }
54
55    /// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
56    pub async fn apply(
57        stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
58        parsing_options: ParsingOptions,
59    ) -> Result<NvCreateCompletionResponse> {
60        tracing::debug!("Tool Call Parser: {:?}", parsing_options.tool_call_parser); // TODO: remove this once completion has tool call support
61        let aggregator = stream
62            .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
63                let delta = match delta.ok() {
64                    Ok(delta) => delta,
65                    Err(error) => {
66                        aggregator.error = Some(error);
67                        return aggregator;
68                    }
69                };
70
71                if aggregator.error.is_none() && delta.data.is_some() {
72                    // note: we could extract annotations here and add them to the aggregator
73                    // to be return as part of the NIM Response Extension
74                    // TODO(#14) - Aggregate Annotation
75
76                    // these are cheap to move so we do it every time since we are consuming the delta
77                    let delta = delta.data.unwrap();
78                    aggregator.id = delta.inner.id;
79                    aggregator.model = delta.inner.model;
80                    aggregator.created = delta.inner.created;
81                    if let Some(usage) = delta.inner.usage {
82                        aggregator.usage = Some(usage);
83                    }
84                    if let Some(system_fingerprint) = delta.inner.system_fingerprint {
85                        aggregator.system_fingerprint = Some(system_fingerprint);
86                    }
87
88                    // handle the choices
89                    for choice in delta.inner.choices {
90                        let state_choice =
91                            aggregator
92                                .choices
93                                .entry(choice.index)
94                                .or_insert(DeltaChoice {
95                                    index: choice.index,
96                                    text: "".to_string(),
97                                    finish_reason: None,
98                                    logprobs: None,
99                                });
100
101                        state_choice.text.push_str(&choice.text);
102
103                        // TODO - handle logprobs
104
105                        // Handle CompletionFinishReason -> FinishReason conversation
106                        state_choice.finish_reason = match choice.finish_reason {
107                            Some(dynamo_async_openai::types::CompletionFinishReason::Stop) => {
108                                Some(FinishReason::Stop)
109                            }
110                            Some(dynamo_async_openai::types::CompletionFinishReason::Length) => {
111                                Some(FinishReason::Length)
112                            }
113                            Some(
114                                dynamo_async_openai::types::CompletionFinishReason::ContentFilter,
115                            ) => Some(FinishReason::ContentFilter),
116                            None => None,
117                        };
118
119                        // Update logprobs
120                        if let Some(logprobs) = &choice.logprobs {
121                            let state_lps = state_choice.logprobs.get_or_insert(
122                                dynamo_async_openai::types::Logprobs {
123                                    tokens: Vec::new(),
124                                    token_logprobs: Vec::new(),
125                                    top_logprobs: Vec::new(),
126                                    text_offset: Vec::new(),
127                                },
128                            );
129                            state_lps.tokens.extend(logprobs.tokens.clone());
130                            state_lps
131                                .token_logprobs
132                                .extend(logprobs.token_logprobs.clone());
133                            state_lps.top_logprobs.extend(logprobs.top_logprobs.clone());
134                            state_lps.text_offset.extend(logprobs.text_offset.clone());
135                        }
136                    }
137                }
138                aggregator
139            })
140            .await;
141
142        // If we have an error, return it
143        let aggregator = if let Some(error) = aggregator.error {
144            return Err(anyhow::anyhow!(error));
145        } else {
146            aggregator
147        };
148
149        // extra the aggregated deltas and sort by index
150        let mut choices: Vec<_> = aggregator
151            .choices
152            .into_values()
153            .map(dynamo_async_openai::types::Choice::from)
154            .collect();
155
156        choices.sort_by(|a, b| a.index.cmp(&b.index));
157
158        let inner = dynamo_async_openai::types::CreateCompletionResponse {
159            id: aggregator.id,
160            created: aggregator.created,
161            usage: aggregator.usage,
162            model: aggregator.model,
163            object: "text_completion".to_string(),
164            system_fingerprint: aggregator.system_fingerprint,
165            choices,
166        };
167
168        let response = NvCreateCompletionResponse { inner };
169
170        Ok(response)
171    }
172}
173
174impl From<DeltaChoice> for dynamo_async_openai::types::Choice {
175    fn from(delta: DeltaChoice) -> Self {
176        let finish_reason = delta.finish_reason.map(Into::into);
177
178        dynamo_async_openai::types::Choice {
179            index: delta.index,
180            text: delta.text,
181            finish_reason,
182            logprobs: delta.logprobs,
183        }
184    }
185}
186
187impl NvCreateCompletionResponse {
188    pub async fn from_sse_stream(
189        stream: DataStream<Result<Message, SseCodecError>>,
190        parsing_options: ParsingOptions,
191    ) -> Result<NvCreateCompletionResponse> {
192        let stream = convert_sse_stream::<NvCreateCompletionResponse>(stream);
193        NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options).await
194    }
195
196    pub async fn from_annotated_stream(
197        stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
198        parsing_options: ParsingOptions,
199    ) -> Result<NvCreateCompletionResponse> {
200        DeltaAggregator::apply(stream, parsing_options).await
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use std::str::FromStr;
207
208    use futures::stream;
209
210    use super::*;
211    use crate::protocols::openai::completions::NvCreateCompletionResponse;
212
213    fn create_test_delta(
214        index: u32,
215        text: &str,
216        finish_reason: Option<String>,
217        logprob: Option<f32>,
218    ) -> Annotated<NvCreateCompletionResponse> {
219        // This will silently discard invalid_finish reason values and fall back
220        // to None - totally fine since this is test code
221        let finish_reason = finish_reason
222            .as_deref()
223            .and_then(|s| FinishReason::from_str(s).ok())
224            .map(Into::into);
225
226        let logprobs = logprob.map(|lp| dynamo_async_openai::types::Logprobs {
227            tokens: vec![text.to_string()],
228            token_logprobs: vec![Some(lp)],
229            top_logprobs: vec![
230                serde_json::to_value(dynamo_async_openai::types::TopLogprobs {
231                    token: text.to_string(),
232                    logprob: lp,
233                    bytes: None,
234                })
235                .unwrap(),
236            ],
237            text_offset: vec![0],
238        });
239
240        let inner = dynamo_async_openai::types::CreateCompletionResponse {
241            id: "test_id".to_string(),
242            model: "meta/llama-3.1-8b".to_string(),
243            created: 1234567890,
244            usage: None,
245            system_fingerprint: None,
246            choices: vec![dynamo_async_openai::types::Choice {
247                index,
248                text: text.to_string(),
249                finish_reason,
250                logprobs,
251            }],
252            object: "text_completion".to_string(),
253        };
254
255        let response = NvCreateCompletionResponse { inner };
256
257        Annotated {
258            data: Some(response),
259            id: Some("test_id".to_string()),
260            event: None,
261            comment: None,
262        }
263    }
264
265    #[tokio::test]
266    async fn test_empty_stream() {
267        // Create an empty stream
268        let stream: DataStream<Annotated<NvCreateCompletionResponse>> = Box::pin(stream::empty());
269
270        // Call DeltaAggregator::apply
271        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
272
273        // Check the result
274        assert!(result.is_ok());
275        let response = result.unwrap();
276
277        // Verify that the response is empty and has default values
278        assert_eq!(response.inner.id, "");
279        assert_eq!(response.inner.model, "");
280        assert_eq!(response.inner.created, 0);
281        assert!(response.inner.usage.is_none());
282        assert!(response.inner.system_fingerprint.is_none());
283        assert_eq!(response.inner.choices.len(), 0);
284    }
285
286    #[tokio::test]
287    async fn test_single_delta() {
288        // Create a sample delta
289        let annotated_delta = create_test_delta(0, "Hello,", Some("length".to_string()), None);
290
291        // Create a stream
292        let stream = Box::pin(stream::iter(vec![annotated_delta]));
293
294        // Call DeltaAggregator::apply
295        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
296
297        // Check the result
298        assert!(result.is_ok());
299        let response = result.unwrap();
300
301        // Verify the response fields
302        assert_eq!(response.inner.id, "test_id");
303        assert_eq!(response.inner.model, "meta/llama-3.1-8b");
304        assert_eq!(response.inner.created, 1234567890);
305        assert!(response.inner.usage.is_none());
306        assert!(response.inner.system_fingerprint.is_none());
307        assert_eq!(response.inner.choices.len(), 1);
308        let choice = &response.inner.choices[0];
309        assert_eq!(choice.index, 0);
310        assert_eq!(choice.text, "Hello,".to_string());
311        assert_eq!(
312            choice.finish_reason,
313            Some(dynamo_async_openai::types::CompletionFinishReason::Length)
314        );
315        assert_eq!(
316            choice.finish_reason,
317            Some(dynamo_async_openai::types::CompletionFinishReason::Length)
318        );
319        assert!(choice.logprobs.is_none());
320    }
321
322    #[tokio::test]
323    async fn test_multiple_deltas_same_choice() {
324        // Create multiple deltas with the same choice index
325        // One will have a MessageRole and no FinishReason,
326        // the other will have a FinishReason and no MessageRole
327        let annotated_delta1 = create_test_delta(0, "Hello,", None, Some(-0.1));
328        let annotated_delta2 =
329            create_test_delta(0, " world!", Some("stop".to_string()), Some(-0.2));
330
331        // Create a stream
332        let annotated_deltas = vec![annotated_delta1, annotated_delta2];
333        let stream = Box::pin(stream::iter(annotated_deltas));
334
335        // Call DeltaAggregator::apply
336        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
337
338        // Check the result
339        assert!(result.is_ok());
340        let response = result.unwrap();
341
342        // Verify the response fields
343        assert_eq!(response.inner.choices.len(), 1);
344        let choice = &response.inner.choices[0];
345        assert_eq!(choice.index, 0);
346        assert_eq!(choice.text, "Hello, world!".to_string());
347        assert_eq!(
348            choice.finish_reason,
349            Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
350        );
351        assert_eq!(choice.logprobs.as_ref().unwrap().tokens.len(), 2);
352        assert_eq!(
353            choice.logprobs.as_ref().unwrap().token_logprobs,
354            vec![Some(-0.1), Some(-0.2)]
355        );
356    }
357
358    #[tokio::test]
359    async fn test_multiple_choices() {
360        // Create a delta with multiple choices
361        let inner = dynamo_async_openai::types::CreateCompletionResponse {
362            id: "test_id".to_string(),
363            model: "meta/llama-3.1-8b".to_string(),
364            created: 1234567890,
365            usage: None,
366            system_fingerprint: None,
367            choices: vec![
368                dynamo_async_openai::types::Choice {
369                    index: 0,
370                    text: "Choice 0".to_string(),
371                    finish_reason: Some(dynamo_async_openai::types::CompletionFinishReason::Stop),
372                    logprobs: None,
373                },
374                dynamo_async_openai::types::Choice {
375                    index: 1,
376                    text: "Choice 1".to_string(),
377                    finish_reason: Some(dynamo_async_openai::types::CompletionFinishReason::Stop),
378                    logprobs: None,
379                },
380            ],
381            object: "text_completion".to_string(),
382        };
383
384        let response = NvCreateCompletionResponse { inner };
385
386        let annotated_delta = Annotated {
387            data: Some(response),
388            id: Some("test_id".to_string()),
389            event: None,
390            comment: None,
391        };
392
393        // Create a stream
394        let stream = Box::pin(stream::iter(vec![annotated_delta]));
395
396        // Call DeltaAggregator::apply
397        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
398
399        // Check the result
400        assert!(result.is_ok());
401        let mut response = result.unwrap();
402
403        // Verify the response fields
404        assert_eq!(response.inner.choices.len(), 2);
405        response.inner.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered
406        let choice0 = &response.inner.choices[0];
407        assert_eq!(choice0.index, 0);
408        assert_eq!(choice0.text, "Choice 0".to_string());
409        assert_eq!(
410            choice0.finish_reason,
411            Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
412        );
413        assert_eq!(
414            choice0.finish_reason,
415            Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
416        );
417
418        let choice1 = &response.inner.choices[1];
419        assert_eq!(choice1.index, 1);
420        assert_eq!(choice1.text, "Choice 1".to_string());
421        assert_eq!(
422            choice1.finish_reason,
423            Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
424        );
425        assert_eq!(
426            choice1.finish_reason,
427            Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
428        );
429    }
430}