rig/providers/openai/completion/
streaming.rs1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::json_utils;
3use crate::json_utils::merge;
4use crate::providers::openai::completion::{CompletionModel, Usage};
5use crate::streaming;
6use crate::streaming::RawStreamingChoice;
7use async_stream::stream;
8use futures::StreamExt;
9use reqwest::RequestBuilder;
10use reqwest_eventsource::Event;
11use reqwest_eventsource::RequestBuilderExt;
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14use std::collections::HashMap;
15use tracing::{debug, info_span};
16use tracing_futures::Instrument;
17
18#[derive(Debug, Serialize, Deserialize, Clone)]
22pub struct StreamingFunction {
23 #[serde(default)]
24 pub name: Option<String>,
25 #[serde(default)]
26 pub arguments: String,
27}
28
29#[derive(Debug, Serialize, Deserialize, Clone)]
30pub struct StreamingToolCall {
31 pub index: usize,
32 pub id: Option<String>,
33 pub function: StreamingFunction,
34}
35
36#[derive(Deserialize, Debug)]
37struct StreamingDelta {
38 #[serde(default)]
39 content: Option<String>,
40 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
41 tool_calls: Vec<StreamingToolCall>,
42}
43
44#[derive(Deserialize, Debug)]
45struct StreamingChoice {
46 delta: StreamingDelta,
47}
48
49#[derive(Deserialize, Debug)]
50struct StreamingCompletionChunk {
51 choices: Vec<StreamingChoice>,
52 usage: Option<Usage>,
53}
54
55#[derive(Clone, Serialize, Deserialize)]
56pub struct StreamingCompletionResponse {
57 pub usage: Usage,
58}
59
60impl GetTokenUsage for StreamingCompletionResponse {
61 fn token_usage(&self) -> Option<crate::completion::Usage> {
62 let mut usage = crate::completion::Usage::new();
63 usage.input_tokens = self.usage.prompt_tokens as u64;
64 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
65 usage.total_tokens = self.usage.total_tokens as u64;
66 Some(usage)
67 }
68}
69
70impl CompletionModel<reqwest::Client> {
71 pub(crate) async fn stream(
72 &self,
73 completion_request: CompletionRequest,
74 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
75 {
76 let request = super::CompletionRequest::try_from((self.model.clone(), completion_request))?;
77 let request_messages = serde_json::to_string(&request.messages)
78 .expect("Converting to JSON from a Rust struct shouldn't fail");
79 let mut request_as_json = serde_json::to_value(request).expect("this should never fail");
80
81 request_as_json = merge(
82 request_as_json,
83 json!({"stream": true, "stream_options": {"include_usage": true}}),
84 );
85
86 let builder = self
87 .client
88 .post_reqwest("/chat/completions")
89 .json(&request_as_json);
90
91 let span = if tracing::Span::current().is_disabled() {
92 info_span!(
93 target: "rig::completions",
94 "chat",
95 gen_ai.operation.name = "chat",
96 gen_ai.provider.name = "openai",
97 gen_ai.request.model = self.model,
98 gen_ai.response.id = tracing::field::Empty,
99 gen_ai.response.model = self.model,
100 gen_ai.usage.output_tokens = tracing::field::Empty,
101 gen_ai.usage.input_tokens = tracing::field::Empty,
102 gen_ai.input.messages = request_messages,
103 gen_ai.output.messages = tracing::field::Empty,
104 )
105 } else {
106 tracing::Span::current()
107 };
108
109 tracing::Instrument::instrument(send_compatible_streaming_request(builder), span).await
110 }
111}
112
113pub async fn send_compatible_streaming_request(
114 request_builder: RequestBuilder,
115) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
116 let span = tracing::Span::current();
117 let mut event_source = request_builder
119 .eventsource()
120 .expect("Cloning request must always succeed");
121
122 let stream = stream! {
123 let span = tracing::Span::current();
124 let mut final_usage = Usage::new();
125
126 let mut tool_calls: HashMap<usize, (String, String, String)> = HashMap::new();
128
129 let mut text_content = String::new();
130
131 while let Some(event_result) = event_source.next().await {
132 match event_result {
133 Ok(Event::Open) => {
134 tracing::trace!("SSE connection opened");
135 continue;
136 }
137 Ok(Event::Message(message)) => {
138 if message.data.trim().is_empty() || message.data == "[DONE]" {
139 continue;
140 }
141
142 let data = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
143 let Ok(data) = data else {
144 let err = data.unwrap_err();
145 debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
146 continue;
147 };
148
149 if let Some(choice) = data.choices.first() {
150 let delta = &choice.delta;
151
152 if !delta.tool_calls.is_empty() {
154 for tool_call in &delta.tool_calls {
155 let function = tool_call.function.clone();
156
157 if function.name.is_some() && function.arguments.is_empty() {
159 let id = tool_call.id.clone().unwrap_or_default();
160 tool_calls.insert(
161 tool_call.index,
162 (id, function.name.clone().unwrap(), "".to_string()),
163 );
164 }
165 else if function.name.clone().is_none_or(|s| s.is_empty())
169 && !function.arguments.is_empty()
170 {
171 if let Some((id, name, arguments)) =
172 tool_calls.get(&tool_call.index)
173 {
174 let new_arguments = &tool_call.function.arguments;
175 let arguments = format!("{arguments}{new_arguments}");
176 tool_calls.insert(
177 tool_call.index,
178 (id.clone(), name.clone(), arguments),
179 );
180 } else {
181 debug!("Partial tool call received but tool call was never started.");
182 }
183 }
184 else {
186 let id = tool_call.id.clone().unwrap_or_default();
187 let name = function.name.expect("tool call should have a name");
188 let arguments = function.arguments;
189 let Ok(arguments) = serde_json::from_str(&arguments) else {
190 debug!("Couldn't serialize '{arguments}' as JSON");
191 continue;
192 };
193
194 yield Ok(streaming::RawStreamingChoice::ToolCall {
195 id,
196 name,
197 arguments,
198 call_id: None,
199 });
200 }
201 }
202 }
203
204 if let Some(content) = &choice.delta.content {
206 text_content += content;
207 yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
208 }
209 }
210
211 if let Some(usage) = data.usage {
213 final_usage = usage.clone();
214 }
215 }
216 Err(reqwest_eventsource::Error::StreamEnded) => {
217 break;
218 }
219 Err(error) => {
220 tracing::error!(?error, "SSE error");
221 yield Err(CompletionError::ResponseError(error.to_string()));
222 break;
223 }
224 }
225 }
226
227 event_source.close();
229
230 let mut vec_toolcalls = vec![];
231
232 for (_, (id, name, arguments)) in tool_calls {
234 let Ok(arguments) = serde_json::from_str::<serde_json::Value>(&arguments) else {
235 continue;
236 };
237
238 vec_toolcalls.push(super::ToolCall {
239 r#type: super::ToolType::Function,
240 id: id.clone(),
241 function: super::Function {
242 name: name.clone(), arguments: arguments.clone()
243 },
244 });
245
246 yield Ok(RawStreamingChoice::ToolCall {
247 id,
248 name,
249 arguments,
250 call_id: None,
251 });
252 }
253
254 let message_output = super::Message::Assistant {
255 content: vec![super::AssistantContent::Text { text: text_content }],
256 refusal: None,
257 audio: None,
258 name: None,
259 tool_calls: vec_toolcalls
260 };
261
262 span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
263 span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
264 span.record("gen_ai.output.messages", serde_json::to_string(&vec![message_output]).expect("Converting from a Rust struct should always convert to JSON without failing"));
265
266 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
267 usage: final_usage.clone()
268 }));
269 }.instrument(span);
270
271 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
272 stream,
273 )))
274}