Skip to main content

rig/providers/openai/responses_api/
streaming.rs

1//! The streaming module for the OpenAI Responses API.
2//! Please see the `openai_streaming` or `openai_streaming_with_tools` example for more practical usage.
3use crate::completion::{CompletionError, GetTokenUsage};
4use crate::http_client::HttpClientExt;
5use crate::http_client::sse::{Event, GenericEventSource};
6use crate::message::ReasoningContent;
7use crate::providers::openai::responses_api::{
8    ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
9};
10use crate::streaming;
11use crate::streaming::RawStreamingChoice;
12use crate::wasm_compat::WasmCompatSend;
13use async_stream::stream;
14use futures::StreamExt;
15use serde::{Deserialize, Serialize};
16use tracing::{Level, debug, enabled, info_span};
17use tracing_futures::Instrument as _;
18
19use super::{CompletionResponse, Output};
20
21// ================================================================
22// OpenAI Responses Streaming API
23// ================================================================
24
25/// A streaming completion chunk.
26/// Streaming chunks can come in one of two forms:
27/// - A response chunk (where the completed response will have the total token usage)
28/// - An item chunk commonly referred to as a delta. In the completions API this would be referred to as the message delta.
29#[derive(Debug, Serialize, Deserialize, Clone)]
30#[serde(untagged)]
31pub enum StreamingCompletionChunk {
32    Response(Box<ResponseChunk>),
33    Delta(ItemChunk),
34}
35
36/// The final streaming response from the OpenAI Responses API.
37#[derive(Debug, Serialize, Deserialize, Clone)]
38pub struct StreamingCompletionResponse {
39    /// Token usage
40    pub usage: ResponsesUsage,
41}
42
43pub(crate) fn reasoning_choices_from_done_item(
44    id: &str,
45    summary: &[ReasoningSummary],
46    encrypted_content: Option<&str>,
47) -> Vec<RawStreamingChoice<StreamingCompletionResponse>> {
48    let mut choices = summary
49        .iter()
50        .map(|reasoning_summary| match reasoning_summary {
51            ReasoningSummary::SummaryText { text } => RawStreamingChoice::Reasoning {
52                id: Some(id.to_owned()),
53                content: ReasoningContent::Summary(text.to_owned()),
54            },
55        })
56        .collect::<Vec<_>>();
57
58    if let Some(encrypted_content) = encrypted_content {
59        choices.push(RawStreamingChoice::Reasoning {
60            id: Some(id.to_owned()),
61            content: ReasoningContent::Encrypted(encrypted_content.to_owned()),
62        });
63    }
64
65    choices
66}
67
68impl GetTokenUsage for StreamingCompletionResponse {
69    fn token_usage(&self) -> Option<crate::completion::Usage> {
70        let mut usage = crate::completion::Usage::new();
71        usage.input_tokens = self.usage.input_tokens;
72        usage.output_tokens = self.usage.output_tokens;
73        usage.total_tokens = self.usage.total_tokens;
74        usage.cached_input_tokens = self
75            .usage
76            .input_tokens_details
77            .as_ref()
78            .map(|d| d.cached_tokens)
79            .unwrap_or(0);
80        Some(usage)
81    }
82}
83
84/// A response chunk from OpenAI's response API.
85#[derive(Debug, Serialize, Deserialize, Clone)]
86pub struct ResponseChunk {
87    /// The response chunk type
88    #[serde(rename = "type")]
89    pub kind: ResponseChunkKind,
90    /// The response itself
91    pub response: CompletionResponse,
92    /// The item sequence
93    pub sequence_number: u64,
94}
95
96/// Response chunk type.
97/// Renames are used to ensure that this type gets (de)serialized properly.
98#[derive(Debug, Serialize, Deserialize, Clone)]
99pub enum ResponseChunkKind {
100    #[serde(rename = "response.created")]
101    ResponseCreated,
102    #[serde(rename = "response.in_progress")]
103    ResponseInProgress,
104    #[serde(rename = "response.completed")]
105    ResponseCompleted,
106    #[serde(rename = "response.failed")]
107    ResponseFailed,
108    #[serde(rename = "response.incomplete")]
109    ResponseIncomplete,
110}
111
112/// An item message chunk from OpenAI's Responses API.
113/// See
114#[derive(Debug, Serialize, Deserialize, Clone)]
115pub struct ItemChunk {
116    /// Item ID. Optional.
117    pub item_id: Option<String>,
118    /// The output index of the item from a given streamed response.
119    pub output_index: u64,
120    /// The item type chunk, as well as the inner data.
121    #[serde(flatten)]
122    pub data: ItemChunkKind,
123}
124
125/// The item chunk type from OpenAI's Responses API.
126#[derive(Debug, Serialize, Deserialize, Clone)]
127#[serde(tag = "type")]
128pub enum ItemChunkKind {
129    #[serde(rename = "response.output_item.added")]
130    OutputItemAdded(StreamingItemDoneOutput),
131    #[serde(rename = "response.output_item.done")]
132    OutputItemDone(StreamingItemDoneOutput),
133    #[serde(rename = "response.content_part.added")]
134    ContentPartAdded(ContentPartChunk),
135    #[serde(rename = "response.content_part.done")]
136    ContentPartDone(ContentPartChunk),
137    #[serde(rename = "response.output_text.delta")]
138    OutputTextDelta(DeltaTextChunk),
139    #[serde(rename = "response.output_text.done")]
140    OutputTextDone(OutputTextChunk),
141    #[serde(rename = "response.refusal.delta")]
142    RefusalDelta(DeltaTextChunk),
143    #[serde(rename = "response.refusal.done")]
144    RefusalDone(RefusalTextChunk),
145    #[serde(rename = "response.function_call_arguments.delta")]
146    FunctionCallArgsDelta(DeltaTextChunkWithItemId),
147    #[serde(rename = "response.function_call_arguments.done")]
148    FunctionCallArgsDone(ArgsTextChunk),
149    #[serde(rename = "response.reasoning_summary_part.added")]
150    ReasoningSummaryPartAdded(SummaryPartChunk),
151    #[serde(rename = "response.reasoning_summary_part.done")]
152    ReasoningSummaryPartDone(SummaryPartChunk),
153    #[serde(rename = "response.reasoning_summary_text.delta")]
154    ReasoningSummaryTextDelta(SummaryTextChunk),
155    #[serde(rename = "response.reasoning_summary_text.done")]
156    ReasoningSummaryTextDone(SummaryTextChunk),
157}
158
159#[derive(Debug, Serialize, Deserialize, Clone)]
160pub struct StreamingItemDoneOutput {
161    pub sequence_number: u64,
162    pub item: Output,
163}
164
165#[derive(Debug, Serialize, Deserialize, Clone)]
166pub struct ContentPartChunk {
167    pub content_index: u64,
168    pub sequence_number: u64,
169    pub part: ContentPartChunkPart,
170}
171
172#[derive(Debug, Serialize, Deserialize, Clone)]
173#[serde(tag = "type", rename_all = "snake_case")]
174pub enum ContentPartChunkPart {
175    OutputText { text: String },
176    SummaryText { text: String },
177}
178
179#[derive(Debug, Serialize, Deserialize, Clone)]
180pub struct DeltaTextChunk {
181    pub content_index: u64,
182    pub sequence_number: u64,
183    pub delta: String,
184}
185
186#[derive(Debug, Serialize, Deserialize, Clone)]
187pub struct DeltaTextChunkWithItemId {
188    pub item_id: String,
189    pub content_index: u64,
190    pub sequence_number: u64,
191    pub delta: String,
192}
193
194#[derive(Debug, Serialize, Deserialize, Clone)]
195pub struct OutputTextChunk {
196    pub content_index: u64,
197    pub sequence_number: u64,
198    pub text: String,
199}
200
201#[derive(Debug, Serialize, Deserialize, Clone)]
202pub struct RefusalTextChunk {
203    pub content_index: u64,
204    pub sequence_number: u64,
205    pub refusal: String,
206}
207
208#[derive(Debug, Serialize, Deserialize, Clone)]
209pub struct ArgsTextChunk {
210    pub content_index: u64,
211    pub sequence_number: u64,
212    pub arguments: serde_json::Value,
213}
214
215#[derive(Debug, Serialize, Deserialize, Clone)]
216pub struct SummaryPartChunk {
217    pub summary_index: u64,
218    pub sequence_number: u64,
219    pub part: SummaryPartChunkPart,
220}
221
222#[derive(Debug, Serialize, Deserialize, Clone)]
223pub struct SummaryTextChunk {
224    pub summary_index: u64,
225    pub sequence_number: u64,
226    pub delta: String,
227}
228
229#[derive(Debug, Serialize, Deserialize, Clone)]
230#[serde(tag = "type", rename_all = "snake_case")]
231pub enum SummaryPartChunkPart {
232    SummaryText { text: String },
233}
234
235impl<T> ResponsesCompletionModel<T>
236where
237    T: HttpClientExt + Clone + Default + std::fmt::Debug + WasmCompatSend + 'static,
238{
239    pub(crate) async fn stream(
240        &self,
241        completion_request: crate::completion::CompletionRequest,
242    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
243    {
244        let mut request = self.create_completion_request(completion_request)?;
245        request.stream = Some(true);
246
247        if enabled!(Level::TRACE) {
248            tracing::trace!(
249                target: "rig::completions",
250                "OpenAI Responses streaming completion request: {}",
251                serde_json::to_string_pretty(&request)?
252            );
253        }
254
255        let body = serde_json::to_vec(&request)?;
256
257        let req = self
258            .client
259            .post("/responses")?
260            .body(body)
261            .map_err(|e| CompletionError::HttpError(e.into()))?;
262
263        // let request_builder = self.client.post_reqwest("/responses").json(&request);
264
265        let span = if tracing::Span::current().is_disabled() {
266            info_span!(
267                target: "rig::completions",
268                "chat_streaming",
269                gen_ai.operation.name = "chat_streaming",
270                gen_ai.provider.name = tracing::field::Empty,
271                gen_ai.request.model = tracing::field::Empty,
272                gen_ai.response.id = tracing::field::Empty,
273                gen_ai.response.model = tracing::field::Empty,
274                gen_ai.usage.output_tokens = tracing::field::Empty,
275                gen_ai.usage.input_tokens = tracing::field::Empty,
276                gen_ai.usage.cached_tokens = tracing::field::Empty,
277            )
278        } else {
279            tracing::Span::current()
280        };
281        span.record("gen_ai.provider.name", "openai");
282        span.record("gen_ai.request.model", &self.model);
283        // Build the request with proper headers for SSE
284        let client = self.client.clone();
285
286        let mut event_source = GenericEventSource::new(client, req);
287
288        let stream = stream! {
289            let mut final_usage = ResponsesUsage::new();
290
291            let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
292            let mut tool_call_internal_ids: std::collections::HashMap<String, String> = std::collections::HashMap::new();
293            let span = tracing::Span::current();
294
295            while let Some(event_result) = event_source.next().await {
296                match event_result {
297                    Ok(Event::Open) => {
298                        tracing::trace!("SSE connection opened");
299                        tracing::info!("OpenAI stream started");
300                        continue;
301                    }
302                    Ok(Event::Message(evt)) => {
303                        // Skip heartbeat messages or empty data
304                        if evt.data.trim().is_empty() {
305                            continue;
306                        }
307
308                        let data = serde_json::from_str::<StreamingCompletionChunk>(&evt.data);
309
310                        let Ok(data) = data else {
311                            let err = data.unwrap_err();
312                            debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
313                            continue;
314                        };
315
316                        if let StreamingCompletionChunk::Delta(chunk) = &data {
317                            match &chunk.data {
318                                ItemChunkKind::OutputItemAdded(message) => {
319                                    if let StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } = message {
320                                        let internal_call_id = tool_call_internal_ids
321                                            .entry(func.id.clone())
322                                            .or_insert_with(|| nanoid::nanoid!())
323                                            .clone();
324                                        yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
325                                            id: func.id.clone(),
326                                            internal_call_id,
327                                            content: streaming::ToolCallDeltaContent::Name(func.name.clone()),
328                                        });
329                                    }
330                                }
331                                ItemChunkKind::OutputItemDone(message) => {
332                                    match message {
333                                        StreamingItemDoneOutput {  item: Output::FunctionCall(func), .. } => {
334                                            let internal_id = tool_call_internal_ids
335                                                .entry(func.id.clone())
336                                                .or_insert_with(|| nanoid::nanoid!())
337                                                .clone();
338                                            let raw_tool_call = streaming::RawStreamingToolCall::new(
339                                                func.id.clone(),
340                                                func.name.clone(),
341                                                func.arguments.clone(),
342                                            )
343                                                .with_internal_call_id(internal_id)
344                                                .with_call_id(func.call_id.clone());
345                                            tool_calls.push(streaming::RawStreamingChoice::ToolCall(raw_tool_call));
346                                        }
347
348                                        StreamingItemDoneOutput {  item: Output::Reasoning {  summary, id, encrypted_content, .. }, .. } => {
349                                            for reasoning_choice in reasoning_choices_from_done_item(
350                                                id,
351                                                summary,
352                                                encrypted_content.as_deref(),
353                                            ) {
354                                                yield Ok(reasoning_choice);
355                                            }
356                                        }
357                                        StreamingItemDoneOutput { item: Output::Message(msg), .. } => {
358                                            yield Ok(streaming::RawStreamingChoice::MessageId(msg.id.clone()));
359                                        }
360                                    }
361                                }
362                                ItemChunkKind::OutputTextDelta(delta) => {
363                                    yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
364                                }
365                                ItemChunkKind::ReasoningSummaryTextDelta(delta) => {
366                                    yield Ok(streaming::RawStreamingChoice::ReasoningDelta { id: None, reasoning: delta.delta.clone() })
367                                }
368                                ItemChunkKind::RefusalDelta(delta) => {
369                                    yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
370                                }
371                                ItemChunkKind::FunctionCallArgsDelta(delta) => {
372                                    let internal_call_id = tool_call_internal_ids
373                                        .entry(delta.item_id.clone())
374                                        .or_insert_with(|| nanoid::nanoid!())
375                                        .clone();
376                                    yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
377                                        id: delta.item_id.clone(),
378                                        internal_call_id,
379                                        content: streaming::ToolCallDeltaContent::Delta(delta.delta.clone())
380                                    })
381                                }
382
383                                _ => { continue }
384                            }
385                        }
386
387                        if let StreamingCompletionChunk::Response(chunk) = data {
388                            if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk {
389                                span.record("gen_ai.response.id", response.id);
390                                span.record("gen_ai.response.model", response.model);
391                                if let Some(usage) = response.usage {
392                                    final_usage = usage;
393                                }
394                            } else {
395                                continue;
396                            }
397                        }
398                    }
399                    Err(crate::http_client::Error::StreamEnded) => {
400                        event_source.close();
401                    }
402                    Err(error) => {
403                        tracing::error!(?error, "SSE error");
404                        yield Err(CompletionError::ProviderError(error.to_string()));
405                        break;
406                    }
407                }
408            }
409
410            // Ensure event source is closed when stream ends
411            event_source.close();
412
413            for tool_call in &tool_calls {
414                yield Ok(tool_call.to_owned())
415            }
416
417            span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
418            span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
419            span.record(
420                "gen_ai.usage.cached_tokens",
421                final_usage
422                    .input_tokens_details
423                    .as_ref()
424                    .map(|d| d.cached_tokens)
425                    .unwrap_or(0),
426            );
427            tracing::info!("OpenAI stream finished");
428
429            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
430                usage: final_usage
431            }));
432        }.instrument(span);
433
434        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
435            stream,
436        )))
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::{ItemChunkKind, StreamingCompletionChunk, reasoning_choices_from_done_item};
443    use crate::message::ReasoningContent;
444    use crate::providers::openai::responses_api::ReasoningSummary;
445    use crate::streaming::RawStreamingChoice;
446    use futures::StreamExt;
447    use rig::{client::CompletionClient, providers::openai, streaming::StreamingChat};
448    use serde_json::{self, json};
449
450    use crate::{
451        completion::ToolDefinition,
452        tool::{Tool, ToolError},
453    };
454
455    struct ExampleTool;
456
457    impl Tool for ExampleTool {
458        type Args = ();
459        type Error = ToolError;
460        type Output = String;
461        const NAME: &'static str = "example_tool";
462
463        async fn definition(&self, _prompt: String) -> ToolDefinition {
464            ToolDefinition {
465                name: self.name(),
466                description: "A tool that returns some example text.".to_string(),
467                parameters: serde_json::json!({
468                        "type": "object",
469                        "properties": {},
470                        "required": []
471                }),
472            }
473        }
474
475        async fn call(&self, _input: Self::Args) -> Result<Self::Output, Self::Error> {
476            let result = "Example answer".to_string();
477            Ok(result)
478        }
479    }
480
481    #[test]
482    fn reasoning_done_item_emits_summary_then_encrypted() {
483        let summary = vec![
484            ReasoningSummary::SummaryText {
485                text: "step 1".to_string(),
486            },
487            ReasoningSummary::SummaryText {
488                text: "step 2".to_string(),
489            },
490        ];
491        let choices = reasoning_choices_from_done_item("rs_1", &summary, Some("enc_blob"));
492
493        assert_eq!(choices.len(), 3);
494        assert!(matches!(
495            choices.first(),
496            Some(RawStreamingChoice::Reasoning {
497                id: Some(id),
498                content: ReasoningContent::Summary(text),
499            }) if id == "rs_1" && text == "step 1"
500        ));
501        assert!(matches!(
502            choices.get(1),
503            Some(RawStreamingChoice::Reasoning {
504                id: Some(id),
505                content: ReasoningContent::Summary(text),
506            }) if id == "rs_1" && text == "step 2"
507        ));
508        assert!(matches!(
509            choices.get(2),
510            Some(RawStreamingChoice::Reasoning {
511                id: Some(id),
512                content: ReasoningContent::Encrypted(data),
513            }) if id == "rs_1" && data == "enc_blob"
514        ));
515    }
516
517    #[test]
518    fn reasoning_done_item_without_encrypted_emits_summary_only() {
519        let summary = vec![ReasoningSummary::SummaryText {
520            text: "only summary".to_string(),
521        }];
522        let choices = reasoning_choices_from_done_item("rs_2", &summary, None);
523
524        assert_eq!(choices.len(), 1);
525        assert!(matches!(
526            choices.first(),
527            Some(RawStreamingChoice::Reasoning {
528                id: Some(id),
529                content: ReasoningContent::Summary(text),
530            }) if id == "rs_2" && text == "only summary"
531        ));
532    }
533
534    #[test]
535    fn content_part_added_deserializes_snake_case_part_type() {
536        let chunk: StreamingCompletionChunk = serde_json::from_value(json!({
537            "type": "response.content_part.added",
538            "item_id": "msg_1",
539            "output_index": 0,
540            "content_index": 0,
541            "sequence_number": 3,
542            "part": {
543                "type": "output_text",
544                "text": "hello"
545            }
546        }))
547        .expect("content part event should deserialize");
548
549        assert!(matches!(
550            chunk,
551            StreamingCompletionChunk::Delta(chunk)
552                if matches!(
553                    chunk.data,
554                    ItemChunkKind::ContentPartAdded(_)
555                )
556        ));
557    }
558
559    #[test]
560    fn content_part_done_deserializes_snake_case_part_type() {
561        let chunk: StreamingCompletionChunk = serde_json::from_value(json!({
562            "type": "response.content_part.done",
563            "item_id": "msg_1",
564            "output_index": 0,
565            "content_index": 0,
566            "sequence_number": 4,
567            "part": {
568                "type": "summary_text",
569                "text": "done"
570            }
571        }))
572        .expect("content part done event should deserialize");
573
574        assert!(matches!(
575            chunk,
576            StreamingCompletionChunk::Delta(chunk)
577                if matches!(
578                    chunk.data,
579                    ItemChunkKind::ContentPartDone(_)
580                )
581        ));
582    }
583
584    #[test]
585    fn reasoning_summary_part_added_deserializes_snake_case_part_type() {
586        let chunk: StreamingCompletionChunk = serde_json::from_value(json!({
587            "type": "response.reasoning_summary_part.added",
588            "item_id": "rs_1",
589            "output_index": 0,
590            "summary_index": 0,
591            "sequence_number": 5,
592            "part": {
593                "type": "summary_text",
594                "text": "step 1"
595            }
596        }))
597        .expect("reasoning summary part event should deserialize");
598
599        assert!(matches!(
600            chunk,
601            StreamingCompletionChunk::Delta(chunk)
602                if matches!(
603                    chunk.data,
604                    ItemChunkKind::ReasoningSummaryPartAdded(_)
605                )
606        ));
607    }
608
609    #[test]
610    fn reasoning_summary_part_done_deserializes_snake_case_part_type() {
611        let chunk: StreamingCompletionChunk = serde_json::from_value(json!({
612            "type": "response.reasoning_summary_part.done",
613            "item_id": "rs_1",
614            "output_index": 0,
615            "summary_index": 0,
616            "sequence_number": 6,
617            "part": {
618                "type": "summary_text",
619                "text": "step 2"
620            }
621        }))
622        .expect("reasoning summary part done event should deserialize");
623
624        assert!(matches!(
625            chunk,
626            StreamingCompletionChunk::Delta(chunk)
627                if matches!(
628                    chunk.data,
629                    ItemChunkKind::ReasoningSummaryPartDone(_)
630                )
631        ));
632    }
633
634    // requires `derive` rig-core feature due to using tool macro
635    #[tokio::test]
636    #[ignore = "requires API key"]
637    async fn test_openai_streaming_tools_reasoning() {
638        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY env var should exist");
639        let client = openai::Client::new(&api_key).expect("Failed to build client");
640        let agent = client
641            .agent("gpt-5.2")
642            .max_tokens(8192)
643            .tool(ExampleTool)
644            .additional_params(serde_json::json!({
645                "reasoning": {"effort": "high"}
646            }))
647            .build();
648
649        let chat_history = Vec::new();
650        let mut stream = agent
651            .stream_chat("Call my example tool", chat_history)
652            .multi_turn(5)
653            .await;
654
655        while let Some(item) = stream.next().await {
656            println!("Got item: {item:?}");
657        }
658    }
659}