artificial_openai/
provider_impl_chat_stream.rs1use std::pin::Pin;
2
3use crate::OpenAiAdapter;
4use crate::api_v1::ChatCompletionMessage;
5use crate::api_v1::ChatCompletionRequest;
6use crate::api_v1::FinishReason;
7use artificial_core::error::{ArtificialError, Result};
8use artificial_core::generic::{GenericFunctionCall, GenericFunctionCallIntent, StreamEvent};
9use artificial_core::provider::StreamingEventsProvider;
10use artificial_core::provider::{ChatCompleteParameters, StreamingChatProvider};
11use futures_core::stream::Stream;
12use std::collections::HashMap;
13
14impl StreamingChatProvider for OpenAiAdapter {
15 type Message = ChatCompletionMessage;
16
17 type Delta<'s>
18 = Pin<Box<dyn Stream<Item = Result<String>> + Send + 's>>
19 where
20 Self: 's;
21
22 fn chat_complete_stream<'s, M>(&'s self, params: ChatCompleteParameters<M>) -> Self::Delta<'s>
23 where
24 M: Into<Self::Message> + Clone + Send + Sync + 's,
25 {
26 let client = self.client.clone();
27
28 Box::pin(async_stream::try_stream! {
29 use futures_util::StreamExt;
30
31 let request: ChatCompletionRequest = params.try_into()?;
32
33
34 let stream = client.chat_completion_stream(request);
35 futures_util::pin_mut!(stream);
36
37 while let Some(chunk) = stream.next().await {
38 let chunk = chunk.map_err(ArtificialError::from)?;
39 for choice in chunk.choices {
40 if let Some(text) = choice.delta.content {
41 yield text;
42 }
43 }
44 }
45
46 })
47 }
48}
49
50impl StreamingEventsProvider for OpenAiAdapter {
51 type EventStream<'s>
52 = Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send + 's>>
53 where
54 Self: 's;
55
56 fn chat_complete_events_stream<'s, M>(
57 &'s self,
58 params: ChatCompleteParameters<M>,
59 ) -> Self::EventStream<'s>
60 where
61 M: Into<Self::Message> + Clone + Send + Sync + 's,
62 {
63 let client = self.client.clone();
64
65 Box::pin(async_stream::try_stream! {
66 use futures_util::StreamExt;
67
68 let request: ChatCompletionRequest = params.try_into()?;
69
70 let mut tool_args: HashMap<usize, String> = HashMap::new();
72 let mut tool_seen: HashMap<usize, (Option<String>, Option<String>)> = HashMap::new();
73
74 let stream = client.chat_completion_stream(request);
75 futures_util::pin_mut!(stream);
76
77 while let Some(chunk) = stream.next().await {
78 let chunk = chunk.map_err(ArtificialError::from)?;
79
80 for choice in chunk.choices {
81 if choice.index != 0 { continue; }
83
84 if let Some(delta) = choice.delta.content
86 && !delta.is_empty() {
87 yield StreamEvent::TextDelta(delta);
88 }
89
90 if let Some(tool_calls) = choice.delta.tool_calls {
92 for tc in tool_calls {
93 let entry = tool_seen.entry(tc.index).or_insert((None, None));
94
95 if let Some(id) = tc.id.clone() {
96 if entry.0.is_none() {
97 entry.0 = Some(id.clone());
98 yield StreamEvent::ToolCallStart {
99 index: tc.index,
100 id: Some(id),
101 name: entry.1.clone(),
102 };
103 } else {
104 entry.0 = Some(id);
105 }
106 }
107
108 if let Some(func) = tc.function {
109 if let Some(name) = func.name {
110 if entry.1.is_none() {
111 entry.1 = Some(name.clone());
112 yield StreamEvent::ToolCallStart {
113 index: tc.index,
114 id: entry.0.clone(),
115 name: Some(name),
116 };
117 } else {
118 entry.1 = Some(name);
119 }
120 }
121
122 if let Some(arguments) = func.arguments {
123 let buf = tool_args.entry(tc.index).or_default();
124 buf.push_str(&arguments);
125 if !arguments.is_empty() {
126 yield StreamEvent::ToolCallArgumentsDelta {
127 index: tc.index,
128 arguments_fragment: arguments,
129 };
130 }
131 }
132 }
133 }
134 }
135
136 if let Some(reason) = choice.finish_reason {
138 match reason {
139 FinishReason::ToolCalls => {
140 for (index, buf) in tool_args.iter() {
142 let (id_opt, name_opt) = tool_seen
143 .get(index)
144 .cloned()
145 .unwrap_or((None, None));
146
147 let name = name_opt.unwrap_or_else(|| "tool".to_string());
148 let args_json: serde_json::Value = serde_json::from_str(buf)
149 .map_err(|e| ArtificialError::Invalid(format!("invalid tool arguments JSON: {e}")))?;
150
151 let intent = GenericFunctionCallIntent {
152 id: id_opt.unwrap_or_else(|| format!("toolcall-{index}")),
153 function: GenericFunctionCall { name, arguments: args_json },
154 };
155
156 yield StreamEvent::ToolCallComplete { index: *index, intent };
157 }
158
159 yield StreamEvent::MessageEnd;
160 return;
161 }
162 FinishReason::Stop | FinishReason::Length | FinishReason::ContentFilter => {
163 yield StreamEvent::MessageEnd;
164 return;
165 }
166 }
167 }
168 }
169 }
170 })
171 }
172}