artificial_openai/
provider_impl_chat_stream.rs

1use std::pin::Pin;
2
3use crate::OpenAiAdapter;
4use crate::api_v1::ChatCompletionRequest;
5use crate::api_v1::FinishReason;
6use artificial_core::error::{ArtificialError, Result};
7use artificial_core::generic::{GenericFunctionCall, GenericFunctionCallIntent, StreamEvent};
8use artificial_core::provider::StreamingEventsProvider;
9use artificial_core::provider::{ChatCompleteParameters, StreamingChatProvider};
10use futures_core::stream::Stream;
11use std::collections::HashMap;
12
13impl StreamingChatProvider for OpenAiAdapter {
14    type Delta<'s>
15        = Pin<Box<dyn Stream<Item = Result<String>> + Send + 's>>
16    where
17        Self: 's;
18
19    fn chat_complete_stream<'s, M>(&'s self, params: ChatCompleteParameters<M>) -> Self::Delta<'s>
20    where
21        M: Into<Self::Message> + Clone + Send + Sync + 's,
22    {
23        let client = self.client.clone();
24
25        Box::pin(async_stream::try_stream! {
26        use futures_util::StreamExt;
27
28        let request: ChatCompletionRequest = params.try_into()?;
29
30
31            let stream = client.chat_completion_stream(request);
32            futures_util::pin_mut!(stream);
33
34            while let Some(chunk) = stream.next().await {
35                let chunk = chunk.map_err(ArtificialError::from)?;
36                for choice in chunk.choices {
37                    if let Some(text) = choice.delta.content {
38                        yield text;
39                    }
40                }
41            }
42
43        })
44    }
45}
46
47impl StreamingEventsProvider for OpenAiAdapter {
48    type EventStream<'s>
49        = Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send + 's>>
50    where
51        Self: 's;
52
53    fn chat_complete_events_stream<'s, M>(
54        &'s self,
55        params: ChatCompleteParameters<M>,
56    ) -> Self::EventStream<'s>
57    where
58        M: Into<Self::Message> + Clone + Send + Sync + 's,
59    {
60        let client = self.client.clone();
61
62        Box::pin(async_stream::try_stream! {
63            use futures_util::StreamExt;
64
65            let request: ChatCompletionRequest = params.try_into()?;
66
67            // Track tool-call argument fragments and first-seen id/name per tool index.
68            let mut tool_args: HashMap<usize, String> = HashMap::new();
69            let mut tool_seen: HashMap<usize, (Option<String>, Option<String>)> = HashMap::new();
70
71            let stream = client.chat_completion_stream(request);
72            futures_util::pin_mut!(stream);
73
74            while let Some(chunk) = stream.next().await {
75                let chunk = chunk.map_err(ArtificialError::from)?;
76
77                for choice in chunk.choices {
78                    // Process only the first choice to match current non-streaming behavior.
79                    if choice.index != 0 { continue; }
80
81                    // Text deltas
82                    if let Some(delta) = choice.delta.content
83                        && !delta.is_empty() {
84                            yield StreamEvent::TextDelta(delta);
85                        }
86
87                    // Tool-call deltas
88                    if let Some(tool_calls) = choice.delta.tool_calls {
89                        for tc in tool_calls {
90                            let entry = tool_seen.entry(tc.index).or_insert((None, None));
91
92                            if let Some(id) = tc.id.clone() {
93                                if entry.0.is_none() {
94                                    entry.0 = Some(id.clone());
95                                    yield StreamEvent::ToolCallStart {
96                                        index: tc.index,
97                                        id: Some(id),
98                                        name: entry.1.clone(),
99                                    };
100                                } else {
101                                    entry.0 = Some(id);
102                                }
103                            }
104
105                            if let Some(func) = tc.function {
106                                if let Some(name) = func.name {
107                                    if entry.1.is_none() {
108                                        entry.1 = Some(name.clone());
109                                        yield StreamEvent::ToolCallStart {
110                                            index: tc.index,
111                                            id: entry.0.clone(),
112                                            name: Some(name),
113                                        };
114                                    } else {
115                                        entry.1 = Some(name);
116                                    }
117                                }
118
119                                if let Some(arguments) = func.arguments {
120                                    let buf = tool_args.entry(tc.index).or_default();
121                                    buf.push_str(&arguments);
122                                    if !arguments.is_empty() {
123                                        yield StreamEvent::ToolCallArgumentsDelta {
124                                            index: tc.index,
125                                            arguments_fragment: arguments,
126                                        };
127                                    }
128                                }
129                            }
130                        }
131                    }
132
133                    // Finish conditions
134                    if let Some(reason) = choice.finish_reason {
135                        match reason {
136                            FinishReason::ToolCalls => {
137                                // Finalize tool calls by parsing accumulated argument buffers.
138                                for (index, buf) in tool_args.iter() {
139                                    let (id_opt, name_opt) = tool_seen
140                                        .get(index)
141                                        .cloned()
142                                        .unwrap_or((None, None));
143
144                                    let name = name_opt.unwrap_or_else(|| "tool".to_string());
145                                    let args_json: serde_json::Value = serde_json::from_str(buf)
146                                        .map_err(|e| ArtificialError::Invalid(format!("invalid tool arguments JSON: {e}")))?;
147
148                                    let intent = GenericFunctionCallIntent {
149                                        id: id_opt.unwrap_or_else(|| format!("toolcall-{index}")),
150                                        function: GenericFunctionCall { name, arguments: args_json },
151                                    };
152
153                                    yield StreamEvent::ToolCallComplete { index: *index, intent };
154                                }
155
156                                yield StreamEvent::MessageEnd;
157                                return;
158                            }
159                            FinishReason::Stop | FinishReason::Length | FinishReason::ContentFilter => {
160                                yield StreamEvent::MessageEnd;
161                                return;
162                            }
163                        }
164                    }
165                }
166            }
167        })
168    }
169}