rig/providers/openai/responses_api/
streaming.rs

1//! The streaming module for the OpenAI Responses API.
2//! Please see the `openai_streaming` or `openai_streaming_with_tools` example for more practical usage.
3use crate::completion::{CompletionError, GetTokenUsage};
4use crate::providers::openai::responses_api::{
5    ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
6};
7use crate::streaming;
8use crate::streaming::RawStreamingChoice;
9use async_stream::stream;
10use futures::StreamExt;
11use reqwest::RequestBuilder;
12use reqwest_eventsource::Event;
13use reqwest_eventsource::RequestBuilderExt;
14use serde::{Deserialize, Serialize};
15use tracing::debug;
16
17use super::{CompletionResponse, Output};
18
19// ================================================================
20// OpenAI Responses Streaming API
21// ================================================================
22
23/// A streaming completion chunk.
24/// Streaming chunks can come in one of two forms:
25/// - A response chunk (where the completed response will have the total token usage)
26/// - An item chunk commonly referred to as a delta. In the completions API this would be referred to as the message delta.
27#[derive(Debug, Serialize, Deserialize, Clone)]
28#[serde(untagged)]
29pub enum StreamingCompletionChunk {
30    Response(Box<ResponseChunk>),
31    Delta(ItemChunk),
32}
33
34/// The final streaming response from the OpenAI Responses API.
35#[derive(Debug, Serialize, Deserialize, Clone)]
36pub struct StreamingCompletionResponse {
37    /// Token usage
38    pub usage: ResponsesUsage,
39}
40
41impl GetTokenUsage for StreamingCompletionResponse {
42    fn token_usage(&self) -> Option<crate::completion::Usage> {
43        let mut usage = crate::completion::Usage::new();
44        usage.input_tokens = self.usage.input_tokens;
45        usage.output_tokens = self.usage.output_tokens;
46        usage.total_tokens = self.usage.total_tokens;
47        Some(usage)
48    }
49}
50
51/// A response chunk from OpenAI's response API.
52#[derive(Debug, Serialize, Deserialize, Clone)]
53pub struct ResponseChunk {
54    /// The response chunk type
55    #[serde(rename = "type")]
56    pub kind: ResponseChunkKind,
57    /// The response itself
58    pub response: CompletionResponse,
59    /// The item sequence
60    pub sequence_number: u64,
61}
62
63/// Response chunk type.
64/// Renames are used to ensure that this type gets (de)serialized properly.
65#[derive(Debug, Serialize, Deserialize, Clone)]
66pub enum ResponseChunkKind {
67    #[serde(rename = "response.created")]
68    ResponseCreated,
69    #[serde(rename = "response.in_progress")]
70    ResponseInProgress,
71    #[serde(rename = "response.completed")]
72    ResponseCompleted,
73    #[serde(rename = "response.failed")]
74    ResponseFailed,
75    #[serde(rename = "response.incomplete")]
76    ResponseIncomplete,
77}
78
79/// An item message chunk from OpenAI's Responses API.
80/// See
81#[derive(Debug, Serialize, Deserialize, Clone)]
82pub struct ItemChunk {
83    /// Item ID. Optional.
84    pub item_id: Option<String>,
85    /// The output index of the item from a given streamed response.
86    pub output_index: u64,
87    /// The item type chunk, as well as the inner data.
88    #[serde(flatten)]
89    pub data: ItemChunkKind,
90}
91
92/// The item chunk type from OpenAI's Responses API.
93#[derive(Debug, Serialize, Deserialize, Clone)]
94#[serde(tag = "type")]
95pub enum ItemChunkKind {
96    #[serde(rename = "response.output_item.added")]
97    OutputItemAdded(StreamingItemDoneOutput),
98    #[serde(rename = "response.output_item.done")]
99    OutputItemDone(StreamingItemDoneOutput),
100    #[serde(rename = "response.content_part.added")]
101    ContentPartAdded(ContentPartChunk),
102    #[serde(rename = "response.content_part.done")]
103    ContentPartDone(ContentPartChunk),
104    #[serde(rename = "response.output_text.delta")]
105    OutputTextDelta(DeltaTextChunk),
106    #[serde(rename = "response.output_text.done")]
107    OutputTextDone(OutputTextChunk),
108    #[serde(rename = "response.refusal.delta")]
109    RefusalDelta(DeltaTextChunk),
110    #[serde(rename = "response.refusal.done")]
111    RefusalDone(RefusalTextChunk),
112    #[serde(rename = "response.function_call_arguments.delta")]
113    FunctionCallArgsDelta(DeltaTextChunk),
114    #[serde(rename = "response.function_call_arguments.done")]
115    FunctionCallArgsDone(ArgsTextChunk),
116    #[serde(rename = "response.reasoning_summary_part.added")]
117    ReasoningSummaryPartAdded(SummaryPartChunk),
118    #[serde(rename = "response.reasoning_summary_part.done")]
119    ReasoningSummaryPartDone(SummaryPartChunk),
120    #[serde(rename = "response.reasoning_summary_text.added")]
121    ReasoningSummaryTextAdded(SummaryTextChunk),
122    #[serde(rename = "response.reasoning_summary_text.done")]
123    ReasoningSummaryTextDone(SummaryTextChunk),
124}
125
126#[derive(Debug, Serialize, Deserialize, Clone)]
127pub struct StreamingItemDoneOutput {
128    pub sequence_number: u64,
129    pub item: Output,
130}
131
132#[derive(Debug, Serialize, Deserialize, Clone)]
133pub struct ContentPartChunk {
134    pub content_index: u64,
135    pub sequence_number: u64,
136    pub part: ContentPartChunkPart,
137}
138
139#[derive(Debug, Serialize, Deserialize, Clone)]
140#[serde(tag = "type")]
141pub enum ContentPartChunkPart {
142    OutputText { text: String },
143    SummaryText { text: String },
144}
145
146#[derive(Debug, Serialize, Deserialize, Clone)]
147pub struct DeltaTextChunk {
148    pub content_index: u64,
149    pub sequence_number: u64,
150    pub delta: String,
151}
152
153#[derive(Debug, Serialize, Deserialize, Clone)]
154pub struct OutputTextChunk {
155    pub content_index: u64,
156    pub sequence_number: u64,
157    pub text: String,
158}
159
160#[derive(Debug, Serialize, Deserialize, Clone)]
161pub struct RefusalTextChunk {
162    pub content_index: u64,
163    pub sequence_number: u64,
164    pub refusal: String,
165}
166
167#[derive(Debug, Serialize, Deserialize, Clone)]
168pub struct ArgsTextChunk {
169    pub content_index: u64,
170    pub sequence_number: u64,
171    pub arguments: serde_json::Value,
172}
173
174#[derive(Debug, Serialize, Deserialize, Clone)]
175pub struct SummaryPartChunk {
176    pub summary_index: u64,
177    pub sequence_number: u64,
178    pub part: SummaryPartChunkPart,
179}
180
181#[derive(Debug, Serialize, Deserialize, Clone)]
182pub struct SummaryTextChunk {
183    pub summary_index: u64,
184    pub sequence_number: u64,
185    pub delta: String,
186}
187
188#[derive(Debug, Serialize, Deserialize, Clone)]
189#[serde(tag = "type")]
190pub enum SummaryPartChunkPart {
191    SummaryText { text: String },
192}
193
194impl ResponsesCompletionModel {
195    pub(crate) async fn stream(
196        &self,
197        completion_request: crate::completion::CompletionRequest,
198    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
199    {
200        let mut request = self.create_completion_request(completion_request)?;
201        request.stream = Some(true);
202
203        tracing::debug!("Input: {}", serde_json::to_string_pretty(&request)?);
204
205        let builder = self.client.post("/responses").json(&request);
206        send_compatible_streaming_request(builder).await
207    }
208}
209
210/// Send a compatible streaming request.
211/// The following are assumed to already be set:
212/// - The URL to send a POST request to
213/// - The JSON body
214pub async fn send_compatible_streaming_request(
215    request_builder: RequestBuilder,
216) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
217    // Build the request with proper headers for SSE
218    let mut event_source = request_builder
219        .eventsource()
220        .expect("Cloning request must always succeed");
221
222    let stream = Box::pin(stream! {
223        let mut final_usage = ResponsesUsage::new();
224
225        let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
226
227        while let Some(event_result) = event_source.next().await {
228            match event_result {
229                Ok(Event::Open) => {
230                    tracing::trace!("SSE connection opened");
231                    continue;
232                }
233                Ok(Event::Message(message)) => {
234                    // Skip heartbeat messages or empty data
235                    if message.data.trim().is_empty() {
236                        continue;
237                    }
238
239                    let data = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
240
241                    let Ok(data) = data else {
242                        let err = data.unwrap_err();
243                        debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
244                        continue;
245                    };
246
247                    if let StreamingCompletionChunk::Delta(chunk) = &data {
248                        match &chunk.data {
249                            ItemChunkKind::OutputItemDone(message) => {
250                                match message {
251                                    StreamingItemDoneOutput {  item: Output::FunctionCall(func), .. } => {
252                                        tracing::debug!("Function call received: {func:?}");
253                                        tool_calls.push(streaming::RawStreamingChoice::ToolCall { id: func.id.clone(), call_id: Some(func.call_id.clone()), name: func.name.clone(), arguments: func.arguments.clone() });
254                                    }
255
256                                    StreamingItemDoneOutput {  item: Output::Reasoning {  summary, id }, .. } => {
257                                        let reasoning = summary
258                                            .iter()
259                                            .map(|x| {
260                                                let ReasoningSummary::SummaryText { text } = x;
261                                                text.to_owned()
262                                            })
263                                            .collect::<Vec<String>>()
264                                            .join("\n");
265                                        yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning, id: Some(id.to_string()) })
266                                    }
267                                    _ => continue
268                                }
269                            }
270                            ItemChunkKind::OutputTextDelta(delta) => {
271                                yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
272                            }
273                            ItemChunkKind::RefusalDelta(delta) => {
274                                yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
275                            }
276
277                            _ => { continue }
278                        }
279                    }
280
281                    if let StreamingCompletionChunk::Response(chunk) = data {
282                        if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk {
283                            if let Some(usage) = response.usage {
284                                final_usage = usage;
285                            }
286                        } else {
287                            continue;
288                        }
289                    }
290                }
291                Err(reqwest_eventsource::Error::StreamEnded) => {
292                    break;
293                }
294                Err(error) => {
295                    tracing::error!(?error, "SSE error");
296                    yield Err(CompletionError::ResponseError(error.to_string()));
297                    break;
298                }
299            }
300        }
301
302        // Ensure event source is closed when stream ends
303        event_source.close();
304
305        for tool_call in &tool_calls {
306            yield Ok(tool_call.to_owned())
307        }
308
309        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
310            usage: final_usage.clone()
311        }));
312    });
313
314    Ok(streaming::StreamingCompletionResponse::stream(stream))
315}