artificial_openai/
provider_impl_chat_stream.rs1use std::pin::Pin;
2
3use crate::OpenAiAdapter;
4use crate::api_v1::ChatCompletionRequest;
5use crate::api_v1::FinishReason;
6use artificial_core::error::{ArtificialError, Result};
7use artificial_core::generic::{GenericFunctionCall, GenericFunctionCallIntent, StreamEvent};
8use artificial_core::provider::StreamingEventsProvider;
9use artificial_core::provider::{ChatCompleteParameters, StreamingChatProvider};
10use futures_core::stream::Stream;
11use std::collections::HashMap;
12
13impl StreamingChatProvider for OpenAiAdapter {
14 type Delta<'s>
15 = Pin<Box<dyn Stream<Item = Result<String>> + Send + 's>>
16 where
17 Self: 's;
18
19 fn chat_complete_stream<'s, M>(&'s self, params: ChatCompleteParameters<M>) -> Self::Delta<'s>
20 where
21 M: Into<Self::Message> + Clone + Send + Sync + 's,
22 {
23 let client = self.client.clone();
24
25 Box::pin(async_stream::try_stream! {
26 use futures_util::StreamExt;
27
28 let request: ChatCompletionRequest = params.try_into()?;
29
30
31 let stream = client.chat_completion_stream(request);
32 futures_util::pin_mut!(stream);
33
34 while let Some(chunk) = stream.next().await {
35 let chunk = chunk.map_err(ArtificialError::from)?;
36 for choice in chunk.choices {
37 if let Some(text) = choice.delta.content {
38 yield text;
39 }
40 }
41 }
42
43 })
44 }
45}
46
47impl StreamingEventsProvider for OpenAiAdapter {
48 type EventStream<'s>
49 = Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send + 's>>
50 where
51 Self: 's;
52
53 fn chat_complete_events_stream<'s, M>(
54 &'s self,
55 params: ChatCompleteParameters<M>,
56 ) -> Self::EventStream<'s>
57 where
58 M: Into<Self::Message> + Clone + Send + Sync + 's,
59 {
60 let client = self.client.clone();
61
62 Box::pin(async_stream::try_stream! {
63 use futures_util::StreamExt;
64
65 let request: ChatCompletionRequest = params.try_into()?;
66
67 let mut tool_args: HashMap<usize, String> = HashMap::new();
69 let mut tool_seen: HashMap<usize, (Option<String>, Option<String>)> = HashMap::new();
70
71 let stream = client.chat_completion_stream(request);
72 futures_util::pin_mut!(stream);
73
74 while let Some(chunk) = stream.next().await {
75 let chunk = chunk.map_err(ArtificialError::from)?;
76
77 for choice in chunk.choices {
78 if choice.index != 0 { continue; }
80
81 if let Some(delta) = choice.delta.content
83 && !delta.is_empty() {
84 yield StreamEvent::TextDelta(delta);
85 }
86
87 if let Some(tool_calls) = choice.delta.tool_calls {
89 for tc in tool_calls {
90 let entry = tool_seen.entry(tc.index).or_insert((None, None));
91
92 if let Some(id) = tc.id.clone() {
93 if entry.0.is_none() {
94 entry.0 = Some(id.clone());
95 yield StreamEvent::ToolCallStart {
96 index: tc.index,
97 id: Some(id),
98 name: entry.1.clone(),
99 };
100 } else {
101 entry.0 = Some(id);
102 }
103 }
104
105 if let Some(func) = tc.function {
106 if let Some(name) = func.name {
107 if entry.1.is_none() {
108 entry.1 = Some(name.clone());
109 yield StreamEvent::ToolCallStart {
110 index: tc.index,
111 id: entry.0.clone(),
112 name: Some(name),
113 };
114 } else {
115 entry.1 = Some(name);
116 }
117 }
118
119 if let Some(arguments) = func.arguments {
120 let buf = tool_args.entry(tc.index).or_default();
121 buf.push_str(&arguments);
122 if !arguments.is_empty() {
123 yield StreamEvent::ToolCallArgumentsDelta {
124 index: tc.index,
125 arguments_fragment: arguments,
126 };
127 }
128 }
129 }
130 }
131 }
132
133 if let Some(reason) = choice.finish_reason {
135 match reason {
136 FinishReason::ToolCalls => {
137 for (index, buf) in tool_args.iter() {
139 let (id_opt, name_opt) = tool_seen
140 .get(index)
141 .cloned()
142 .unwrap_or((None, None));
143
144 let name = name_opt.unwrap_or_else(|| "tool".to_string());
145 let args_json: serde_json::Value = serde_json::from_str(buf)
146 .map_err(|e| ArtificialError::Invalid(format!("invalid tool arguments JSON: {e}")))?;
147
148 let intent = GenericFunctionCallIntent {
149 id: id_opt.unwrap_or_else(|| format!("toolcall-{index}")),
150 function: GenericFunctionCall { name, arguments: args_json },
151 };
152
153 yield StreamEvent::ToolCallComplete { index: *index, intent };
154 }
155
156 yield StreamEvent::MessageEnd;
157 return;
158 }
159 FinishReason::Stop | FinishReason::Length | FinishReason::ContentFilter => {
160 yield StreamEvent::MessageEnd;
161 return;
162 }
163 }
164 }
165 }
166 }
167 })
168 }
169}