rig/providers/openai/responses_api/
streaming.rs

1//! The streaming module for the OpenAI Responses API.
2//! Please see the `openai_streaming` or `openai_streaming_with_tools` example for more practical usage.
3use crate::completion::{CompletionError, GetTokenUsage};
4use crate::http_client::HttpClientExt;
5use crate::http_client::sse::{Event, GenericEventSource};
6use crate::providers::openai::responses_api::{
7    ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
8};
9use crate::streaming;
10use crate::streaming::RawStreamingChoice;
11use crate::wasm_compat::WasmCompatSend;
12use async_stream::stream;
13use futures::StreamExt;
14use serde::{Deserialize, Serialize};
15use tracing::{debug, info_span};
16use tracing_futures::Instrument as _;
17
18use super::{CompletionResponse, Output};
19
20// ================================================================
21// OpenAI Responses Streaming API
22// ================================================================
23
24/// A streaming completion chunk.
25/// Streaming chunks can come in one of two forms:
26/// - A response chunk (where the completed response will have the total token usage)
27/// - An item chunk commonly referred to as a delta. In the completions API this would be referred to as the message delta.
28#[derive(Debug, Serialize, Deserialize, Clone)]
29#[serde(untagged)]
30pub enum StreamingCompletionChunk {
31    Response(Box<ResponseChunk>),
32    Delta(ItemChunk),
33}
34
35/// The final streaming response from the OpenAI Responses API.
36#[derive(Debug, Serialize, Deserialize, Clone)]
37pub struct StreamingCompletionResponse {
38    /// Token usage
39    pub usage: ResponsesUsage,
40}
41
42impl GetTokenUsage for StreamingCompletionResponse {
43    fn token_usage(&self) -> Option<crate::completion::Usage> {
44        let mut usage = crate::completion::Usage::new();
45        usage.input_tokens = self.usage.input_tokens;
46        usage.output_tokens = self.usage.output_tokens;
47        usage.total_tokens = self.usage.total_tokens;
48        Some(usage)
49    }
50}
51
52/// A response chunk from OpenAI's response API.
53#[derive(Debug, Serialize, Deserialize, Clone)]
54pub struct ResponseChunk {
55    /// The response chunk type
56    #[serde(rename = "type")]
57    pub kind: ResponseChunkKind,
58    /// The response itself
59    pub response: CompletionResponse,
60    /// The item sequence
61    pub sequence_number: u64,
62}
63
64/// Response chunk type.
65/// Renames are used to ensure that this type gets (de)serialized properly.
66#[derive(Debug, Serialize, Deserialize, Clone)]
67pub enum ResponseChunkKind {
68    #[serde(rename = "response.created")]
69    ResponseCreated,
70    #[serde(rename = "response.in_progress")]
71    ResponseInProgress,
72    #[serde(rename = "response.completed")]
73    ResponseCompleted,
74    #[serde(rename = "response.failed")]
75    ResponseFailed,
76    #[serde(rename = "response.incomplete")]
77    ResponseIncomplete,
78}
79
80/// An item message chunk from OpenAI's Responses API.
81/// See
82#[derive(Debug, Serialize, Deserialize, Clone)]
83pub struct ItemChunk {
84    /// Item ID. Optional.
85    pub item_id: Option<String>,
86    /// The output index of the item from a given streamed response.
87    pub output_index: u64,
88    /// The item type chunk, as well as the inner data.
89    #[serde(flatten)]
90    pub data: ItemChunkKind,
91}
92
93/// The item chunk type from OpenAI's Responses API.
94#[derive(Debug, Serialize, Deserialize, Clone)]
95#[serde(tag = "type")]
96pub enum ItemChunkKind {
97    #[serde(rename = "response.output_item.added")]
98    OutputItemAdded(StreamingItemDoneOutput),
99    #[serde(rename = "response.output_item.done")]
100    OutputItemDone(StreamingItemDoneOutput),
101    #[serde(rename = "response.content_part.added")]
102    ContentPartAdded(ContentPartChunk),
103    #[serde(rename = "response.content_part.done")]
104    ContentPartDone(ContentPartChunk),
105    #[serde(rename = "response.output_text.delta")]
106    OutputTextDelta(DeltaTextChunk),
107    #[serde(rename = "response.output_text.done")]
108    OutputTextDone(OutputTextChunk),
109    #[serde(rename = "response.refusal.delta")]
110    RefusalDelta(DeltaTextChunk),
111    #[serde(rename = "response.refusal.done")]
112    RefusalDone(RefusalTextChunk),
113    #[serde(rename = "response.function_call_arguments.delta")]
114    FunctionCallArgsDelta(DeltaTextChunkWithItemId),
115    #[serde(rename = "response.function_call_arguments.done")]
116    FunctionCallArgsDone(ArgsTextChunk),
117    #[serde(rename = "response.reasoning_summary_part.added")]
118    ReasoningSummaryPartAdded(SummaryPartChunk),
119    #[serde(rename = "response.reasoning_summary_part.done")]
120    ReasoningSummaryPartDone(SummaryPartChunk),
121    #[serde(rename = "response.reasoning_summary_text.added")]
122    ReasoningSummaryTextAdded(SummaryTextChunk),
123    #[serde(rename = "response.reasoning_summary_text.done")]
124    ReasoningSummaryTextDone(SummaryTextChunk),
125}
126
127#[derive(Debug, Serialize, Deserialize, Clone)]
128pub struct StreamingItemDoneOutput {
129    pub sequence_number: u64,
130    pub item: Output,
131}
132
133#[derive(Debug, Serialize, Deserialize, Clone)]
134pub struct ContentPartChunk {
135    pub content_index: u64,
136    pub sequence_number: u64,
137    pub part: ContentPartChunkPart,
138}
139
140#[derive(Debug, Serialize, Deserialize, Clone)]
141#[serde(tag = "type")]
142pub enum ContentPartChunkPart {
143    OutputText { text: String },
144    SummaryText { text: String },
145}
146
147#[derive(Debug, Serialize, Deserialize, Clone)]
148pub struct DeltaTextChunk {
149    pub content_index: u64,
150    pub sequence_number: u64,
151    pub delta: String,
152}
153
154#[derive(Debug, Serialize, Deserialize, Clone)]
155pub struct DeltaTextChunkWithItemId {
156    pub item_id: String,
157    pub content_index: u64,
158    pub sequence_number: u64,
159    pub delta: String,
160}
161
162#[derive(Debug, Serialize, Deserialize, Clone)]
163pub struct OutputTextChunk {
164    pub content_index: u64,
165    pub sequence_number: u64,
166    pub text: String,
167}
168
169#[derive(Debug, Serialize, Deserialize, Clone)]
170pub struct RefusalTextChunk {
171    pub content_index: u64,
172    pub sequence_number: u64,
173    pub refusal: String,
174}
175
176#[derive(Debug, Serialize, Deserialize, Clone)]
177pub struct ArgsTextChunk {
178    pub content_index: u64,
179    pub sequence_number: u64,
180    pub arguments: serde_json::Value,
181}
182
183#[derive(Debug, Serialize, Deserialize, Clone)]
184pub struct SummaryPartChunk {
185    pub summary_index: u64,
186    pub sequence_number: u64,
187    pub part: SummaryPartChunkPart,
188}
189
190#[derive(Debug, Serialize, Deserialize, Clone)]
191pub struct SummaryTextChunk {
192    pub summary_index: u64,
193    pub sequence_number: u64,
194    pub delta: String,
195}
196
197#[derive(Debug, Serialize, Deserialize, Clone)]
198#[serde(tag = "type")]
199pub enum SummaryPartChunkPart {
200    SummaryText { text: String },
201}
202
203impl<T> ResponsesCompletionModel<T>
204where
205    T: HttpClientExt + Clone + Default + std::fmt::Debug + WasmCompatSend + 'static,
206{
207    pub(crate) async fn stream(
208        &self,
209        completion_request: crate::completion::CompletionRequest,
210    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
211    {
212        let mut request = self.create_completion_request(completion_request)?;
213        request.stream = Some(true);
214
215        let body = serde_json::to_vec(&request)?;
216
217        let req = self
218            .client
219            .post("/responses")?
220            .body(body)
221            .map_err(|e| CompletionError::HttpError(e.into()))?;
222
223        // let request_builder = self.client.post_reqwest("/responses").json(&request);
224
225        let span = if tracing::Span::current().is_disabled() {
226            info_span!(
227                target: "rig::completions",
228                "chat_streaming",
229                gen_ai.operation.name = "chat_streaming",
230                gen_ai.provider.name = tracing::field::Empty,
231                gen_ai.request.model = tracing::field::Empty,
232                gen_ai.response.id = tracing::field::Empty,
233                gen_ai.response.model = tracing::field::Empty,
234                gen_ai.usage.output_tokens = tracing::field::Empty,
235                gen_ai.usage.input_tokens = tracing::field::Empty,
236                gen_ai.input.messages = tracing::field::Empty,
237                gen_ai.output.messages = tracing::field::Empty,
238            )
239        } else {
240            tracing::Span::current()
241        };
242        span.record("gen_ai.provider.name", "openai");
243        span.record("gen_ai.request.model", &self.model);
244        span.record(
245            "gen_ai.input.messages",
246            serde_json::to_string(&request.input).expect("This should always work"),
247        );
248        // Build the request with proper headers for SSE
249        let client = self.client.http_client().clone();
250
251        let mut event_source = GenericEventSource::new(client, req);
252
253        let stream = stream! {
254            let mut final_usage = ResponsesUsage::new();
255
256            let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
257            let mut combined_text = String::new();
258            let span = tracing::Span::current();
259
260            while let Some(event_result) = event_source.next().await {
261                match event_result {
262                    Ok(Event::Open) => {
263                        tracing::trace!("SSE connection opened");
264                        tracing::info!("OpenAI stream started");
265                        continue;
266                    }
267                    Ok(Event::Message(evt)) => {
268                        // Skip heartbeat messages or empty data
269                        if evt.data.trim().is_empty() {
270                            continue;
271                        }
272
273                        let data = serde_json::from_str::<StreamingCompletionChunk>(&evt.data);
274
275                        let Ok(data) = data else {
276                            let err = data.unwrap_err();
277                            debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
278                            continue;
279                        };
280
281                        if let StreamingCompletionChunk::Delta(chunk) = &data {
282                            match &chunk.data {
283                                ItemChunkKind::OutputItemDone(message) => {
284                                    match message {
285                                        StreamingItemDoneOutput {  item: Output::FunctionCall(func), .. } => {
286                                            tool_calls.push(streaming::RawStreamingChoice::ToolCall { id: func.id.clone(), call_id: Some(func.call_id.clone()), name: func.name.clone(), arguments: func.arguments.clone() });
287                                        }
288
289                                        StreamingItemDoneOutput {  item: Output::Reasoning {  summary, id }, .. } => {
290                                            let reasoning = summary
291                                                .iter()
292                                                .map(|x| {
293                                                    let ReasoningSummary::SummaryText { text } = x;
294                                                    text.to_owned()
295                                                })
296                                                .collect::<Vec<String>>()
297                                                .join("\n");
298                                            yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning, id: Some(id.to_string()), signature: None })
299                                        }
300                                        _ => continue
301                                    }
302                                }
303                                ItemChunkKind::OutputTextDelta(delta) => {
304                                    combined_text.push_str(&delta.delta);
305                                    yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
306                                }
307                                ItemChunkKind::RefusalDelta(delta) => {
308                                    combined_text.push_str(&delta.delta);
309                                    yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
310                                }
311                                ItemChunkKind::FunctionCallArgsDelta(delta) => {
312                                    yield Ok(streaming::RawStreamingChoice::ToolCallDelta { id: delta.item_id.clone(), delta: delta.delta.clone() })
313                                }
314
315                                _ => { continue }
316                            }
317                        }
318
319                        if let StreamingCompletionChunk::Response(chunk) = data {
320                            if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk {
321                                span.record("gen_ai.output.messages", serde_json::to_string(&response.output).unwrap());
322                                span.record("gen_ai.response.id", response.id);
323                                span.record("gen_ai.response.model", response.model);
324                                if let Some(usage) = response.usage {
325                                    final_usage = usage;
326                                }
327                            } else {
328                                continue;
329                            }
330                        }
331                    }
332                    Err(crate::http_client::Error::StreamEnded) => {
333                        event_source.close();
334                    }
335                    Err(error) => {
336                        tracing::error!(?error, "SSE error");
337                        yield Err(CompletionError::ProviderError(error.to_string()));
338                        break;
339                    }
340                }
341            }
342
343            // Ensure event source is closed when stream ends
344            event_source.close();
345
346            for tool_call in &tool_calls {
347                yield Ok(tool_call.to_owned())
348            }
349
350            span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
351            span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
352            tracing::info!("OpenAI stream finished");
353
354            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
355                usage: final_usage
356            }));
357        }.instrument(span);
358
359        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
360            stream,
361        )))
362    }
363}