rig/providers/openai/responses_api/
streaming.rs

1//! The streaming module for the OpenAI Responses API.
2//! Please see the `openai_streaming` or `openai_streaming_with_tools` example for more practical usage.
3use crate::completion::{CompletionError, GetTokenUsage};
4use crate::providers::openai::responses_api::{
5    ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
6};
7use crate::streaming;
8use crate::streaming::RawStreamingChoice;
9use async_stream::stream;
10use futures::StreamExt;
11use reqwest_eventsource::Event;
12use reqwest_eventsource::RequestBuilderExt;
13use serde::{Deserialize, Serialize};
14use tracing::{debug, info_span};
15use tracing_futures::Instrument as _;
16
17use super::{CompletionResponse, Output};
18
19// ================================================================
20// OpenAI Responses Streaming API
21// ================================================================
22
23/// A streaming completion chunk.
24/// Streaming chunks can come in one of two forms:
25/// - A response chunk (where the completed response will have the total token usage)
26/// - An item chunk commonly referred to as a delta. In the completions API this would be referred to as the message delta.
27#[derive(Debug, Serialize, Deserialize, Clone)]
28#[serde(untagged)]
29pub enum StreamingCompletionChunk {
30    Response(Box<ResponseChunk>),
31    Delta(ItemChunk),
32}
33
34/// The final streaming response from the OpenAI Responses API.
35#[derive(Debug, Serialize, Deserialize, Clone)]
36pub struct StreamingCompletionResponse {
37    /// Token usage
38    pub usage: ResponsesUsage,
39}
40
41impl GetTokenUsage for StreamingCompletionResponse {
42    fn token_usage(&self) -> Option<crate::completion::Usage> {
43        let mut usage = crate::completion::Usage::new();
44        usage.input_tokens = self.usage.input_tokens;
45        usage.output_tokens = self.usage.output_tokens;
46        usage.total_tokens = self.usage.total_tokens;
47        Some(usage)
48    }
49}
50
51/// A response chunk from OpenAI's response API.
52#[derive(Debug, Serialize, Deserialize, Clone)]
53pub struct ResponseChunk {
54    /// The response chunk type
55    #[serde(rename = "type")]
56    pub kind: ResponseChunkKind,
57    /// The response itself
58    pub response: CompletionResponse,
59    /// The item sequence
60    pub sequence_number: u64,
61}
62
63/// Response chunk type.
64/// Renames are used to ensure that this type gets (de)serialized properly.
65#[derive(Debug, Serialize, Deserialize, Clone)]
66pub enum ResponseChunkKind {
67    #[serde(rename = "response.created")]
68    ResponseCreated,
69    #[serde(rename = "response.in_progress")]
70    ResponseInProgress,
71    #[serde(rename = "response.completed")]
72    ResponseCompleted,
73    #[serde(rename = "response.failed")]
74    ResponseFailed,
75    #[serde(rename = "response.incomplete")]
76    ResponseIncomplete,
77}
78
79/// An item message chunk from OpenAI's Responses API.
80/// See
81#[derive(Debug, Serialize, Deserialize, Clone)]
82pub struct ItemChunk {
83    /// Item ID. Optional.
84    pub item_id: Option<String>,
85    /// The output index of the item from a given streamed response.
86    pub output_index: u64,
87    /// The item type chunk, as well as the inner data.
88    #[serde(flatten)]
89    pub data: ItemChunkKind,
90}
91
92/// The item chunk type from OpenAI's Responses API.
93#[derive(Debug, Serialize, Deserialize, Clone)]
94#[serde(tag = "type")]
95pub enum ItemChunkKind {
96    #[serde(rename = "response.output_item.added")]
97    OutputItemAdded(StreamingItemDoneOutput),
98    #[serde(rename = "response.output_item.done")]
99    OutputItemDone(StreamingItemDoneOutput),
100    #[serde(rename = "response.content_part.added")]
101    ContentPartAdded(ContentPartChunk),
102    #[serde(rename = "response.content_part.done")]
103    ContentPartDone(ContentPartChunk),
104    #[serde(rename = "response.output_text.delta")]
105    OutputTextDelta(DeltaTextChunk),
106    #[serde(rename = "response.output_text.done")]
107    OutputTextDone(OutputTextChunk),
108    #[serde(rename = "response.refusal.delta")]
109    RefusalDelta(DeltaTextChunk),
110    #[serde(rename = "response.refusal.done")]
111    RefusalDone(RefusalTextChunk),
112    #[serde(rename = "response.function_call_arguments.delta")]
113    FunctionCallArgsDelta(DeltaTextChunk),
114    #[serde(rename = "response.function_call_arguments.done")]
115    FunctionCallArgsDone(ArgsTextChunk),
116    #[serde(rename = "response.reasoning_summary_part.added")]
117    ReasoningSummaryPartAdded(SummaryPartChunk),
118    #[serde(rename = "response.reasoning_summary_part.done")]
119    ReasoningSummaryPartDone(SummaryPartChunk),
120    #[serde(rename = "response.reasoning_summary_text.added")]
121    ReasoningSummaryTextAdded(SummaryTextChunk),
122    #[serde(rename = "response.reasoning_summary_text.done")]
123    ReasoningSummaryTextDone(SummaryTextChunk),
124}
125
126#[derive(Debug, Serialize, Deserialize, Clone)]
127pub struct StreamingItemDoneOutput {
128    pub sequence_number: u64,
129    pub item: Output,
130}
131
132#[derive(Debug, Serialize, Deserialize, Clone)]
133pub struct ContentPartChunk {
134    pub content_index: u64,
135    pub sequence_number: u64,
136    pub part: ContentPartChunkPart,
137}
138
139#[derive(Debug, Serialize, Deserialize, Clone)]
140#[serde(tag = "type")]
141pub enum ContentPartChunkPart {
142    OutputText { text: String },
143    SummaryText { text: String },
144}
145
146#[derive(Debug, Serialize, Deserialize, Clone)]
147pub struct DeltaTextChunk {
148    pub content_index: u64,
149    pub sequence_number: u64,
150    pub delta: String,
151}
152
153#[derive(Debug, Serialize, Deserialize, Clone)]
154pub struct OutputTextChunk {
155    pub content_index: u64,
156    pub sequence_number: u64,
157    pub text: String,
158}
159
160#[derive(Debug, Serialize, Deserialize, Clone)]
161pub struct RefusalTextChunk {
162    pub content_index: u64,
163    pub sequence_number: u64,
164    pub refusal: String,
165}
166
167#[derive(Debug, Serialize, Deserialize, Clone)]
168pub struct ArgsTextChunk {
169    pub content_index: u64,
170    pub sequence_number: u64,
171    pub arguments: serde_json::Value,
172}
173
174#[derive(Debug, Serialize, Deserialize, Clone)]
175pub struct SummaryPartChunk {
176    pub summary_index: u64,
177    pub sequence_number: u64,
178    pub part: SummaryPartChunkPart,
179}
180
181#[derive(Debug, Serialize, Deserialize, Clone)]
182pub struct SummaryTextChunk {
183    pub summary_index: u64,
184    pub sequence_number: u64,
185    pub delta: String,
186}
187
188#[derive(Debug, Serialize, Deserialize, Clone)]
189#[serde(tag = "type")]
190pub enum SummaryPartChunkPart {
191    SummaryText { text: String },
192}
193
194impl ResponsesCompletionModel<reqwest::Client> {
195    pub(crate) async fn stream(
196        &self,
197        completion_request: crate::completion::CompletionRequest,
198    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
199    {
200        let mut request = self.create_completion_request(completion_request)?;
201        request.stream = Some(true);
202
203        let request_builder = self.client.post_reqwest("/responses").json(&request);
204
205        let span = if tracing::Span::current().is_disabled() {
206            info_span!(
207                target: "rig::completions",
208                "chat_streaming",
209                gen_ai.operation.name = "chat_streaming",
210                gen_ai.provider.name = tracing::field::Empty,
211                gen_ai.request.model = tracing::field::Empty,
212                gen_ai.response.id = tracing::field::Empty,
213                gen_ai.response.model = tracing::field::Empty,
214                gen_ai.usage.output_tokens = tracing::field::Empty,
215                gen_ai.usage.input_tokens = tracing::field::Empty,
216                gen_ai.input.messages = tracing::field::Empty,
217                gen_ai.output.messages = tracing::field::Empty,
218            )
219        } else {
220            tracing::Span::current()
221        };
222        span.record("gen_ai.provider.name", "openai");
223        span.record("gen_ai.request.model", &self.model);
224        span.record(
225            "gen_ai.input.messages",
226            serde_json::to_string(&request.input).expect("This should always work"),
227        );
228        // Build the request with proper headers for SSE
229        let mut event_source = request_builder
230            .eventsource()
231            .expect("Cloning request must always succeed");
232
233        let stream = stream! {
234            let mut final_usage = ResponsesUsage::new();
235
236            let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
237            let mut combined_text = String::new();
238            let span = tracing::Span::current();
239
240            while let Some(event_result) = event_source.next().await {
241                match event_result {
242                    Ok(Event::Open) => {
243                        tracing::trace!("SSE connection opened");
244                        tracing::info!("OpenAI stream started");
245                        continue;
246                    }
247                    Ok(Event::Message(message)) => {
248                        // Skip heartbeat messages or empty data
249                        if message.data.trim().is_empty() {
250                            continue;
251                        }
252
253                        let data = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
254
255                        let Ok(data) = data else {
256                            let err = data.unwrap_err();
257                            debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
258                            continue;
259                        };
260
261                        if let StreamingCompletionChunk::Delta(chunk) = &data {
262                            match &chunk.data {
263                                ItemChunkKind::OutputItemDone(message) => {
264                                    match message {
265                                        StreamingItemDoneOutput {  item: Output::FunctionCall(func), .. } => {
266                                            tool_calls.push(streaming::RawStreamingChoice::ToolCall { id: func.id.clone(), call_id: Some(func.call_id.clone()), name: func.name.clone(), arguments: func.arguments.clone() });
267                                        }
268
269                                        StreamingItemDoneOutput {  item: Output::Reasoning {  summary, id }, .. } => {
270                                            let reasoning = summary
271                                                .iter()
272                                                .map(|x| {
273                                                    let ReasoningSummary::SummaryText { text } = x;
274                                                    text.to_owned()
275                                                })
276                                                .collect::<Vec<String>>()
277                                                .join("\n");
278                                            yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning, id: Some(id.to_string()) })
279                                        }
280                                        _ => continue
281                                    }
282                                }
283                                ItemChunkKind::OutputTextDelta(delta) => {
284                                    combined_text.push_str(&delta.delta);
285                                    yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
286                                }
287                                ItemChunkKind::RefusalDelta(delta) => {
288                                    combined_text.push_str(&delta.delta);
289                                    yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
290                                }
291
292                                _ => { continue }
293                            }
294                        }
295
296                        if let StreamingCompletionChunk::Response(chunk) = data {
297                            if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk {
298                                span.record("gen_ai.output.messages", serde_json::to_string(&response.output).unwrap());
299                                span.record("gen_ai.response.id", response.id);
300                                span.record("gen_ai.response.model", response.model);
301                                if let Some(usage) = response.usage {
302                                    final_usage = usage;
303                                }
304                            } else {
305                                continue;
306                            }
307                        }
308                    }
309                    Err(reqwest_eventsource::Error::StreamEnded) => {
310                        break;
311                    }
312                    Err(error) => {
313                        tracing::error!(?error, "SSE error");
314                        yield Err(CompletionError::ResponseError(error.to_string()));
315                        break;
316                    }
317                }
318            }
319
320            // Ensure event source is closed when stream ends
321            event_source.close();
322
323            for tool_call in &tool_calls {
324                yield Ok(tool_call.to_owned())
325            }
326
327            span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
328            span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
329            tracing::info!("OpenAI stream finished");
330
331            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
332                usage: final_usage.clone()
333            }));
334        }.instrument(span);
335
336        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
337            stream,
338        )))
339    }
340}