rig/providers/openai/responses_api/
streaming.rs

1//! The streaming module for the OpenAI Responses API.
2//! Please see the `openai_streaming` or `openai_streaming_with_tools` example for more practical usage.
3use crate::completion::{CompletionError, GetTokenUsage};
4use crate::http_client::HttpClientExt;
5use crate::http_client::sse::{Event, GenericEventSource};
6use crate::providers::openai::responses_api::{
7    ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
8};
9use crate::streaming;
10use crate::streaming::RawStreamingChoice;
11use crate::wasm_compat::WasmCompatSend;
12use async_stream::stream;
13use futures::StreamExt;
14use serde::{Deserialize, Serialize};
15use tracing::{Level, debug, enabled, info_span};
16use tracing_futures::Instrument as _;
17
18use super::{CompletionResponse, Output};
19
20// ================================================================
21// OpenAI Responses Streaming API
22// ================================================================
23
24/// A streaming completion chunk.
25/// Streaming chunks can come in one of two forms:
26/// - A response chunk (where the completed response will have the total token usage)
27/// - An item chunk commonly referred to as a delta. In the completions API this would be referred to as the message delta.
28#[derive(Debug, Serialize, Deserialize, Clone)]
29#[serde(untagged)]
30pub enum StreamingCompletionChunk {
31    Response(Box<ResponseChunk>),
32    Delta(ItemChunk),
33}
34
35/// The final streaming response from the OpenAI Responses API.
36#[derive(Debug, Serialize, Deserialize, Clone)]
37pub struct StreamingCompletionResponse {
38    /// Token usage
39    pub usage: ResponsesUsage,
40}
41
42impl GetTokenUsage for StreamingCompletionResponse {
43    fn token_usage(&self) -> Option<crate::completion::Usage> {
44        let mut usage = crate::completion::Usage::new();
45        usage.input_tokens = self.usage.input_tokens;
46        usage.output_tokens = self.usage.output_tokens;
47        usage.total_tokens = self.usage.total_tokens;
48        Some(usage)
49    }
50}
51
52/// A response chunk from OpenAI's response API.
53#[derive(Debug, Serialize, Deserialize, Clone)]
54pub struct ResponseChunk {
55    /// The response chunk type
56    #[serde(rename = "type")]
57    pub kind: ResponseChunkKind,
58    /// The response itself
59    pub response: CompletionResponse,
60    /// The item sequence
61    pub sequence_number: u64,
62}
63
64/// Response chunk type.
65/// Renames are used to ensure that this type gets (de)serialized properly.
66#[derive(Debug, Serialize, Deserialize, Clone)]
67pub enum ResponseChunkKind {
68    #[serde(rename = "response.created")]
69    ResponseCreated,
70    #[serde(rename = "response.in_progress")]
71    ResponseInProgress,
72    #[serde(rename = "response.completed")]
73    ResponseCompleted,
74    #[serde(rename = "response.failed")]
75    ResponseFailed,
76    #[serde(rename = "response.incomplete")]
77    ResponseIncomplete,
78}
79
80/// An item message chunk from OpenAI's Responses API.
81/// See
82#[derive(Debug, Serialize, Deserialize, Clone)]
83pub struct ItemChunk {
84    /// Item ID. Optional.
85    pub item_id: Option<String>,
86    /// The output index of the item from a given streamed response.
87    pub output_index: u64,
88    /// The item type chunk, as well as the inner data.
89    #[serde(flatten)]
90    pub data: ItemChunkKind,
91}
92
93/// The item chunk type from OpenAI's Responses API.
94#[derive(Debug, Serialize, Deserialize, Clone)]
95#[serde(tag = "type")]
96pub enum ItemChunkKind {
97    #[serde(rename = "response.output_item.added")]
98    OutputItemAdded(StreamingItemDoneOutput),
99    #[serde(rename = "response.output_item.done")]
100    OutputItemDone(StreamingItemDoneOutput),
101    #[serde(rename = "response.content_part.added")]
102    ContentPartAdded(ContentPartChunk),
103    #[serde(rename = "response.content_part.done")]
104    ContentPartDone(ContentPartChunk),
105    #[serde(rename = "response.output_text.delta")]
106    OutputTextDelta(DeltaTextChunk),
107    #[serde(rename = "response.output_text.done")]
108    OutputTextDone(OutputTextChunk),
109    #[serde(rename = "response.refusal.delta")]
110    RefusalDelta(DeltaTextChunk),
111    #[serde(rename = "response.refusal.done")]
112    RefusalDone(RefusalTextChunk),
113    #[serde(rename = "response.function_call_arguments.delta")]
114    FunctionCallArgsDelta(DeltaTextChunkWithItemId),
115    #[serde(rename = "response.function_call_arguments.done")]
116    FunctionCallArgsDone(ArgsTextChunk),
117    #[serde(rename = "response.reasoning_summary_part.added")]
118    ReasoningSummaryPartAdded(SummaryPartChunk),
119    #[serde(rename = "response.reasoning_summary_part.done")]
120    ReasoningSummaryPartDone(SummaryPartChunk),
121    #[serde(rename = "response.reasoning_summary_text.added")]
122    ReasoningSummaryTextAdded(SummaryTextChunk),
123    #[serde(rename = "response.reasoning_summary_text.done")]
124    ReasoningSummaryTextDone(SummaryTextChunk),
125}
126
127#[derive(Debug, Serialize, Deserialize, Clone)]
128pub struct StreamingItemDoneOutput {
129    pub sequence_number: u64,
130    pub item: Output,
131}
132
133#[derive(Debug, Serialize, Deserialize, Clone)]
134pub struct ContentPartChunk {
135    pub content_index: u64,
136    pub sequence_number: u64,
137    pub part: ContentPartChunkPart,
138}
139
140#[derive(Debug, Serialize, Deserialize, Clone)]
141#[serde(tag = "type")]
142pub enum ContentPartChunkPart {
143    OutputText { text: String },
144    SummaryText { text: String },
145}
146
147#[derive(Debug, Serialize, Deserialize, Clone)]
148pub struct DeltaTextChunk {
149    pub content_index: u64,
150    pub sequence_number: u64,
151    pub delta: String,
152}
153
154#[derive(Debug, Serialize, Deserialize, Clone)]
155pub struct DeltaTextChunkWithItemId {
156    pub item_id: String,
157    pub content_index: u64,
158    pub sequence_number: u64,
159    pub delta: String,
160}
161
162#[derive(Debug, Serialize, Deserialize, Clone)]
163pub struct OutputTextChunk {
164    pub content_index: u64,
165    pub sequence_number: u64,
166    pub text: String,
167}
168
169#[derive(Debug, Serialize, Deserialize, Clone)]
170pub struct RefusalTextChunk {
171    pub content_index: u64,
172    pub sequence_number: u64,
173    pub refusal: String,
174}
175
176#[derive(Debug, Serialize, Deserialize, Clone)]
177pub struct ArgsTextChunk {
178    pub content_index: u64,
179    pub sequence_number: u64,
180    pub arguments: serde_json::Value,
181}
182
183#[derive(Debug, Serialize, Deserialize, Clone)]
184pub struct SummaryPartChunk {
185    pub summary_index: u64,
186    pub sequence_number: u64,
187    pub part: SummaryPartChunkPart,
188}
189
190#[derive(Debug, Serialize, Deserialize, Clone)]
191pub struct SummaryTextChunk {
192    pub summary_index: u64,
193    pub sequence_number: u64,
194    pub delta: String,
195}
196
197#[derive(Debug, Serialize, Deserialize, Clone)]
198#[serde(tag = "type")]
199pub enum SummaryPartChunkPart {
200    SummaryText { text: String },
201}
202
203impl<T> ResponsesCompletionModel<T>
204where
205    T: HttpClientExt + Clone + Default + std::fmt::Debug + WasmCompatSend + 'static,
206{
207    pub(crate) async fn stream(
208        &self,
209        completion_request: crate::completion::CompletionRequest,
210    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
211    {
212        let mut request = self.create_completion_request(completion_request)?;
213        request.stream = Some(true);
214
215        if enabled!(Level::TRACE) {
216            tracing::trace!(
217                target: "rig::completions",
218                "OpenAI Responses streaming completion request: {}",
219                serde_json::to_string_pretty(&request)?
220            );
221        }
222
223        let body = serde_json::to_vec(&request)?;
224
225        let req = self
226            .client
227            .post("/responses")?
228            .body(body)
229            .map_err(|e| CompletionError::HttpError(e.into()))?;
230
231        // let request_builder = self.client.post_reqwest("/responses").json(&request);
232
233        let span = if tracing::Span::current().is_disabled() {
234            info_span!(
235                target: "rig::completions",
236                "chat_streaming",
237                gen_ai.operation.name = "chat_streaming",
238                gen_ai.provider.name = tracing::field::Empty,
239                gen_ai.request.model = tracing::field::Empty,
240                gen_ai.response.id = tracing::field::Empty,
241                gen_ai.response.model = tracing::field::Empty,
242                gen_ai.usage.output_tokens = tracing::field::Empty,
243                gen_ai.usage.input_tokens = tracing::field::Empty,
244            )
245        } else {
246            tracing::Span::current()
247        };
248        span.record("gen_ai.provider.name", "openai");
249        span.record("gen_ai.request.model", &self.model);
250        // Build the request with proper headers for SSE
251        let client = self.client.clone();
252
253        let mut event_source = GenericEventSource::new(client, req);
254
255        let stream = stream! {
256            let mut final_usage = ResponsesUsage::new();
257
258            let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
259            let mut combined_text = String::new();
260            let span = tracing::Span::current();
261
262            while let Some(event_result) = event_source.next().await {
263                match event_result {
264                    Ok(Event::Open) => {
265                        tracing::trace!("SSE connection opened");
266                        tracing::info!("OpenAI stream started");
267                        continue;
268                    }
269                    Ok(Event::Message(evt)) => {
270                        // Skip heartbeat messages or empty data
271                        if evt.data.trim().is_empty() {
272                            continue;
273                        }
274
275                        let data = serde_json::from_str::<StreamingCompletionChunk>(&evt.data);
276
277                        let Ok(data) = data else {
278                            let err = data.unwrap_err();
279                            debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
280                            continue;
281                        };
282
283                        if let StreamingCompletionChunk::Delta(chunk) = &data {
284                            match &chunk.data {
285                                ItemChunkKind::OutputItemDone(message) => {
286                                    match message {
287                                        StreamingItemDoneOutput {  item: Output::FunctionCall(func), .. } => {
288                                            tool_calls.push(streaming::RawStreamingChoice::ToolCall { id: func.id.clone(), call_id: Some(func.call_id.clone()), name: func.name.clone(), arguments: func.arguments.clone() });
289                                        }
290
291                                        StreamingItemDoneOutput {  item: Output::Reasoning {  summary, id }, .. } => {
292                                            let reasoning = summary
293                                                .iter()
294                                                .map(|x| {
295                                                    let ReasoningSummary::SummaryText { text } = x;
296                                                    text.to_owned()
297                                                })
298                                                .collect::<Vec<String>>()
299                                                .join("\n");
300                                            yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning, id: Some(id.to_string()), signature: None })
301                                        }
302                                        _ => continue
303                                    }
304                                }
305                                ItemChunkKind::OutputTextDelta(delta) => {
306                                    combined_text.push_str(&delta.delta);
307                                    yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
308                                }
309                                ItemChunkKind::RefusalDelta(delta) => {
310                                    combined_text.push_str(&delta.delta);
311                                    yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
312                                }
313                                ItemChunkKind::FunctionCallArgsDelta(delta) => {
314                                    yield Ok(streaming::RawStreamingChoice::ToolCallDelta { id: delta.item_id.clone(), delta: delta.delta.clone() })
315                                }
316
317                                _ => { continue }
318                            }
319                        }
320
321                        if let StreamingCompletionChunk::Response(chunk) = data {
322                            if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk {
323                                span.record("gen_ai.response.id", response.id);
324                                span.record("gen_ai.response.model", response.model);
325                                if let Some(usage) = response.usage {
326                                    final_usage = usage;
327                                }
328                            } else {
329                                continue;
330                            }
331                        }
332                    }
333                    Err(crate::http_client::Error::StreamEnded) => {
334                        event_source.close();
335                    }
336                    Err(error) => {
337                        tracing::error!(?error, "SSE error");
338                        yield Err(CompletionError::ProviderError(error.to_string()));
339                        break;
340                    }
341                }
342            }
343
344            // Ensure event source is closed when stream ends
345            event_source.close();
346
347            for tool_call in &tool_calls {
348                yield Ok(tool_call.to_owned())
349            }
350
351            span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
352            span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
353            tracing::info!("OpenAI stream finished");
354
355            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
356                usage: final_usage
357            }));
358        }.instrument(span);
359
360        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
361            stream,
362        )))
363    }
364}