rig/providers/openai/responses_api/
streaming.rs1use crate::completion::{CompletionError, GetTokenUsage};
4use crate::providers::openai::responses_api::{
5 ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
6};
7use crate::streaming;
8use crate::streaming::RawStreamingChoice;
9use async_stream::stream;
10use futures::StreamExt;
11use reqwest_eventsource::Event;
12use reqwest_eventsource::RequestBuilderExt;
13use serde::{Deserialize, Serialize};
14use tracing::{debug, info_span};
15use tracing_futures::Instrument as _;
16
17use super::{CompletionResponse, Output};
18
19#[derive(Debug, Serialize, Deserialize, Clone)]
28#[serde(untagged)]
29pub enum StreamingCompletionChunk {
30 Response(Box<ResponseChunk>),
31 Delta(ItemChunk),
32}
33
34#[derive(Debug, Serialize, Deserialize, Clone)]
36pub struct StreamingCompletionResponse {
37 pub usage: ResponsesUsage,
39}
40
41impl GetTokenUsage for StreamingCompletionResponse {
42 fn token_usage(&self) -> Option<crate::completion::Usage> {
43 let mut usage = crate::completion::Usage::new();
44 usage.input_tokens = self.usage.input_tokens;
45 usage.output_tokens = self.usage.output_tokens;
46 usage.total_tokens = self.usage.total_tokens;
47 Some(usage)
48 }
49}
50
51#[derive(Debug, Serialize, Deserialize, Clone)]
53pub struct ResponseChunk {
54 #[serde(rename = "type")]
56 pub kind: ResponseChunkKind,
57 pub response: CompletionResponse,
59 pub sequence_number: u64,
61}
62
63#[derive(Debug, Serialize, Deserialize, Clone)]
66pub enum ResponseChunkKind {
67 #[serde(rename = "response.created")]
68 ResponseCreated,
69 #[serde(rename = "response.in_progress")]
70 ResponseInProgress,
71 #[serde(rename = "response.completed")]
72 ResponseCompleted,
73 #[serde(rename = "response.failed")]
74 ResponseFailed,
75 #[serde(rename = "response.incomplete")]
76 ResponseIncomplete,
77}
78
79#[derive(Debug, Serialize, Deserialize, Clone)]
82pub struct ItemChunk {
83 pub item_id: Option<String>,
85 pub output_index: u64,
87 #[serde(flatten)]
89 pub data: ItemChunkKind,
90}
91
92#[derive(Debug, Serialize, Deserialize, Clone)]
94#[serde(tag = "type")]
95pub enum ItemChunkKind {
96 #[serde(rename = "response.output_item.added")]
97 OutputItemAdded(StreamingItemDoneOutput),
98 #[serde(rename = "response.output_item.done")]
99 OutputItemDone(StreamingItemDoneOutput),
100 #[serde(rename = "response.content_part.added")]
101 ContentPartAdded(ContentPartChunk),
102 #[serde(rename = "response.content_part.done")]
103 ContentPartDone(ContentPartChunk),
104 #[serde(rename = "response.output_text.delta")]
105 OutputTextDelta(DeltaTextChunk),
106 #[serde(rename = "response.output_text.done")]
107 OutputTextDone(OutputTextChunk),
108 #[serde(rename = "response.refusal.delta")]
109 RefusalDelta(DeltaTextChunk),
110 #[serde(rename = "response.refusal.done")]
111 RefusalDone(RefusalTextChunk),
112 #[serde(rename = "response.function_call_arguments.delta")]
113 FunctionCallArgsDelta(DeltaTextChunk),
114 #[serde(rename = "response.function_call_arguments.done")]
115 FunctionCallArgsDone(ArgsTextChunk),
116 #[serde(rename = "response.reasoning_summary_part.added")]
117 ReasoningSummaryPartAdded(SummaryPartChunk),
118 #[serde(rename = "response.reasoning_summary_part.done")]
119 ReasoningSummaryPartDone(SummaryPartChunk),
120 #[serde(rename = "response.reasoning_summary_text.added")]
121 ReasoningSummaryTextAdded(SummaryTextChunk),
122 #[serde(rename = "response.reasoning_summary_text.done")]
123 ReasoningSummaryTextDone(SummaryTextChunk),
124}
125
126#[derive(Debug, Serialize, Deserialize, Clone)]
127pub struct StreamingItemDoneOutput {
128 pub sequence_number: u64,
129 pub item: Output,
130}
131
132#[derive(Debug, Serialize, Deserialize, Clone)]
133pub struct ContentPartChunk {
134 pub content_index: u64,
135 pub sequence_number: u64,
136 pub part: ContentPartChunkPart,
137}
138
139#[derive(Debug, Serialize, Deserialize, Clone)]
140#[serde(tag = "type")]
141pub enum ContentPartChunkPart {
142 OutputText { text: String },
143 SummaryText { text: String },
144}
145
146#[derive(Debug, Serialize, Deserialize, Clone)]
147pub struct DeltaTextChunk {
148 pub content_index: u64,
149 pub sequence_number: u64,
150 pub delta: String,
151}
152
153#[derive(Debug, Serialize, Deserialize, Clone)]
154pub struct OutputTextChunk {
155 pub content_index: u64,
156 pub sequence_number: u64,
157 pub text: String,
158}
159
160#[derive(Debug, Serialize, Deserialize, Clone)]
161pub struct RefusalTextChunk {
162 pub content_index: u64,
163 pub sequence_number: u64,
164 pub refusal: String,
165}
166
167#[derive(Debug, Serialize, Deserialize, Clone)]
168pub struct ArgsTextChunk {
169 pub content_index: u64,
170 pub sequence_number: u64,
171 pub arguments: serde_json::Value,
172}
173
174#[derive(Debug, Serialize, Deserialize, Clone)]
175pub struct SummaryPartChunk {
176 pub summary_index: u64,
177 pub sequence_number: u64,
178 pub part: SummaryPartChunkPart,
179}
180
181#[derive(Debug, Serialize, Deserialize, Clone)]
182pub struct SummaryTextChunk {
183 pub summary_index: u64,
184 pub sequence_number: u64,
185 pub delta: String,
186}
187
188#[derive(Debug, Serialize, Deserialize, Clone)]
189#[serde(tag = "type")]
190pub enum SummaryPartChunkPart {
191 SummaryText { text: String },
192}
193
194impl ResponsesCompletionModel<reqwest::Client> {
195 pub(crate) async fn stream(
196 &self,
197 completion_request: crate::completion::CompletionRequest,
198 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
199 {
200 let mut request = self.create_completion_request(completion_request)?;
201 request.stream = Some(true);
202
203 let request_builder = self.client.post_reqwest("/responses").json(&request);
204
205 let span = if tracing::Span::current().is_disabled() {
206 info_span!(
207 target: "rig::completions",
208 "chat_streaming",
209 gen_ai.operation.name = "chat_streaming",
210 gen_ai.provider.name = tracing::field::Empty,
211 gen_ai.request.model = tracing::field::Empty,
212 gen_ai.response.id = tracing::field::Empty,
213 gen_ai.response.model = tracing::field::Empty,
214 gen_ai.usage.output_tokens = tracing::field::Empty,
215 gen_ai.usage.input_tokens = tracing::field::Empty,
216 gen_ai.input.messages = tracing::field::Empty,
217 gen_ai.output.messages = tracing::field::Empty,
218 )
219 } else {
220 tracing::Span::current()
221 };
222 span.record("gen_ai.provider.name", "openai");
223 span.record("gen_ai.request.model", &self.model);
224 span.record(
225 "gen_ai.input.messages",
226 serde_json::to_string(&request.input).expect("This should always work"),
227 );
228 let mut event_source = request_builder
230 .eventsource()
231 .expect("Cloning request must always succeed");
232
233 let stream = stream! {
234 let mut final_usage = ResponsesUsage::new();
235
236 let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
237 let mut combined_text = String::new();
238 let span = tracing::Span::current();
239
240 while let Some(event_result) = event_source.next().await {
241 match event_result {
242 Ok(Event::Open) => {
243 tracing::trace!("SSE connection opened");
244 tracing::info!("OpenAI stream started");
245 continue;
246 }
247 Ok(Event::Message(message)) => {
248 if message.data.trim().is_empty() {
250 continue;
251 }
252
253 let data = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
254
255 let Ok(data) = data else {
256 let err = data.unwrap_err();
257 debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
258 continue;
259 };
260
261 if let StreamingCompletionChunk::Delta(chunk) = &data {
262 match &chunk.data {
263 ItemChunkKind::OutputItemDone(message) => {
264 match message {
265 StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } => {
266 tool_calls.push(streaming::RawStreamingChoice::ToolCall { id: func.id.clone(), call_id: Some(func.call_id.clone()), name: func.name.clone(), arguments: func.arguments.clone() });
267 }
268
269 StreamingItemDoneOutput { item: Output::Reasoning { summary, id }, .. } => {
270 let reasoning = summary
271 .iter()
272 .map(|x| {
273 let ReasoningSummary::SummaryText { text } = x;
274 text.to_owned()
275 })
276 .collect::<Vec<String>>()
277 .join("\n");
278 yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning, id: Some(id.to_string()) })
279 }
280 _ => continue
281 }
282 }
283 ItemChunkKind::OutputTextDelta(delta) => {
284 combined_text.push_str(&delta.delta);
285 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
286 }
287 ItemChunkKind::RefusalDelta(delta) => {
288 combined_text.push_str(&delta.delta);
289 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
290 }
291
292 _ => { continue }
293 }
294 }
295
296 if let StreamingCompletionChunk::Response(chunk) = data {
297 if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk {
298 span.record("gen_ai.output.messages", serde_json::to_string(&response.output).unwrap());
299 span.record("gen_ai.response.id", response.id);
300 span.record("gen_ai.response.model", response.model);
301 if let Some(usage) = response.usage {
302 final_usage = usage;
303 }
304 } else {
305 continue;
306 }
307 }
308 }
309 Err(reqwest_eventsource::Error::StreamEnded) => {
310 break;
311 }
312 Err(error) => {
313 tracing::error!(?error, "SSE error");
314 yield Err(CompletionError::ResponseError(error.to_string()));
315 break;
316 }
317 }
318 }
319
320 event_source.close();
322
323 for tool_call in &tool_calls {
324 yield Ok(tool_call.to_owned())
325 }
326
327 span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
328 span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
329 tracing::info!("OpenAI stream finished");
330
331 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
332 usage: final_usage.clone()
333 }));
334 }.instrument(span);
335
336 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
337 stream,
338 )))
339 }
340}