1use crate::completion::{CompletionError, GetTokenUsage};
4use crate::http_client::HttpClientExt;
5use crate::http_client::sse::{Event, GenericEventSource};
6use crate::message::ReasoningContent;
7use crate::providers::openai::responses_api::{
8 ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
9};
10use crate::streaming;
11use crate::streaming::RawStreamingChoice;
12use crate::wasm_compat::WasmCompatSend;
13use async_stream::stream;
14use futures::StreamExt;
15use serde::{Deserialize, Serialize};
16use tracing::{Level, debug, enabled, info_span};
17use tracing_futures::Instrument as _;
18
19use super::{CompletionResponse, Output};
20
21#[derive(Debug, Serialize, Deserialize, Clone)]
30#[serde(untagged)]
31pub enum StreamingCompletionChunk {
32 Response(Box<ResponseChunk>),
33 Delta(ItemChunk),
34}
35
36#[derive(Debug, Serialize, Deserialize, Clone)]
38pub struct StreamingCompletionResponse {
39 pub usage: ResponsesUsage,
41}
42
43pub(crate) fn reasoning_choices_from_done_item(
44 id: &str,
45 summary: &[ReasoningSummary],
46 encrypted_content: Option<&str>,
47) -> Vec<RawStreamingChoice<StreamingCompletionResponse>> {
48 let mut choices = summary
49 .iter()
50 .map(|reasoning_summary| match reasoning_summary {
51 ReasoningSummary::SummaryText { text } => RawStreamingChoice::Reasoning {
52 id: Some(id.to_owned()),
53 content: ReasoningContent::Summary(text.to_owned()),
54 },
55 })
56 .collect::<Vec<_>>();
57
58 if let Some(encrypted_content) = encrypted_content {
59 choices.push(RawStreamingChoice::Reasoning {
60 id: Some(id.to_owned()),
61 content: ReasoningContent::Encrypted(encrypted_content.to_owned()),
62 });
63 }
64
65 choices
66}
67
68impl GetTokenUsage for StreamingCompletionResponse {
69 fn token_usage(&self) -> Option<crate::completion::Usage> {
70 let mut usage = crate::completion::Usage::new();
71 usage.input_tokens = self.usage.input_tokens;
72 usage.output_tokens = self.usage.output_tokens;
73 usage.total_tokens = self.usage.total_tokens;
74 usage.cached_input_tokens = self
75 .usage
76 .input_tokens_details
77 .as_ref()
78 .map(|d| d.cached_tokens)
79 .unwrap_or(0);
80 Some(usage)
81 }
82}
83
84#[derive(Debug, Serialize, Deserialize, Clone)]
86pub struct ResponseChunk {
87 #[serde(rename = "type")]
89 pub kind: ResponseChunkKind,
90 pub response: CompletionResponse,
92 pub sequence_number: u64,
94}
95
96#[derive(Debug, Serialize, Deserialize, Clone)]
99pub enum ResponseChunkKind {
100 #[serde(rename = "response.created")]
101 ResponseCreated,
102 #[serde(rename = "response.in_progress")]
103 ResponseInProgress,
104 #[serde(rename = "response.completed")]
105 ResponseCompleted,
106 #[serde(rename = "response.failed")]
107 ResponseFailed,
108 #[serde(rename = "response.incomplete")]
109 ResponseIncomplete,
110}
111
112#[derive(Debug, Serialize, Deserialize, Clone)]
115pub struct ItemChunk {
116 pub item_id: Option<String>,
118 pub output_index: u64,
120 #[serde(flatten)]
122 pub data: ItemChunkKind,
123}
124
125#[derive(Debug, Serialize, Deserialize, Clone)]
127#[serde(tag = "type")]
128pub enum ItemChunkKind {
129 #[serde(rename = "response.output_item.added")]
130 OutputItemAdded(StreamingItemDoneOutput),
131 #[serde(rename = "response.output_item.done")]
132 OutputItemDone(StreamingItemDoneOutput),
133 #[serde(rename = "response.content_part.added")]
134 ContentPartAdded(ContentPartChunk),
135 #[serde(rename = "response.content_part.done")]
136 ContentPartDone(ContentPartChunk),
137 #[serde(rename = "response.output_text.delta")]
138 OutputTextDelta(DeltaTextChunk),
139 #[serde(rename = "response.output_text.done")]
140 OutputTextDone(OutputTextChunk),
141 #[serde(rename = "response.refusal.delta")]
142 RefusalDelta(DeltaTextChunk),
143 #[serde(rename = "response.refusal.done")]
144 RefusalDone(RefusalTextChunk),
145 #[serde(rename = "response.function_call_arguments.delta")]
146 FunctionCallArgsDelta(DeltaTextChunkWithItemId),
147 #[serde(rename = "response.function_call_arguments.done")]
148 FunctionCallArgsDone(ArgsTextChunk),
149 #[serde(rename = "response.reasoning_summary_part.added")]
150 ReasoningSummaryPartAdded(SummaryPartChunk),
151 #[serde(rename = "response.reasoning_summary_part.done")]
152 ReasoningSummaryPartDone(SummaryPartChunk),
153 #[serde(rename = "response.reasoning_summary_text.delta")]
154 ReasoningSummaryTextDelta(SummaryTextChunk),
155 #[serde(rename = "response.reasoning_summary_text.done")]
156 ReasoningSummaryTextDone(SummaryTextChunk),
157}
158
159#[derive(Debug, Serialize, Deserialize, Clone)]
160pub struct StreamingItemDoneOutput {
161 pub sequence_number: u64,
162 pub item: Output,
163}
164
165#[derive(Debug, Serialize, Deserialize, Clone)]
166pub struct ContentPartChunk {
167 pub content_index: u64,
168 pub sequence_number: u64,
169 pub part: ContentPartChunkPart,
170}
171
172#[derive(Debug, Serialize, Deserialize, Clone)]
173#[serde(tag = "type", rename_all = "snake_case")]
174pub enum ContentPartChunkPart {
175 OutputText { text: String },
176 SummaryText { text: String },
177}
178
179#[derive(Debug, Serialize, Deserialize, Clone)]
180pub struct DeltaTextChunk {
181 pub content_index: u64,
182 pub sequence_number: u64,
183 pub delta: String,
184}
185
186#[derive(Debug, Serialize, Deserialize, Clone)]
187pub struct DeltaTextChunkWithItemId {
188 pub item_id: String,
189 pub content_index: u64,
190 pub sequence_number: u64,
191 pub delta: String,
192}
193
194#[derive(Debug, Serialize, Deserialize, Clone)]
195pub struct OutputTextChunk {
196 pub content_index: u64,
197 pub sequence_number: u64,
198 pub text: String,
199}
200
201#[derive(Debug, Serialize, Deserialize, Clone)]
202pub struct RefusalTextChunk {
203 pub content_index: u64,
204 pub sequence_number: u64,
205 pub refusal: String,
206}
207
208#[derive(Debug, Serialize, Deserialize, Clone)]
209pub struct ArgsTextChunk {
210 pub content_index: u64,
211 pub sequence_number: u64,
212 pub arguments: serde_json::Value,
213}
214
215#[derive(Debug, Serialize, Deserialize, Clone)]
216pub struct SummaryPartChunk {
217 pub summary_index: u64,
218 pub sequence_number: u64,
219 pub part: SummaryPartChunkPart,
220}
221
222#[derive(Debug, Serialize, Deserialize, Clone)]
223pub struct SummaryTextChunk {
224 pub summary_index: u64,
225 pub sequence_number: u64,
226 pub delta: String,
227}
228
229#[derive(Debug, Serialize, Deserialize, Clone)]
230#[serde(tag = "type", rename_all = "snake_case")]
231pub enum SummaryPartChunkPart {
232 SummaryText { text: String },
233}
234
235impl<T> ResponsesCompletionModel<T>
236where
237 T: HttpClientExt + Clone + Default + std::fmt::Debug + WasmCompatSend + 'static,
238{
239 pub(crate) async fn stream(
240 &self,
241 completion_request: crate::completion::CompletionRequest,
242 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
243 {
244 let mut request = self.create_completion_request(completion_request)?;
245 request.stream = Some(true);
246
247 if enabled!(Level::TRACE) {
248 tracing::trace!(
249 target: "rig::completions",
250 "OpenAI Responses streaming completion request: {}",
251 serde_json::to_string_pretty(&request)?
252 );
253 }
254
255 let body = serde_json::to_vec(&request)?;
256
257 let req = self
258 .client
259 .post("/responses")?
260 .body(body)
261 .map_err(|e| CompletionError::HttpError(e.into()))?;
262
263 let span = if tracing::Span::current().is_disabled() {
266 info_span!(
267 target: "rig::completions",
268 "chat_streaming",
269 gen_ai.operation.name = "chat_streaming",
270 gen_ai.provider.name = tracing::field::Empty,
271 gen_ai.request.model = tracing::field::Empty,
272 gen_ai.response.id = tracing::field::Empty,
273 gen_ai.response.model = tracing::field::Empty,
274 gen_ai.usage.output_tokens = tracing::field::Empty,
275 gen_ai.usage.input_tokens = tracing::field::Empty,
276 gen_ai.usage.cached_tokens = tracing::field::Empty,
277 )
278 } else {
279 tracing::Span::current()
280 };
281 span.record("gen_ai.provider.name", "openai");
282 span.record("gen_ai.request.model", &self.model);
283 let client = self.client.clone();
285
286 let mut event_source = GenericEventSource::new(client, req);
287
288 let stream = stream! {
289 let mut final_usage = ResponsesUsage::new();
290
291 let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
292 let mut tool_call_internal_ids: std::collections::HashMap<String, String> = std::collections::HashMap::new();
293 let span = tracing::Span::current();
294
295 while let Some(event_result) = event_source.next().await {
296 match event_result {
297 Ok(Event::Open) => {
298 tracing::trace!("SSE connection opened");
299 tracing::info!("OpenAI stream started");
300 continue;
301 }
302 Ok(Event::Message(evt)) => {
303 if evt.data.trim().is_empty() {
305 continue;
306 }
307
308 let data = serde_json::from_str::<StreamingCompletionChunk>(&evt.data);
309
310 let Ok(data) = data else {
311 let err = data.unwrap_err();
312 debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
313 continue;
314 };
315
316 if let StreamingCompletionChunk::Delta(chunk) = &data {
317 match &chunk.data {
318 ItemChunkKind::OutputItemAdded(message) => {
319 if let StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } = message {
320 let internal_call_id = tool_call_internal_ids
321 .entry(func.id.clone())
322 .or_insert_with(|| nanoid::nanoid!())
323 .clone();
324 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
325 id: func.id.clone(),
326 internal_call_id,
327 content: streaming::ToolCallDeltaContent::Name(func.name.clone()),
328 });
329 }
330 }
331 ItemChunkKind::OutputItemDone(message) => {
332 match message {
333 StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } => {
334 let internal_id = tool_call_internal_ids
335 .entry(func.id.clone())
336 .or_insert_with(|| nanoid::nanoid!())
337 .clone();
338 let raw_tool_call = streaming::RawStreamingToolCall::new(
339 func.id.clone(),
340 func.name.clone(),
341 func.arguments.clone(),
342 )
343 .with_internal_call_id(internal_id)
344 .with_call_id(func.call_id.clone());
345 tool_calls.push(streaming::RawStreamingChoice::ToolCall(raw_tool_call));
346 }
347
348 StreamingItemDoneOutput { item: Output::Reasoning { summary, id, encrypted_content, .. }, .. } => {
349 for reasoning_choice in reasoning_choices_from_done_item(
350 id,
351 summary,
352 encrypted_content.as_deref(),
353 ) {
354 yield Ok(reasoning_choice);
355 }
356 }
357 StreamingItemDoneOutput { item: Output::Message(msg), .. } => {
358 yield Ok(streaming::RawStreamingChoice::MessageId(msg.id.clone()));
359 }
360 }
361 }
362 ItemChunkKind::OutputTextDelta(delta) => {
363 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
364 }
365 ItemChunkKind::ReasoningSummaryTextDelta(delta) => {
366 yield Ok(streaming::RawStreamingChoice::ReasoningDelta { id: None, reasoning: delta.delta.clone() })
367 }
368 ItemChunkKind::RefusalDelta(delta) => {
369 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
370 }
371 ItemChunkKind::FunctionCallArgsDelta(delta) => {
372 let internal_call_id = tool_call_internal_ids
373 .entry(delta.item_id.clone())
374 .or_insert_with(|| nanoid::nanoid!())
375 .clone();
376 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
377 id: delta.item_id.clone(),
378 internal_call_id,
379 content: streaming::ToolCallDeltaContent::Delta(delta.delta.clone())
380 })
381 }
382
383 _ => { continue }
384 }
385 }
386
387 if let StreamingCompletionChunk::Response(chunk) = data {
388 if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk {
389 span.record("gen_ai.response.id", response.id);
390 span.record("gen_ai.response.model", response.model);
391 if let Some(usage) = response.usage {
392 final_usage = usage;
393 }
394 } else {
395 continue;
396 }
397 }
398 }
399 Err(crate::http_client::Error::StreamEnded) => {
400 event_source.close();
401 }
402 Err(error) => {
403 tracing::error!(?error, "SSE error");
404 yield Err(CompletionError::ProviderError(error.to_string()));
405 break;
406 }
407 }
408 }
409
410 event_source.close();
412
413 for tool_call in &tool_calls {
414 yield Ok(tool_call.to_owned())
415 }
416
417 span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
418 span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
419 span.record(
420 "gen_ai.usage.cached_tokens",
421 final_usage
422 .input_tokens_details
423 .as_ref()
424 .map(|d| d.cached_tokens)
425 .unwrap_or(0),
426 );
427 tracing::info!("OpenAI stream finished");
428
429 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
430 usage: final_usage
431 }));
432 }.instrument(span);
433
434 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
435 stream,
436 )))
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::{ItemChunkKind, StreamingCompletionChunk, reasoning_choices_from_done_item};
443 use crate::message::ReasoningContent;
444 use crate::providers::openai::responses_api::ReasoningSummary;
445 use crate::streaming::RawStreamingChoice;
446 use futures::StreamExt;
447 use rig::{client::CompletionClient, providers::openai, streaming::StreamingChat};
448 use serde_json::{self, json};
449
450 use crate::{
451 completion::ToolDefinition,
452 tool::{Tool, ToolError},
453 };
454
455 struct ExampleTool;
456
457 impl Tool for ExampleTool {
458 type Args = ();
459 type Error = ToolError;
460 type Output = String;
461 const NAME: &'static str = "example_tool";
462
463 async fn definition(&self, _prompt: String) -> ToolDefinition {
464 ToolDefinition {
465 name: self.name(),
466 description: "A tool that returns some example text.".to_string(),
467 parameters: serde_json::json!({
468 "type": "object",
469 "properties": {},
470 "required": []
471 }),
472 }
473 }
474
475 async fn call(&self, _input: Self::Args) -> Result<Self::Output, Self::Error> {
476 let result = "Example answer".to_string();
477 Ok(result)
478 }
479 }
480
481 #[test]
482 fn reasoning_done_item_emits_summary_then_encrypted() {
483 let summary = vec![
484 ReasoningSummary::SummaryText {
485 text: "step 1".to_string(),
486 },
487 ReasoningSummary::SummaryText {
488 text: "step 2".to_string(),
489 },
490 ];
491 let choices = reasoning_choices_from_done_item("rs_1", &summary, Some("enc_blob"));
492
493 assert_eq!(choices.len(), 3);
494 assert!(matches!(
495 choices.first(),
496 Some(RawStreamingChoice::Reasoning {
497 id: Some(id),
498 content: ReasoningContent::Summary(text),
499 }) if id == "rs_1" && text == "step 1"
500 ));
501 assert!(matches!(
502 choices.get(1),
503 Some(RawStreamingChoice::Reasoning {
504 id: Some(id),
505 content: ReasoningContent::Summary(text),
506 }) if id == "rs_1" && text == "step 2"
507 ));
508 assert!(matches!(
509 choices.get(2),
510 Some(RawStreamingChoice::Reasoning {
511 id: Some(id),
512 content: ReasoningContent::Encrypted(data),
513 }) if id == "rs_1" && data == "enc_blob"
514 ));
515 }
516
517 #[test]
518 fn reasoning_done_item_without_encrypted_emits_summary_only() {
519 let summary = vec![ReasoningSummary::SummaryText {
520 text: "only summary".to_string(),
521 }];
522 let choices = reasoning_choices_from_done_item("rs_2", &summary, None);
523
524 assert_eq!(choices.len(), 1);
525 assert!(matches!(
526 choices.first(),
527 Some(RawStreamingChoice::Reasoning {
528 id: Some(id),
529 content: ReasoningContent::Summary(text),
530 }) if id == "rs_2" && text == "only summary"
531 ));
532 }
533
534 #[test]
535 fn content_part_added_deserializes_snake_case_part_type() {
536 let chunk: StreamingCompletionChunk = serde_json::from_value(json!({
537 "type": "response.content_part.added",
538 "item_id": "msg_1",
539 "output_index": 0,
540 "content_index": 0,
541 "sequence_number": 3,
542 "part": {
543 "type": "output_text",
544 "text": "hello"
545 }
546 }))
547 .expect("content part event should deserialize");
548
549 assert!(matches!(
550 chunk,
551 StreamingCompletionChunk::Delta(chunk)
552 if matches!(
553 chunk.data,
554 ItemChunkKind::ContentPartAdded(_)
555 )
556 ));
557 }
558
559 #[test]
560 fn content_part_done_deserializes_snake_case_part_type() {
561 let chunk: StreamingCompletionChunk = serde_json::from_value(json!({
562 "type": "response.content_part.done",
563 "item_id": "msg_1",
564 "output_index": 0,
565 "content_index": 0,
566 "sequence_number": 4,
567 "part": {
568 "type": "summary_text",
569 "text": "done"
570 }
571 }))
572 .expect("content part done event should deserialize");
573
574 assert!(matches!(
575 chunk,
576 StreamingCompletionChunk::Delta(chunk)
577 if matches!(
578 chunk.data,
579 ItemChunkKind::ContentPartDone(_)
580 )
581 ));
582 }
583
584 #[test]
585 fn reasoning_summary_part_added_deserializes_snake_case_part_type() {
586 let chunk: StreamingCompletionChunk = serde_json::from_value(json!({
587 "type": "response.reasoning_summary_part.added",
588 "item_id": "rs_1",
589 "output_index": 0,
590 "summary_index": 0,
591 "sequence_number": 5,
592 "part": {
593 "type": "summary_text",
594 "text": "step 1"
595 }
596 }))
597 .expect("reasoning summary part event should deserialize");
598
599 assert!(matches!(
600 chunk,
601 StreamingCompletionChunk::Delta(chunk)
602 if matches!(
603 chunk.data,
604 ItemChunkKind::ReasoningSummaryPartAdded(_)
605 )
606 ));
607 }
608
609 #[test]
610 fn reasoning_summary_part_done_deserializes_snake_case_part_type() {
611 let chunk: StreamingCompletionChunk = serde_json::from_value(json!({
612 "type": "response.reasoning_summary_part.done",
613 "item_id": "rs_1",
614 "output_index": 0,
615 "summary_index": 0,
616 "sequence_number": 6,
617 "part": {
618 "type": "summary_text",
619 "text": "step 2"
620 }
621 }))
622 .expect("reasoning summary part done event should deserialize");
623
624 assert!(matches!(
625 chunk,
626 StreamingCompletionChunk::Delta(chunk)
627 if matches!(
628 chunk.data,
629 ItemChunkKind::ReasoningSummaryPartDone(_)
630 )
631 ));
632 }
633
634 #[tokio::test]
636 #[ignore = "requires API key"]
637 async fn test_openai_streaming_tools_reasoning() {
638 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY env var should exist");
639 let client = openai::Client::new(&api_key).expect("Failed to build client");
640 let agent = client
641 .agent("gpt-5.2")
642 .max_tokens(8192)
643 .tool(ExampleTool)
644 .additional_params(serde_json::json!({
645 "reasoning": {"effort": "high"}
646 }))
647 .build();
648
649 let chat_history = Vec::new();
650 let mut stream = agent
651 .stream_chat("Call my example tool", chat_history)
652 .multi_turn(5)
653 .await;
654
655 while let Some(item) = stream.next().await {
656 println!("Got item: {item:?}");
657 }
658 }
659}