rig/providers/openai/responses_api/
streaming.rs

1//! The streaming module for the OpenAI Responses API.
2//! Please see the `openai_streaming` or `openai_streaming_with_tools` example for more practical usage.
3use crate::completion::{CompletionError, GetTokenUsage};
4use crate::http_client::HttpClientExt;
5use crate::http_client::sse::{Event, GenericEventSource};
6use crate::providers::openai::responses_api::{
7    ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
8};
9use crate::streaming;
10use crate::streaming::RawStreamingChoice;
11use async_stream::stream;
12use futures::StreamExt;
13use serde::{Deserialize, Serialize};
14use tracing::{debug, info_span};
15use tracing_futures::Instrument as _;
16
17use super::{CompletionResponse, Output};
18
19// ================================================================
20// OpenAI Responses Streaming API
21// ================================================================
22
23/// A streaming completion chunk.
24/// Streaming chunks can come in one of two forms:
25/// - A response chunk (where the completed response will have the total token usage)
26/// - An item chunk commonly referred to as a delta. In the completions API this would be referred to as the message delta.
27#[derive(Debug, Serialize, Deserialize, Clone)]
28#[serde(untagged)]
29pub enum StreamingCompletionChunk {
30    Response(Box<ResponseChunk>),
31    Delta(ItemChunk),
32}
33
34/// The final streaming response from the OpenAI Responses API.
35#[derive(Debug, Serialize, Deserialize, Clone)]
36pub struct StreamingCompletionResponse {
37    /// Token usage
38    pub usage: ResponsesUsage,
39}
40
41impl GetTokenUsage for StreamingCompletionResponse {
42    fn token_usage(&self) -> Option<crate::completion::Usage> {
43        let mut usage = crate::completion::Usage::new();
44        usage.input_tokens = self.usage.input_tokens;
45        usage.output_tokens = self.usage.output_tokens;
46        usage.total_tokens = self.usage.total_tokens;
47        Some(usage)
48    }
49}
50
51/// A response chunk from OpenAI's response API.
52#[derive(Debug, Serialize, Deserialize, Clone)]
53pub struct ResponseChunk {
54    /// The response chunk type
55    #[serde(rename = "type")]
56    pub kind: ResponseChunkKind,
57    /// The response itself
58    pub response: CompletionResponse,
59    /// The item sequence
60    pub sequence_number: u64,
61}
62
63/// Response chunk type.
64/// Renames are used to ensure that this type gets (de)serialized properly.
65#[derive(Debug, Serialize, Deserialize, Clone)]
66pub enum ResponseChunkKind {
67    #[serde(rename = "response.created")]
68    ResponseCreated,
69    #[serde(rename = "response.in_progress")]
70    ResponseInProgress,
71    #[serde(rename = "response.completed")]
72    ResponseCompleted,
73    #[serde(rename = "response.failed")]
74    ResponseFailed,
75    #[serde(rename = "response.incomplete")]
76    ResponseIncomplete,
77}
78
79/// An item message chunk from OpenAI's Responses API.
80/// See
81#[derive(Debug, Serialize, Deserialize, Clone)]
82pub struct ItemChunk {
83    /// Item ID. Optional.
84    pub item_id: Option<String>,
85    /// The output index of the item from a given streamed response.
86    pub output_index: u64,
87    /// The item type chunk, as well as the inner data.
88    #[serde(flatten)]
89    pub data: ItemChunkKind,
90}
91
92/// The item chunk type from OpenAI's Responses API.
93#[derive(Debug, Serialize, Deserialize, Clone)]
94#[serde(tag = "type")]
95pub enum ItemChunkKind {
96    #[serde(rename = "response.output_item.added")]
97    OutputItemAdded(StreamingItemDoneOutput),
98    #[serde(rename = "response.output_item.done")]
99    OutputItemDone(StreamingItemDoneOutput),
100    #[serde(rename = "response.content_part.added")]
101    ContentPartAdded(ContentPartChunk),
102    #[serde(rename = "response.content_part.done")]
103    ContentPartDone(ContentPartChunk),
104    #[serde(rename = "response.output_text.delta")]
105    OutputTextDelta(DeltaTextChunk),
106    #[serde(rename = "response.output_text.done")]
107    OutputTextDone(OutputTextChunk),
108    #[serde(rename = "response.refusal.delta")]
109    RefusalDelta(DeltaTextChunk),
110    #[serde(rename = "response.refusal.done")]
111    RefusalDone(RefusalTextChunk),
112    #[serde(rename = "response.function_call_arguments.delta")]
113    FunctionCallArgsDelta(DeltaTextChunkWithItemId),
114    #[serde(rename = "response.function_call_arguments.done")]
115    FunctionCallArgsDone(ArgsTextChunk),
116    #[serde(rename = "response.reasoning_summary_part.added")]
117    ReasoningSummaryPartAdded(SummaryPartChunk),
118    #[serde(rename = "response.reasoning_summary_part.done")]
119    ReasoningSummaryPartDone(SummaryPartChunk),
120    #[serde(rename = "response.reasoning_summary_text.added")]
121    ReasoningSummaryTextAdded(SummaryTextChunk),
122    #[serde(rename = "response.reasoning_summary_text.done")]
123    ReasoningSummaryTextDone(SummaryTextChunk),
124}
125
126#[derive(Debug, Serialize, Deserialize, Clone)]
127pub struct StreamingItemDoneOutput {
128    pub sequence_number: u64,
129    pub item: Output,
130}
131
132#[derive(Debug, Serialize, Deserialize, Clone)]
133pub struct ContentPartChunk {
134    pub content_index: u64,
135    pub sequence_number: u64,
136    pub part: ContentPartChunkPart,
137}
138
139#[derive(Debug, Serialize, Deserialize, Clone)]
140#[serde(tag = "type")]
141pub enum ContentPartChunkPart {
142    OutputText { text: String },
143    SummaryText { text: String },
144}
145
146#[derive(Debug, Serialize, Deserialize, Clone)]
147pub struct DeltaTextChunk {
148    pub content_index: u64,
149    pub sequence_number: u64,
150    pub delta: String,
151}
152
153#[derive(Debug, Serialize, Deserialize, Clone)]
154pub struct DeltaTextChunkWithItemId {
155    pub item_id: String,
156    pub content_index: u64,
157    pub sequence_number: u64,
158    pub delta: String,
159}
160
161#[derive(Debug, Serialize, Deserialize, Clone)]
162pub struct OutputTextChunk {
163    pub content_index: u64,
164    pub sequence_number: u64,
165    pub text: String,
166}
167
168#[derive(Debug, Serialize, Deserialize, Clone)]
169pub struct RefusalTextChunk {
170    pub content_index: u64,
171    pub sequence_number: u64,
172    pub refusal: String,
173}
174
175#[derive(Debug, Serialize, Deserialize, Clone)]
176pub struct ArgsTextChunk {
177    pub content_index: u64,
178    pub sequence_number: u64,
179    pub arguments: serde_json::Value,
180}
181
182#[derive(Debug, Serialize, Deserialize, Clone)]
183pub struct SummaryPartChunk {
184    pub summary_index: u64,
185    pub sequence_number: u64,
186    pub part: SummaryPartChunkPart,
187}
188
189#[derive(Debug, Serialize, Deserialize, Clone)]
190pub struct SummaryTextChunk {
191    pub summary_index: u64,
192    pub sequence_number: u64,
193    pub delta: String,
194}
195
196#[derive(Debug, Serialize, Deserialize, Clone)]
197#[serde(tag = "type")]
198pub enum SummaryPartChunkPart {
199    SummaryText { text: String },
200}
201
202impl<T> ResponsesCompletionModel<T>
203where
204    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
205{
206    pub(crate) async fn stream(
207        &self,
208        completion_request: crate::completion::CompletionRequest,
209    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
210    {
211        let mut request = self.create_completion_request(completion_request)?;
212        request.stream = Some(true);
213
214        let body = serde_json::to_vec(&request)?;
215
216        let req = self
217            .client
218            .post("/responses")?
219            .header("Content-Type", "application/json")
220            .body(body)
221            .map_err(|e| CompletionError::HttpError(e.into()))?;
222
223        // let request_builder = self.client.post_reqwest("/responses").json(&request);
224
225        let span = if tracing::Span::current().is_disabled() {
226            info_span!(
227                target: "rig::completions",
228                "chat_streaming",
229                gen_ai.operation.name = "chat_streaming",
230                gen_ai.provider.name = tracing::field::Empty,
231                gen_ai.request.model = tracing::field::Empty,
232                gen_ai.response.id = tracing::field::Empty,
233                gen_ai.response.model = tracing::field::Empty,
234                gen_ai.usage.output_tokens = tracing::field::Empty,
235                gen_ai.usage.input_tokens = tracing::field::Empty,
236                gen_ai.input.messages = tracing::field::Empty,
237                gen_ai.output.messages = tracing::field::Empty,
238            )
239        } else {
240            tracing::Span::current()
241        };
242        span.record("gen_ai.provider.name", "openai");
243        span.record("gen_ai.request.model", &self.model);
244        span.record(
245            "gen_ai.input.messages",
246            serde_json::to_string(&request.input).expect("This should always work"),
247        );
248        // Build the request with proper headers for SSE
249        let client = self.clone().client.http_client;
250
251        let mut event_source = GenericEventSource::new(client, req);
252
253        let stream = stream! {
254            let mut final_usage = ResponsesUsage::new();
255
256            let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
257            let mut combined_text = String::new();
258            let span = tracing::Span::current();
259
260            while let Some(event_result) = event_source.next().await {
261                match event_result {
262                    Ok(Event::Open) => {
263                        tracing::trace!("SSE connection opened");
264                        tracing::info!("OpenAI stream started");
265                        continue;
266                    }
267                    Ok(Event::Message(evt)) => {
268                        // Skip heartbeat messages or empty data
269                        if evt.data.trim().is_empty() {
270                            continue;
271                        }
272
273                        let data = serde_json::from_str::<StreamingCompletionChunk>(&evt.data);
274
275                        let Ok(data) = data else {
276                            let err = data.unwrap_err();
277                            debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
278                            continue;
279                        };
280
281                        if let StreamingCompletionChunk::Delta(chunk) = &data {
282                            match &chunk.data {
283                                ItemChunkKind::OutputItemDone(message) => {
284                                    match message {
285                                        StreamingItemDoneOutput {  item: Output::FunctionCall(func), .. } => {
286                                            tool_calls.push(streaming::RawStreamingChoice::ToolCall { id: func.id.clone(), call_id: Some(func.call_id.clone()), name: func.name.clone(), arguments: func.arguments.clone() });
287                                        }
288
289                                        StreamingItemDoneOutput {  item: Output::Reasoning {  summary, id }, .. } => {
290                                            let reasoning = summary
291                                                .iter()
292                                                .map(|x| {
293                                                    let ReasoningSummary::SummaryText { text } = x;
294                                                    text.to_owned()
295                                                })
296                                                .collect::<Vec<String>>()
297                                                .join("\n");
298                                            yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning, id: Some(id.to_string()), signature: None })
299                                        }
300                                        _ => continue
301                                    }
302                                }
303                                ItemChunkKind::OutputTextDelta(delta) => {
304                                    combined_text.push_str(&delta.delta);
305                                    yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
306                                }
307                                ItemChunkKind::RefusalDelta(delta) => {
308                                    combined_text.push_str(&delta.delta);
309                                    yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
310                                }
311                                ItemChunkKind::FunctionCallArgsDelta(delta) => {
312                                    yield Ok(streaming::RawStreamingChoice::ToolCallDelta { id: delta.item_id.clone(), delta: delta.delta.clone() })
313                                }
314
315                                _ => { continue }
316                            }
317                        }
318
319                        if let StreamingCompletionChunk::Response(chunk) = data {
320                            if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk {
321                                span.record("gen_ai.output.messages", serde_json::to_string(&response.output).unwrap());
322                                span.record("gen_ai.response.id", response.id);
323                                span.record("gen_ai.response.model", response.model);
324                                if let Some(usage) = response.usage {
325                                    final_usage = usage;
326                                }
327                            } else {
328                                continue;
329                            }
330                        }
331                    }
332                    Err(crate::http_client::Error::StreamEnded) => {
333                        event_source.close();
334                    }
335                    Err(error) => {
336                        tracing::error!(?error, "SSE error");
337                        yield Err(CompletionError::ResponseError(error.to_string()));
338                        break;
339                    }
340                }
341            }
342
343            // // Ensure event source is closed when stream ends
344            // event_source.close();
345
346            for tool_call in &tool_calls {
347                yield Ok(tool_call.to_owned())
348            }
349
350            span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
351            span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
352            tracing::info!("OpenAI stream finished");
353
354            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
355                usage: final_usage.clone()
356            }));
357        }.instrument(span);
358
359        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
360            stream,
361        )))
362    }
363}