Skip to main content

artificial_openai/
provider_impl_chat_stream.rs

1use std::pin::Pin;
2
3use crate::OpenAiAdapter;
4use crate::api_v1::ChatCompletionMessage;
5use crate::api_v1::ChatCompletionRequest;
6use crate::api_v1::FinishReason;
7use artificial_core::error::{ArtificialError, Result};
8use artificial_core::generic::{GenericFunctionCall, GenericFunctionCallIntent, StreamEvent};
9use artificial_core::provider::StreamingEventsProvider;
10use artificial_core::provider::{ChatCompleteParameters, StreamingChatProvider};
11use futures_core::stream::Stream;
12use std::collections::HashMap;
13
14impl StreamingChatProvider for OpenAiAdapter {
15    type Message = ChatCompletionMessage;
16
17    type Delta<'s>
18        = Pin<Box<dyn Stream<Item = Result<String>> + Send + 's>>
19    where
20        Self: 's;
21
22    fn chat_complete_stream<'s, M>(&'s self, params: ChatCompleteParameters<M>) -> Self::Delta<'s>
23    where
24        M: Into<Self::Message> + Clone + Send + Sync + 's,
25    {
26        let client = self.client.clone();
27
28        Box::pin(async_stream::try_stream! {
29        use futures_util::StreamExt;
30
31        let request: ChatCompletionRequest = params.try_into()?;
32
33
34            let stream = client.chat_completion_stream(request);
35            futures_util::pin_mut!(stream);
36
37            while let Some(chunk) = stream.next().await {
38                let chunk = chunk.map_err(ArtificialError::from)?;
39                for choice in chunk.choices {
40                    if let Some(text) = choice.delta.content {
41                        yield text;
42                    }
43                }
44            }
45
46        })
47    }
48}
49
50impl StreamingEventsProvider for OpenAiAdapter {
51    type EventStream<'s>
52        = Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send + 's>>
53    where
54        Self: 's;
55
56    fn chat_complete_events_stream<'s, M>(
57        &'s self,
58        params: ChatCompleteParameters<M>,
59    ) -> Self::EventStream<'s>
60    where
61        M: Into<Self::Message> + Clone + Send + Sync + 's,
62    {
63        let client = self.client.clone();
64
65        Box::pin(async_stream::try_stream! {
66            use futures_util::StreamExt;
67
68            let request: ChatCompletionRequest = params.try_into()?;
69
70            // Track tool-call argument fragments and first-seen id/name per tool index.
71            let mut tool_args: HashMap<usize, String> = HashMap::new();
72            let mut tool_seen: HashMap<usize, (Option<String>, Option<String>)> = HashMap::new();
73
74            let stream = client.chat_completion_stream(request);
75            futures_util::pin_mut!(stream);
76
77            while let Some(chunk) = stream.next().await {
78                let chunk = chunk.map_err(ArtificialError::from)?;
79
80                for choice in chunk.choices {
81                    // Process only the first choice to match current non-streaming behavior.
82                    if choice.index != 0 { continue; }
83
84                    // Text deltas
85                    if let Some(delta) = choice.delta.content
86                        && !delta.is_empty() {
87                            yield StreamEvent::TextDelta(delta);
88                        }
89
90                    // Tool-call deltas
91                    if let Some(tool_calls) = choice.delta.tool_calls {
92                        for tc in tool_calls {
93                            let entry = tool_seen.entry(tc.index).or_insert((None, None));
94
95                            if let Some(id) = tc.id.clone() {
96                                if entry.0.is_none() {
97                                    entry.0 = Some(id.clone());
98                                    yield StreamEvent::ToolCallStart {
99                                        index: tc.index,
100                                        id: Some(id),
101                                        name: entry.1.clone(),
102                                    };
103                                } else {
104                                    entry.0 = Some(id);
105                                }
106                            }
107
108                            if let Some(func) = tc.function {
109                                if let Some(name) = func.name {
110                                    if entry.1.is_none() {
111                                        entry.1 = Some(name.clone());
112                                        yield StreamEvent::ToolCallStart {
113                                            index: tc.index,
114                                            id: entry.0.clone(),
115                                            name: Some(name),
116                                        };
117                                    } else {
118                                        entry.1 = Some(name);
119                                    }
120                                }
121
122                                if let Some(arguments) = func.arguments {
123                                    let buf = tool_args.entry(tc.index).or_default();
124                                    buf.push_str(&arguments);
125                                    if !arguments.is_empty() {
126                                        yield StreamEvent::ToolCallArgumentsDelta {
127                                            index: tc.index,
128                                            arguments_fragment: arguments,
129                                        };
130                                    }
131                                }
132                            }
133                        }
134                    }
135
136                    // Finish conditions
137                    if let Some(reason) = choice.finish_reason {
138                        match reason {
139                            FinishReason::ToolCalls => {
140                                // Finalize tool calls by parsing accumulated argument buffers.
141                                for (index, buf) in tool_args.iter() {
142                                    let (id_opt, name_opt) = tool_seen
143                                        .get(index)
144                                        .cloned()
145                                        .unwrap_or((None, None));
146
147                                    let name = name_opt.unwrap_or_else(|| "tool".to_string());
148                                    let args_json: serde_json::Value = serde_json::from_str(buf)
149                                        .map_err(|e| ArtificialError::Invalid(format!("invalid tool arguments JSON: {e}")))?;
150
151                                    let intent = GenericFunctionCallIntent {
152                                        id: id_opt.unwrap_or_else(|| format!("toolcall-{index}")),
153                                        function: GenericFunctionCall { name, arguments: args_json },
154                                    };
155
156                                    yield StreamEvent::ToolCallComplete { index: *index, intent };
157                                }
158
159                                yield StreamEvent::MessageEnd;
160                                return;
161                            }
162                            FinishReason::Stop | FinishReason::Length | FinishReason::ContentFilter => {
163                                yield StreamEvent::MessageEnd;
164                                return;
165                            }
166                        }
167                    }
168                }
169            }
170        })
171    }
172}