Skip to main content

rig/providers/openai/responses_api/
streaming.rs

1//! The streaming module for the OpenAI Responses API.
2//! Please see the `openai_streaming` or `openai_streaming_with_tools` example for more practical usage.
3use crate::completion::{CompletionError, GetTokenUsage};
4use crate::http_client::HttpClientExt;
5use crate::http_client::sse::{Event, GenericEventSource};
6use crate::providers::openai::responses_api::{
7    ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
8};
9use crate::streaming;
10use crate::streaming::RawStreamingChoice;
11use crate::wasm_compat::WasmCompatSend;
12use async_stream::stream;
13use futures::StreamExt;
14use serde::{Deserialize, Serialize};
15use tracing::{Level, debug, enabled, info_span};
16use tracing_futures::Instrument as _;
17
18use super::{CompletionResponse, Output};
19
20// ================================================================
21// OpenAI Responses Streaming API
22// ================================================================
23
24/// A streaming completion chunk.
25/// Streaming chunks can come in one of two forms:
26/// - A response chunk (where the completed response will have the total token usage)
27/// - An item chunk commonly referred to as a delta. In the completions API this would be referred to as the message delta.
28#[derive(Debug, Serialize, Deserialize, Clone)]
29#[serde(untagged)]
30pub enum StreamingCompletionChunk {
31    Response(Box<ResponseChunk>),
32    Delta(ItemChunk),
33}
34
35/// The final streaming response from the OpenAI Responses API.
36#[derive(Debug, Serialize, Deserialize, Clone)]
37pub struct StreamingCompletionResponse {
38    /// Token usage
39    pub usage: ResponsesUsage,
40}
41
42impl GetTokenUsage for StreamingCompletionResponse {
43    fn token_usage(&self) -> Option<crate::completion::Usage> {
44        let mut usage = crate::completion::Usage::new();
45        usage.input_tokens = self.usage.input_tokens;
46        usage.output_tokens = self.usage.output_tokens;
47        usage.total_tokens = self.usage.total_tokens;
48        usage.cached_input_tokens = self
49            .usage
50            .input_tokens_details
51            .as_ref()
52            .map(|d| d.cached_tokens)
53            .unwrap_or(0);
54        Some(usage)
55    }
56}
57
58/// A response chunk from OpenAI's response API.
59#[derive(Debug, Serialize, Deserialize, Clone)]
60pub struct ResponseChunk {
61    /// The response chunk type
62    #[serde(rename = "type")]
63    pub kind: ResponseChunkKind,
64    /// The response itself
65    pub response: CompletionResponse,
66    /// The item sequence
67    pub sequence_number: u64,
68}
69
70/// Response chunk type.
71/// Renames are used to ensure that this type gets (de)serialized properly.
72#[derive(Debug, Serialize, Deserialize, Clone)]
73pub enum ResponseChunkKind {
74    #[serde(rename = "response.created")]
75    ResponseCreated,
76    #[serde(rename = "response.in_progress")]
77    ResponseInProgress,
78    #[serde(rename = "response.completed")]
79    ResponseCompleted,
80    #[serde(rename = "response.failed")]
81    ResponseFailed,
82    #[serde(rename = "response.incomplete")]
83    ResponseIncomplete,
84}
85
86/// An item message chunk from OpenAI's Responses API.
87/// See
88#[derive(Debug, Serialize, Deserialize, Clone)]
89pub struct ItemChunk {
90    /// Item ID. Optional.
91    pub item_id: Option<String>,
92    /// The output index of the item from a given streamed response.
93    pub output_index: u64,
94    /// The item type chunk, as well as the inner data.
95    #[serde(flatten)]
96    pub data: ItemChunkKind,
97}
98
99/// The item chunk type from OpenAI's Responses API.
100#[derive(Debug, Serialize, Deserialize, Clone)]
101#[serde(tag = "type")]
102pub enum ItemChunkKind {
103    #[serde(rename = "response.output_item.added")]
104    OutputItemAdded(StreamingItemDoneOutput),
105    #[serde(rename = "response.output_item.done")]
106    OutputItemDone(StreamingItemDoneOutput),
107    #[serde(rename = "response.content_part.added")]
108    ContentPartAdded(ContentPartChunk),
109    #[serde(rename = "response.content_part.done")]
110    ContentPartDone(ContentPartChunk),
111    #[serde(rename = "response.output_text.delta")]
112    OutputTextDelta(DeltaTextChunk),
113    #[serde(rename = "response.output_text.done")]
114    OutputTextDone(OutputTextChunk),
115    #[serde(rename = "response.refusal.delta")]
116    RefusalDelta(DeltaTextChunk),
117    #[serde(rename = "response.refusal.done")]
118    RefusalDone(RefusalTextChunk),
119    #[serde(rename = "response.function_call_arguments.delta")]
120    FunctionCallArgsDelta(DeltaTextChunkWithItemId),
121    #[serde(rename = "response.function_call_arguments.done")]
122    FunctionCallArgsDone(ArgsTextChunk),
123    #[serde(rename = "response.reasoning_summary_part.added")]
124    ReasoningSummaryPartAdded(SummaryPartChunk),
125    #[serde(rename = "response.reasoning_summary_part.done")]
126    ReasoningSummaryPartDone(SummaryPartChunk),
127    #[serde(rename = "response.reasoning_summary_text.delta")]
128    ReasoningSummaryTextDelta(SummaryTextChunk),
129    #[serde(rename = "response.reasoning_summary_text.done")]
130    ReasoningSummaryTextDone(SummaryTextChunk),
131}
132
133#[derive(Debug, Serialize, Deserialize, Clone)]
134pub struct StreamingItemDoneOutput {
135    pub sequence_number: u64,
136    pub item: Output,
137}
138
139#[derive(Debug, Serialize, Deserialize, Clone)]
140pub struct ContentPartChunk {
141    pub content_index: u64,
142    pub sequence_number: u64,
143    pub part: ContentPartChunkPart,
144}
145
146#[derive(Debug, Serialize, Deserialize, Clone)]
147#[serde(tag = "type")]
148pub enum ContentPartChunkPart {
149    OutputText { text: String },
150    SummaryText { text: String },
151}
152
153#[derive(Debug, Serialize, Deserialize, Clone)]
154pub struct DeltaTextChunk {
155    pub content_index: u64,
156    pub sequence_number: u64,
157    pub delta: String,
158}
159
160#[derive(Debug, Serialize, Deserialize, Clone)]
161pub struct DeltaTextChunkWithItemId {
162    pub item_id: String,
163    pub content_index: u64,
164    pub sequence_number: u64,
165    pub delta: String,
166}
167
168#[derive(Debug, Serialize, Deserialize, Clone)]
169pub struct OutputTextChunk {
170    pub content_index: u64,
171    pub sequence_number: u64,
172    pub text: String,
173}
174
175#[derive(Debug, Serialize, Deserialize, Clone)]
176pub struct RefusalTextChunk {
177    pub content_index: u64,
178    pub sequence_number: u64,
179    pub refusal: String,
180}
181
182#[derive(Debug, Serialize, Deserialize, Clone)]
183pub struct ArgsTextChunk {
184    pub content_index: u64,
185    pub sequence_number: u64,
186    pub arguments: serde_json::Value,
187}
188
189#[derive(Debug, Serialize, Deserialize, Clone)]
190pub struct SummaryPartChunk {
191    pub summary_index: u64,
192    pub sequence_number: u64,
193    pub part: SummaryPartChunkPart,
194}
195
196#[derive(Debug, Serialize, Deserialize, Clone)]
197pub struct SummaryTextChunk {
198    pub summary_index: u64,
199    pub sequence_number: u64,
200    pub delta: String,
201}
202
203#[derive(Debug, Serialize, Deserialize, Clone)]
204#[serde(tag = "type")]
205pub enum SummaryPartChunkPart {
206    SummaryText { text: String },
207}
208
209impl<T> ResponsesCompletionModel<T>
210where
211    T: HttpClientExt + Clone + Default + std::fmt::Debug + WasmCompatSend + 'static,
212{
213    pub(crate) async fn stream(
214        &self,
215        completion_request: crate::completion::CompletionRequest,
216    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
217    {
218        let mut request = self.create_completion_request(completion_request)?;
219        request.stream = Some(true);
220
221        if enabled!(Level::TRACE) {
222            tracing::trace!(
223                target: "rig::completions",
224                "OpenAI Responses streaming completion request: {}",
225                serde_json::to_string_pretty(&request)?
226            );
227        }
228
229        let body = serde_json::to_vec(&request)?;
230
231        let req = self
232            .client
233            .post("/responses")?
234            .body(body)
235            .map_err(|e| CompletionError::HttpError(e.into()))?;
236
237        // let request_builder = self.client.post_reqwest("/responses").json(&request);
238
239        let span = if tracing::Span::current().is_disabled() {
240            info_span!(
241                target: "rig::completions",
242                "chat_streaming",
243                gen_ai.operation.name = "chat_streaming",
244                gen_ai.provider.name = tracing::field::Empty,
245                gen_ai.request.model = tracing::field::Empty,
246                gen_ai.response.id = tracing::field::Empty,
247                gen_ai.response.model = tracing::field::Empty,
248                gen_ai.usage.output_tokens = tracing::field::Empty,
249                gen_ai.usage.input_tokens = tracing::field::Empty,
250            )
251        } else {
252            tracing::Span::current()
253        };
254        span.record("gen_ai.provider.name", "openai");
255        span.record("gen_ai.request.model", &self.model);
256        // Build the request with proper headers for SSE
257        let client = self.client.clone();
258
259        let mut event_source = GenericEventSource::new(client, req);
260
261        let stream = stream! {
262            let mut final_usage = ResponsesUsage::new();
263
264            let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
265            let mut tool_call_internal_ids: std::collections::HashMap<String, String> = std::collections::HashMap::new();
266            let mut combined_text = String::new();
267            let span = tracing::Span::current();
268
269            while let Some(event_result) = event_source.next().await {
270                match event_result {
271                    Ok(Event::Open) => {
272                        tracing::trace!("SSE connection opened");
273                        tracing::info!("OpenAI stream started");
274                        continue;
275                    }
276                    Ok(Event::Message(evt)) => {
277                        // Skip heartbeat messages or empty data
278                        if evt.data.trim().is_empty() {
279                            continue;
280                        }
281
282                        let data = serde_json::from_str::<StreamingCompletionChunk>(&evt.data);
283
284                        let Ok(data) = data else {
285                            let err = data.unwrap_err();
286                            debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
287                            continue;
288                        };
289
290                        if let StreamingCompletionChunk::Delta(chunk) = &data {
291                            match &chunk.data {
292                                ItemChunkKind::OutputItemAdded(message) => {
293                                    if let StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } = message {
294                                        let internal_call_id = tool_call_internal_ids
295                                            .entry(func.id.clone())
296                                            .or_insert_with(|| nanoid::nanoid!())
297                                            .clone();
298                                        yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
299                                            id: func.id.clone(),
300                                            internal_call_id,
301                                            content: streaming::ToolCallDeltaContent::Name(func.name.clone()),
302                                        });
303                                    }
304                                }
305                                ItemChunkKind::OutputItemDone(message) => {
306                                    match message {
307                                        StreamingItemDoneOutput {  item: Output::FunctionCall(func), .. } => {
308                                            let internal_id = tool_call_internal_ids
309                                                .entry(func.id.clone())
310                                                .or_insert_with(|| nanoid::nanoid!())
311                                                .clone();
312                                            let raw_tool_call = streaming::RawStreamingToolCall::new(
313                                                func.id.clone(),
314                                                func.name.clone(),
315                                                func.arguments.clone(),
316                                            )
317                                                .with_internal_call_id(internal_id)
318                                                .with_call_id(func.call_id.clone());
319                                            tool_calls.push(streaming::RawStreamingChoice::ToolCall(raw_tool_call));
320                                        }
321
322                                        StreamingItemDoneOutput {  item: Output::Reasoning {  summary, id }, .. } => {
323                                            let reasoning = summary
324                                                .iter()
325                                                .map(|x| {
326                                                    let ReasoningSummary::SummaryText { text } = x;
327                                                    text.to_owned()
328                                                })
329                                                .collect::<Vec<String>>()
330                                                .join("\n");
331                                            yield Ok(streaming::RawStreamingChoice::Reasoning {
332                                                id: Some(id.to_string()),
333                                                reasoning,
334                                                signature: None,
335                                            })
336                                        }
337                                        _ => continue
338                                    }
339                                }
340                                ItemChunkKind::OutputTextDelta(delta) => {
341                                    combined_text.push_str(&delta.delta);
342                                    yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
343                                }
344                                ItemChunkKind::ReasoningSummaryTextDelta(delta) => {
345                                    yield Ok(streaming::RawStreamingChoice::ReasoningDelta { id: None, reasoning: delta.delta.clone() })
346                                }
347                                ItemChunkKind::RefusalDelta(delta) => {
348                                    combined_text.push_str(&delta.delta);
349                                    yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
350                                }
351                                ItemChunkKind::FunctionCallArgsDelta(delta) => {
352                                    let internal_call_id = tool_call_internal_ids
353                                        .entry(delta.item_id.clone())
354                                        .or_insert_with(|| nanoid::nanoid!())
355                                        .clone();
356                                    yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
357                                        id: delta.item_id.clone(),
358                                        internal_call_id,
359                                        content: streaming::ToolCallDeltaContent::Delta(delta.delta.clone())
360                                    })
361                                }
362
363                                _ => { continue }
364                            }
365                        }
366
367                        if let StreamingCompletionChunk::Response(chunk) = data {
368                            if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk {
369                                span.record("gen_ai.response.id", response.id);
370                                span.record("gen_ai.response.model", response.model);
371                                if let Some(usage) = response.usage {
372                                    final_usage = usage;
373                                }
374                            } else {
375                                continue;
376                            }
377                        }
378                    }
379                    Err(crate::http_client::Error::StreamEnded) => {
380                        event_source.close();
381                    }
382                    Err(error) => {
383                        tracing::error!(?error, "SSE error");
384                        yield Err(CompletionError::ProviderError(error.to_string()));
385                        break;
386                    }
387                }
388            }
389
390            // Ensure event source is closed when stream ends
391            event_source.close();
392
393            for tool_call in &tool_calls {
394                yield Ok(tool_call.to_owned())
395            }
396
397            span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
398            span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
399            tracing::info!("OpenAI stream finished");
400
401            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
402                usage: final_usage
403            }));
404        }.instrument(span);
405
406        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
407            stream,
408        )))
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use futures::StreamExt;
415    use rig::{client::CompletionClient, providers::openai, streaming::StreamingChat};
416    use serde_json;
417
418    use crate::{
419        completion::ToolDefinition,
420        tool::{Tool, ToolError},
421    };
422
423    struct ExampleTool;
424
425    impl Tool for ExampleTool {
426        type Args = ();
427        type Error = ToolError;
428        type Output = String;
429        const NAME: &'static str = "example_tool";
430
431        async fn definition(&self, _prompt: String) -> ToolDefinition {
432            ToolDefinition {
433                name: self.name(),
434                description: "A tool that returns some example text.".to_string(),
435                parameters: serde_json::json!({
436                        "type": "object",
437                        "properties": {},
438                        "required": []
439                }),
440            }
441        }
442
443        async fn call(&self, _input: Self::Args) -> Result<Self::Output, Self::Error> {
444            let result = "Example answer".to_string();
445            Ok(result)
446        }
447    }
448
449    // requires `derive` rig-core feature due to using tool macro
450    #[tokio::test]
451    #[ignore = "requires API key"]
452    async fn test_openai_streaming_tools_reasoning() {
453        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY env var should exist");
454        let client: openai::Client<rig::http_client::ReqwestClient> =
455            openai::Client::new(&api_key).expect("Failed to build client");
456        let agent = client
457            .agent("gpt-5.2")
458            .max_tokens(8192)
459            .tool(ExampleTool)
460            .additional_params(serde_json::json!({
461                "reasoning": {"effort": "high"}
462            }))
463            .build();
464
465        let chat_history = Vec::new();
466        let mut stream = agent
467            .stream_chat("Call my example tool", chat_history)
468            .multi_turn(5)
469            .await;
470
471        while let Some(item) = stream.next().await {
472            println!("Got item: {item:?}");
473        }
474    }
475}