dynamo_llm/protocols/openai/chat_completions/
aggregator.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use futures::{Stream, StreamExt};
5use std::collections::HashMap;
6
7use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
8use crate::protocols::{
9    Annotated,
10    codec::{Message, SseCodecError},
11    convert_sse_stream,
12    openai::ParsingOptions,
13};
14
15use dynamo_runtime::engine::DataStream;
16
17/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
18/// [`NvCreateChatCompletionResponse`]. This struct accumulates incremental responses
19/// from a streaming OpenAI API call into a complete final response.
20pub struct DeltaAggregator {
21    /// Unique identifier for the chat completion.
22    id: String,
23    /// Model name used for the chat completion.
24    model: String,
25    /// Timestamp (Unix epoch) indicating when the response was created.
26    created: u32,
27    /// Optional usage statistics for the completion request.
28    usage: Option<dynamo_async_openai::types::CompletionUsage>,
29    /// Optional system fingerprint for version tracking.
30    system_fingerprint: Option<String>,
31    /// Map of incremental response choices, keyed by index.
32    choices: HashMap<u32, DeltaChoice>,
33    /// Optional error message if an error occurs during aggregation.
34    error: Option<String>,
35    /// Optional service tier information for the response.
36    service_tier: Option<dynamo_async_openai::types::ServiceTierResponse>,
37}
38
39/// Represents the accumulated state of a single chat choice during streaming aggregation.
40#[derive(Debug)]
41struct DeltaChoice {
42    /// The index of the choice in the completion.
43    index: u32,
44    /// The accumulated text content for the choice.
45    text: String,
46    /// The role associated with this message (e.g., `system`, `user`, `assistant`).
47    role: Option<dynamo_async_openai::types::Role>,
48    /// The reason the completion was finished (if applicable).
49    finish_reason: Option<dynamo_async_openai::types::FinishReason>,
50    /// Optional log probabilities for the chat choice.
51    logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
52    // Optional tool calls for the chat choice.
53    tool_calls: Option<Vec<dynamo_async_openai::types::ChatCompletionMessageToolCall>>,
54
55    /// Optional reasoning content for the chat choice.
56    reasoning_content: Option<String>,
57}
58
59impl Default for DeltaAggregator {
60    /// Provides a default implementation for `DeltaAggregator` by calling [`DeltaAggregator::new`].
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66fn convert_tool_chunk_to_message_tool_call(
67    chunk: &dynamo_async_openai::types::ChatCompletionMessageToolCallChunk,
68) -> Option<dynamo_async_openai::types::ChatCompletionMessageToolCall> {
69    // Convert ChatCompletionMessageToolCallChunk to ChatCompletionMessageToolCall
70    if let (Some(id), Some(r#type), Some(function)) = (&chunk.id, &chunk.r#type, &chunk.function) {
71        if let (Some(name), Some(arguments)) = (&function.name, &function.arguments) {
72            Some(dynamo_async_openai::types::ChatCompletionMessageToolCall {
73                id: id.clone(),
74                r#type: r#type.clone(),
75                function: dynamo_async_openai::types::FunctionCall {
76                    name: name.clone(),
77                    arguments: arguments.clone(),
78                },
79            })
80        } else {
81            None
82        }
83    } else {
84        None
85    }
86}
87
88impl DeltaAggregator {
89    /// Creates a new, empty [`DeltaAggregator`] instance.
90    pub fn new() -> Self {
91        Self {
92            id: "".to_string(),
93            model: "".to_string(),
94            created: 0,
95            usage: None,
96            system_fingerprint: None,
97            choices: HashMap::new(),
98            error: None,
99            service_tier: None,
100        }
101    }
102
103    /// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
104    /// [`NvCreateChatCompletionResponse`].
105    ///
106    /// # Arguments
107    /// * `stream` - A stream of annotated chat completion responses.
108    ///
109    /// # Returns
110    /// * `Ok(NvCreateChatCompletionResponse)` if aggregation is successful.
111    /// * `Err(String)` if an error occurs during processing.
112    pub async fn apply(
113        stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
114        _parsing_options: ParsingOptions,
115    ) -> Result<NvCreateChatCompletionResponse, String> {
116        let aggregator = stream
117            .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
118                // Attempt to unwrap the delta, capturing any errors.
119                let delta = match delta.ok() {
120                    Ok(delta) => delta,
121                    Err(error) => {
122                        aggregator.error = Some(error);
123                        return aggregator;
124                    }
125                };
126
127                if aggregator.error.is_none() && delta.data.is_some() {
128                    // Extract the data payload from the delta.
129                    let delta = delta.data.unwrap();
130                    aggregator.id = delta.id;
131                    aggregator.model = delta.model;
132                    aggregator.created = delta.created;
133                    aggregator.service_tier = delta.service_tier;
134
135                    // Aggregate usage statistics if available.
136                    if let Some(usage) = delta.usage {
137                        aggregator.usage = Some(usage);
138                    }
139                    if let Some(system_fingerprint) = delta.system_fingerprint {
140                        aggregator.system_fingerprint = Some(system_fingerprint);
141                    }
142
143                    // Aggregate choices incrementally.
144                    for choice in delta.choices {
145                        let state_choice =
146                            aggregator
147                                .choices
148                                .entry(choice.index)
149                                .or_insert(DeltaChoice {
150                                    index: choice.index,
151                                    text: "".to_string(),
152                                    role: choice.delta.role,
153                                    finish_reason: None,
154                                    logprobs: None,
155                                    tool_calls: None,
156                                    reasoning_content: None,
157                                });
158                        // Append content if available.
159                        if let Some(content) = &choice.delta.content {
160                            state_choice.text.push_str(content.trim_end());
161                        }
162
163                        if let Some(reasoning_content) = &choice.delta.reasoning_content {
164                            state_choice
165                                .reasoning_content
166                                .get_or_insert_with(String::new)
167                                .push_str(reasoning_content);
168                        }
169
170                        // Since one tool call is one chunk, we don't need to aggregate them
171                        // We just need to convert the ChatCompletionMessageToolCallChunk to ChatCompletionMessageToolCall and append to the state_choice.tool_calls
172                        if let Some(tool_calls) = &choice.delta.tool_calls
173                            && !tool_calls.is_empty()
174                        {
175                            // Convert ChatCompletionMessageToolCallChunk to ChatCompletionMessageToolCall
176                            let converted_tool_calls: Vec<
177                                dynamo_async_openai::types::ChatCompletionMessageToolCall,
178                            > = tool_calls
179                                .iter()
180                                .filter_map(convert_tool_chunk_to_message_tool_call)
181                                .collect();
182
183                            // Initialize and push the converted tool calls to state_choice.tool_calls
184                            // Only set tool_calls to Some if there are actual tool calls
185                            if !converted_tool_calls.is_empty() {
186                                if let Some(existing_tool_calls) = &mut state_choice.tool_calls {
187                                    existing_tool_calls.extend(converted_tool_calls);
188                                } else {
189                                    state_choice.tool_calls = Some(converted_tool_calls);
190                                }
191                            }
192                        }
193
194                        // Update finish reason if provided.
195                        if let Some(finish_reason) = choice.finish_reason {
196                            state_choice.finish_reason = Some(finish_reason);
197                        }
198
199                        // Update logprobs
200                        if let Some(logprobs) = &choice.logprobs {
201                            let state_lps = state_choice.logprobs.get_or_insert(
202                                dynamo_async_openai::types::ChatChoiceLogprobs {
203                                    content: None,
204                                    refusal: None,
205                                },
206                            );
207                            if let Some(content_lps) = &logprobs.content {
208                                state_lps
209                                    .content
210                                    .get_or_insert(Vec::new())
211                                    .extend(content_lps.clone());
212                            }
213                            if let Some(refusal_lps) = &logprobs.refusal {
214                                state_lps
215                                    .refusal
216                                    .get_or_insert(Vec::new())
217                                    .extend(refusal_lps.clone());
218                            }
219                        }
220                    }
221                }
222                aggregator
223            })
224            .await;
225
226        // Return early if an error was encountered.
227        if let Some(error) = aggregator.error {
228            return Err(error);
229        }
230
231        // Extract aggregated choices and sort them by index.
232        let mut choices: Vec<_> = aggregator
233            .choices
234            .into_values()
235            .map(dynamo_async_openai::types::ChatChoice::from)
236            .collect();
237
238        choices.sort_by(|a, b| a.index.cmp(&b.index));
239
240        // Construct the final response object.
241        let response = NvCreateChatCompletionResponse {
242            id: aggregator.id,
243            created: aggregator.created,
244            usage: aggregator.usage,
245            model: aggregator.model,
246            object: "chat.completion".to_string(),
247            system_fingerprint: aggregator.system_fingerprint,
248            choices,
249            service_tier: aggregator.service_tier,
250        };
251
252        Ok(response)
253    }
254}
255
256#[allow(deprecated)]
257impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice {
258    /// Converts a [`DeltaChoice`] into an [`dynamo_async_openai::types::ChatChoice`].
259    ///
260    /// # Note
261    /// The `function_call` field is deprecated.
262    fn from(delta: DeltaChoice) -> Self {
263        // If tool calls are present and non-empty, finish reason should be ToolCalls
264        let finish_reason = if delta
265            .tool_calls
266            .as_ref()
267            .is_some_and(|calls| !calls.is_empty())
268        {
269            Some(dynamo_async_openai::types::FinishReason::ToolCalls)
270        } else {
271            delta.finish_reason
272        };
273
274        dynamo_async_openai::types::ChatChoice {
275            message: dynamo_async_openai::types::ChatCompletionResponseMessage {
276                role: delta.role.expect("delta should have a Role"),
277                content: if delta.text.is_empty() {
278                    None
279                } else {
280                    Some(delta.text)
281                },
282                tool_calls: delta.tool_calls,
283                refusal: None,
284                function_call: None,
285                audio: None,
286                reasoning_content: delta.reasoning_content,
287            },
288            index: delta.index,
289            finish_reason,
290            logprobs: delta.logprobs,
291        }
292    }
293}
294
295/// Trait for aggregating chat completion responses from streams.
296/// Setting this macro because our async functions are not used outside of the library
297#[allow(async_fn_in_trait)]
298pub trait ChatCompletionAggregator {
299    /// Aggregates an annotated stream of chat completion responses into a final response.
300    ///
301    /// # Arguments
302    /// * `stream` - A stream of annotated chat completion responses.
303    ///
304    /// # Returns
305    /// * `Ok(NvCreateChatCompletionResponse)` if aggregation succeeds.
306    /// * `Err(String)` if an error occurs.
307    async fn from_annotated_stream(
308        stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
309        parsing_options: ParsingOptions,
310    ) -> Result<NvCreateChatCompletionResponse, String>;
311
312    /// Converts an SSE stream into a [`NvCreateChatCompletionResponse`].
313    ///
314    /// # Arguments
315    /// * `stream` - A stream of SSE messages containing chat completion responses.
316    ///
317    /// # Returns
318    /// * `Ok(NvCreateChatCompletionResponse)` if aggregation succeeds.
319    /// * `Err(String)` if an error occurs.
320    async fn from_sse_stream(
321        stream: DataStream<Result<Message, SseCodecError>>,
322        parsing_options: ParsingOptions,
323    ) -> Result<NvCreateChatCompletionResponse, String>;
324}
325
326impl ChatCompletionAggregator for dynamo_async_openai::types::CreateChatCompletionResponse {
327    async fn from_annotated_stream(
328        stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
329        parsing_options: ParsingOptions,
330    ) -> Result<NvCreateChatCompletionResponse, String> {
331        DeltaAggregator::apply(stream, parsing_options).await
332    }
333
334    async fn from_sse_stream(
335        stream: DataStream<Result<Message, SseCodecError>>,
336        parsing_options: ParsingOptions,
337    ) -> Result<NvCreateChatCompletionResponse, String> {
338        let stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(stream);
339        NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options).await
340    }
341}
342
343#[cfg(test)]
344mod tests {
345
346    use super::*;
347    use futures::stream;
348
349    #[allow(deprecated)]
350    fn create_test_delta(
351        index: u32,
352        text: &str,
353        role: Option<dynamo_async_openai::types::Role>,
354        finish_reason: Option<dynamo_async_openai::types::FinishReason>,
355        logprob: Option<f32>,
356        tool_calls: Option<&str>,
357    ) -> Annotated<NvCreateChatCompletionStreamResponse> {
358        // ALLOW: function_call is deprecated
359
360        let tool_calls: Option<serde_json::Value> =
361            tool_calls.map(|tool_calls| serde_json::from_str(tool_calls).unwrap());
362
363        let tool_call_chunks = if let Some(tool_calls) = tool_calls {
364            Some(vec![
365                dynamo_async_openai::types::ChatCompletionMessageToolCallChunk {
366                    index: 0,
367                    id: Some("test_id".to_string()),
368                    r#type: Some(dynamo_async_openai::types::ChatCompletionToolType::Function),
369                    function: Some(dynamo_async_openai::types::FunctionCallStream {
370                        name: tool_calls["name"].as_str().map(|s| s.to_string()),
371                        arguments: Some(serde_json::to_string(&tool_calls["arguments"]).unwrap()),
372                    }),
373                },
374            ])
375        } else {
376            None
377        };
378
379        let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
380            content: Some(text.to_string()),
381            function_call: None,
382            tool_calls: tool_call_chunks,
383            role,
384            refusal: None,
385            reasoning_content: None,
386        };
387        let logprobs = logprob.map(|lp| dynamo_async_openai::types::ChatChoiceLogprobs {
388            content: Some(vec![
389                dynamo_async_openai::types::ChatCompletionTokenLogprob {
390                    token: text.to_string(),
391                    logprob: lp,
392                    bytes: None,
393                    top_logprobs: vec![],
394                },
395            ]),
396            refusal: None,
397        });
398        let choice = dynamo_async_openai::types::ChatChoiceStream {
399            index,
400            delta,
401            finish_reason,
402            logprobs,
403        };
404
405        let data = NvCreateChatCompletionStreamResponse {
406            id: "test_id".to_string(),
407            model: "meta/llama-3.1-8b-instruct".to_string(),
408            created: 1234567890,
409            service_tier: None,
410            usage: None,
411            system_fingerprint: None,
412            choices: vec![choice],
413            object: "chat.completion".to_string(),
414        };
415
416        Annotated {
417            data: Some(data),
418            id: Some("test_id".to_string()),
419            event: None,
420            comment: None,
421        }
422    }
423
424    #[tokio::test]
425    async fn test_empty_stream() {
426        // Create an empty stream
427        let stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>> =
428            Box::pin(stream::empty());
429
430        // Call DeltaAggregator::apply
431        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
432
433        // Check the result
434        assert!(result.is_ok());
435        let response = result.unwrap();
436
437        // Verify that the response is empty and has default values
438        assert_eq!(response.id, "");
439        assert_eq!(response.model, "");
440        assert_eq!(response.created, 0);
441        assert!(response.usage.is_none());
442        assert!(response.system_fingerprint.is_none());
443        assert_eq!(response.choices.len(), 0);
444        assert!(response.service_tier.is_none());
445    }
446
447    #[tokio::test]
448    async fn test_single_delta() {
449        // Create a sample delta
450        let annotated_delta = create_test_delta(
451            0,
452            "Hello,",
453            Some(dynamo_async_openai::types::Role::User),
454            None,
455            None,
456            None,
457        );
458
459        // Create a stream
460        let stream = Box::pin(stream::iter(vec![annotated_delta]));
461
462        // Call DeltaAggregator::apply
463        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
464
465        // Check the result
466        assert!(result.is_ok());
467        let response = result.unwrap();
468
469        // Verify the response fields
470        assert_eq!(response.id, "test_id");
471        assert_eq!(response.model, "meta/llama-3.1-8b-instruct");
472        assert_eq!(response.created, 1234567890);
473        assert!(response.usage.is_none());
474        assert!(response.system_fingerprint.is_none());
475        assert_eq!(response.choices.len(), 1);
476        let choice = &response.choices[0];
477        assert_eq!(choice.index, 0);
478        assert_eq!(choice.message.content.as_ref().unwrap(), "Hello,");
479        assert!(choice.finish_reason.is_none());
480        assert_eq!(choice.message.role, dynamo_async_openai::types::Role::User);
481        assert!(response.service_tier.is_none());
482    }
483
484    #[tokio::test]
485    async fn test_multiple_deltas_same_choice() {
486        // Create multiple deltas with the same choice index
487        // One will have a MessageRole and no FinishReason,
488        // the other will have a FinishReason and no MessageRole
489        let annotated_delta1 = create_test_delta(
490            0,
491            "Hello,",
492            Some(dynamo_async_openai::types::Role::User),
493            None,
494            Some(-0.1),
495            None,
496        );
497        let annotated_delta2 = create_test_delta(
498            0,
499            " world!",
500            None,
501            Some(dynamo_async_openai::types::FinishReason::Stop),
502            Some(-0.2),
503            None,
504        );
505
506        // Create a stream
507        let annotated_deltas = vec![annotated_delta1, annotated_delta2];
508        let stream = Box::pin(stream::iter(annotated_deltas));
509
510        // Call DeltaAggregator::apply
511        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
512
513        // Check the result
514        assert!(result.is_ok());
515        let response = result.unwrap();
516
517        // Verify the response fields
518        assert_eq!(response.choices.len(), 1);
519        let choice = &response.choices[0];
520        assert_eq!(choice.index, 0);
521        assert_eq!(choice.message.content.as_ref().unwrap(), "Hello, world!");
522        assert_eq!(
523            choice.finish_reason,
524            Some(dynamo_async_openai::types::FinishReason::Stop)
525        );
526        assert_eq!(choice.message.role, dynamo_async_openai::types::Role::User);
527        assert_eq!(
528            choice
529                .logprobs
530                .as_ref()
531                .unwrap()
532                .content
533                .as_ref()
534                .unwrap()
535                .len(),
536            2
537        );
538        assert_eq!(
539            choice.logprobs.as_ref().unwrap().content.as_ref().unwrap()[0].logprob,
540            -0.1
541        );
542        assert_eq!(
543            choice.logprobs.as_ref().unwrap().content.as_ref().unwrap()[1].logprob,
544            -0.2
545        );
546    }
547
548    #[allow(deprecated)]
549    #[tokio::test]
550    async fn test_multiple_choices() {
551        // Create a delta with multiple choices
552        // ALLOW: function_call is deprecated
553        let data = NvCreateChatCompletionStreamResponse {
554            id: "test_id".to_string(),
555            model: "test_model".to_string(),
556            created: 1234567890,
557            service_tier: None,
558            usage: None,
559            system_fingerprint: None,
560            choices: vec![
561                dynamo_async_openai::types::ChatChoiceStream {
562                    index: 0,
563                    delta: dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
564                        role: Some(dynamo_async_openai::types::Role::Assistant),
565                        content: Some("Choice 0".to_string()),
566                        function_call: None,
567                        tool_calls: None,
568                        refusal: None,
569                        reasoning_content: None,
570                    },
571                    finish_reason: Some(dynamo_async_openai::types::FinishReason::Stop),
572                    logprobs: None,
573                },
574                dynamo_async_openai::types::ChatChoiceStream {
575                    index: 1,
576                    delta: dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
577                        role: Some(dynamo_async_openai::types::Role::Assistant),
578                        content: Some("Choice 1".to_string()),
579                        function_call: None,
580                        tool_calls: None,
581                        refusal: None,
582                        reasoning_content: None,
583                    },
584                    finish_reason: Some(dynamo_async_openai::types::FinishReason::Stop),
585                    logprobs: None,
586                },
587            ],
588            object: "chat.completion".to_string(),
589        };
590
591        // Wrap it in Annotated and create a stream
592        let annotated_delta = Annotated {
593            data: Some(data),
594            id: Some("test_id".to_string()),
595            event: None,
596            comment: None,
597        };
598        let stream = Box::pin(stream::iter(vec![annotated_delta]));
599
600        // Call DeltaAggregator::apply
601        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
602
603        // Check the result
604        assert!(result.is_ok());
605        let mut response = result.unwrap();
606
607        // Verify the response fields
608        assert_eq!(response.choices.len(), 2);
609        response.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered
610        let choice0 = &response.choices[0];
611        assert_eq!(choice0.index, 0);
612        assert_eq!(choice0.message.content.as_ref().unwrap(), "Choice 0");
613        assert_eq!(
614            choice0.finish_reason,
615            Some(dynamo_async_openai::types::FinishReason::Stop)
616        );
617        assert_eq!(
618            choice0.message.role,
619            dynamo_async_openai::types::Role::Assistant
620        );
621
622        let choice1 = &response.choices[1];
623        assert_eq!(choice1.index, 1);
624        assert_eq!(choice1.message.content.as_ref().unwrap(), "Choice 1");
625        assert_eq!(
626            choice1.finish_reason,
627            Some(dynamo_async_openai::types::FinishReason::Stop)
628        );
629        assert_eq!(
630            choice1.message.role,
631            dynamo_async_openai::types::Role::Assistant
632        );
633    }
634
635    #[tokio::test]
636    async fn test_tool_calling_finish_reason_override_from_stop() {
637        // Test that when tool calls are present but finish reason is Stop, it gets overridden to ToolCalls
638        let tool_call_json =
639            r#"{"name": "get_weather", "arguments": {"location": "New York", "unit": "celsius"}}"#;
640
641        let annotated_delta = create_test_delta(
642            0,
643            "I'll check the weather for you.",
644            Some(dynamo_async_openai::types::Role::Assistant),
645            Some(dynamo_async_openai::types::FinishReason::Stop), // Original finish reason is Stop
646            None,
647            Some(tool_call_json),
648        );
649
650        let data = annotated_delta.data.unwrap();
651        let annotated_delta = Annotated {
652            data: Some(data),
653            id: Some("test_id".to_string()),
654            event: None,
655            comment: None,
656        };
657        let stream = Box::pin(stream::iter(vec![annotated_delta]));
658
659        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
660
661        assert!(result.is_ok());
662        let response = result.unwrap();
663        assert_eq!(response.choices.len(), 1);
664        let choice = &response.choices[0];
665
666        // Verify tool calls are present
667        assert!(choice.message.tool_calls.is_some());
668        let tool_calls = choice.message.tool_calls.as_ref().unwrap();
669        assert_eq!(tool_calls.len(), 1);
670
671        // Most importantly, verify that finish reason was overridden to ToolCalls despite original being Stop
672        assert_eq!(
673            choice.finish_reason,
674            Some(dynamo_async_openai::types::FinishReason::ToolCalls)
675        );
676    }
677
678    #[tokio::test]
679    async fn test_tool_calling_finish_reason_override_from_length() {
680        // Test that when tool calls are present but finish reason is Length, it gets overridden to ToolCalls
681        let tool_call_json = r#"{"name": "search", "arguments": {"query": "rust programming"}}"#;
682
683        let annotated_delta = create_test_delta(
684            0,
685            "Let me search for that.",
686            Some(dynamo_async_openai::types::Role::Assistant),
687            Some(dynamo_async_openai::types::FinishReason::Length), // Original finish reason is Length
688            None,
689            Some(tool_call_json),
690        );
691
692        let data = annotated_delta.data.unwrap();
693        let annotated_delta = Annotated {
694            data: Some(data),
695            id: Some("test_id".to_string()),
696            event: None,
697            comment: None,
698        };
699        let stream = Box::pin(stream::iter(vec![annotated_delta]));
700
701        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
702
703        assert!(result.is_ok());
704        let response = result.unwrap();
705        assert_eq!(response.choices.len(), 1);
706        let choice = &response.choices[0];
707
708        // Verify tool calls are present
709        assert!(choice.message.tool_calls.is_some());
710        let tool_calls = choice.message.tool_calls.as_ref().unwrap();
711        assert_eq!(tool_calls.len(), 1);
712
713        // Verify that finish reason was overridden to ToolCalls despite original being Length
714        assert_eq!(
715            choice.finish_reason,
716            Some(dynamo_async_openai::types::FinishReason::ToolCalls)
717        );
718    }
719
720    #[tokio::test]
721    async fn test_tool_calling_finish_reason_override_from_none() {
722        // Test that when tool calls are present but finish reason is None, it gets set to ToolCalls
723        let tool_call_json = r#"{"name": "calculate", "arguments": {"expression": "2+2"}}"#;
724
725        let annotated_delta = create_test_delta(
726            0,
727            "I'll calculate that for you.",
728            Some(dynamo_async_openai::types::Role::Assistant),
729            None, // Original finish reason is None
730            None,
731            Some(tool_call_json),
732        );
733
734        let data = annotated_delta.data.unwrap();
735        let annotated_delta = Annotated {
736            data: Some(data),
737            id: Some("test_id".to_string()),
738            event: None,
739            comment: None,
740        };
741        let stream = Box::pin(stream::iter(vec![annotated_delta]));
742
743        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
744
745        assert!(result.is_ok());
746        let response = result.unwrap();
747        assert_eq!(response.choices.len(), 1);
748        let choice = &response.choices[0];
749
750        // Verify tool calls are present
751        assert!(choice.message.tool_calls.is_some());
752        let tool_calls = choice.message.tool_calls.as_ref().unwrap();
753        assert_eq!(tool_calls.len(), 1);
754
755        // Verify that finish reason was set to ToolCalls despite original being None
756        assert_eq!(
757            choice.finish_reason,
758            Some(dynamo_async_openai::types::FinishReason::ToolCalls)
759        );
760    }
761
762    #[tokio::test]
763    async fn test_no_tool_calling_preserves_original_finish_reason() {
764        // Test that when no tool calls are present, the original finish reason is preserved
765        let annotated_delta = create_test_delta(
766            0,
767            "This is a regular response without tool calls.",
768            Some(dynamo_async_openai::types::Role::Assistant),
769            Some(dynamo_async_openai::types::FinishReason::Stop),
770            None,
771            None, // No tool calls
772        );
773
774        let data = annotated_delta.data.unwrap();
775        let annotated_delta = Annotated {
776            data: Some(data),
777            id: Some("test_id".to_string()),
778            event: None,
779            comment: None,
780        };
781        let stream = Box::pin(stream::iter(vec![annotated_delta]));
782
783        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
784
785        assert!(result.is_ok());
786        let response = result.unwrap();
787        assert_eq!(response.choices.len(), 1);
788        let choice = &response.choices[0];
789
790        // Verify no tool calls are present
791        assert!(choice.message.tool_calls.is_none());
792
793        // Verify that original finish reason (Stop) is preserved
794        assert_eq!(
795            choice.finish_reason,
796            Some(dynamo_async_openai::types::FinishReason::Stop)
797        );
798    }
799
800    #[tokio::test]
801    async fn test_empty_tool_calls_preserves_original_finish_reason() {
802        // Test that when tool calls array is empty, the original finish reason is preserved
803        // Create a delta with empty tool calls by modifying the create_test_delta output
804        let mut annotated_delta = create_test_delta(
805            0,
806            "Response with empty tool calls array.",
807            Some(dynamo_async_openai::types::Role::Assistant),
808            Some(dynamo_async_openai::types::FinishReason::Length),
809            None,
810            None,
811        );
812
813        // Manually set empty tool calls array
814        if let Some(ref mut data) = annotated_delta.data {
815            data.choices[0].delta.tool_calls = Some(vec![]); // Empty tool calls array
816        }
817
818        let data = annotated_delta.data.unwrap();
819        let annotated_delta = Annotated {
820            data: Some(data),
821            id: Some("test_id".to_string()),
822            event: None,
823            comment: None,
824        };
825        let stream = Box::pin(stream::iter(vec![annotated_delta]));
826
827        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
828
829        assert!(result.is_ok());
830        let response = result.unwrap();
831        assert_eq!(response.choices.len(), 1);
832        let choice = &response.choices[0];
833
834        // Verify tool calls array is empty
835        assert!(choice.message.tool_calls.is_none());
836
837        // Verify that original finish reason (Length) is preserved since tool calls are empty
838        assert_eq!(
839            choice.finish_reason,
840            Some(dynamo_async_openai::types::FinishReason::Length)
841        );
842    }
843
844    #[tokio::test]
845    async fn test_tool_calling_output() {
846        // Simulate a delta with a tool call in the content
847        let tool_call_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}"#;
848
849        // Use create_test_delta to generate the annotated delta, then extract the inner delta for the test
850        let annotated_delta = create_test_delta(
851            0,
852            "Hey Dude ! What's the weather in San Francisco in Fahrenheit?",
853            Some(dynamo_async_openai::types::Role::Assistant),
854            Some(dynamo_async_openai::types::FinishReason::ToolCalls),
855            None,
856            Some(tool_call_json),
857        );
858        let data = annotated_delta.data.unwrap();
859
860        // Wrap it in Annotated and create a stream
861        let annotated_delta = Annotated {
862            data: Some(data),
863            id: Some("test_id".to_string()),
864            event: None,
865            comment: None,
866        };
867        let stream = Box::pin(stream::iter(vec![annotated_delta]));
868
869        // Call DeltaAggregator::apply
870        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
871
872        // Check the result
873        assert!(result.is_ok());
874        let response = result.unwrap();
875
876        // There should be one choice
877        assert_eq!(response.choices.len(), 1);
878        let choice = &response.choices[0];
879
880        // The tool_calls field should be present and parsed
881        assert!(choice.message.tool_calls.is_some());
882        let tool_calls = choice.message.tool_calls.as_ref().unwrap();
883        assert_eq!(tool_calls.len(), 1);
884
885        let tool_call = &tool_calls[0];
886        assert_eq!(tool_call.function.name, "get_weather");
887        // The arguments should be a JSON string containing the expected keys
888        let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments).unwrap();
889        assert_eq!(args["location"], "San Francisco, CA");
890        assert_eq!(args["unit"], "fahrenheit");
891
892        assert_eq!(
893            choice.message.content.as_ref().unwrap(),
894            "Hey Dude ! What's the weather in San Francisco in Fahrenheit?"
895        );
896
897        // The finish_reason should be ToolCalls
898        assert_eq!(
899            choice.finish_reason,
900            Some(dynamo_async_openai::types::FinishReason::ToolCalls)
901        );
902        assert_eq!(
903            choice.message.role,
904            dynamo_async_openai::types::Role::Assistant
905        );
906    }
907
908    #[tokio::test]
909    async fn test_tool_calling_finish_reason_override_from_stop_alternative() {
910        // Test that when tool calls are present but finish reason is Stop, it gets overridden to ToolCalls
911        let tool_call_json =
912            r#"{"name": "get_weather", "arguments": {"location": "New York", "unit": "celsius"}}"#;
913
914        let annotated_delta = create_test_delta(
915            0,
916            "Getting weather for New York",
917            Some(dynamo_async_openai::types::Role::Assistant),
918            Some(dynamo_async_openai::types::FinishReason::Stop), // This should be overridden
919            None,
920            Some(tool_call_json),
921        );
922
923        let stream = Box::pin(stream::iter(vec![annotated_delta]));
924
925        // Call DeltaAggregator::apply
926        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
927
928        // Check the result
929        assert!(result.is_ok());
930        let response = result.unwrap();
931
932        // There should be one choice
933        assert_eq!(response.choices.len(), 1);
934        let choice = &response.choices[0];
935
936        // The finish_reason should be ToolCalls, not Stop, because tool calls are present
937        assert_eq!(
938            choice.finish_reason,
939            Some(dynamo_async_openai::types::FinishReason::ToolCalls)
940        );
941
942        // Verify tool calls are present
943        assert!(choice.message.tool_calls.is_some());
944        let tool_calls = choice.message.tool_calls.as_ref().unwrap();
945        assert_eq!(tool_calls.len(), 1);
946        assert_eq!(tool_calls[0].function.name, "get_weather");
947    }
948}