1use crate::completion::{CompletionError, GetTokenUsage};
4use crate::http_client::HttpClientExt;
5use crate::http_client::sse::{Event, GenericEventSource};
6use crate::message::ReasoningContent;
7use crate::providers::openai::responses_api::{
8 ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
9};
10use crate::streaming;
11use crate::streaming::RawStreamingChoice;
12use crate::wasm_compat::WasmCompatSend;
13use async_stream::stream;
14use futures::StreamExt;
15use serde::{Deserialize, Serialize};
16use tracing::{Level, debug, enabled, info_span};
17use tracing_futures::Instrument as _;
18
19use super::{CompletionResponse, Output};
20
21#[derive(Debug, Serialize, Deserialize, Clone)]
30#[serde(untagged)]
31pub enum StreamingCompletionChunk {
32 Response(Box<ResponseChunk>),
33 Delta(ItemChunk),
34}
35
36#[derive(Debug, Serialize, Deserialize, Clone)]
38pub struct StreamingCompletionResponse {
39 pub usage: ResponsesUsage,
41}
42
43pub(crate) fn reasoning_choices_from_done_item(
44 id: &str,
45 summary: &[ReasoningSummary],
46 encrypted_content: Option<&str>,
47) -> Vec<RawStreamingChoice<StreamingCompletionResponse>> {
48 let mut choices = summary
49 .iter()
50 .map(|reasoning_summary| match reasoning_summary {
51 ReasoningSummary::SummaryText { text } => RawStreamingChoice::Reasoning {
52 id: Some(id.to_owned()),
53 content: ReasoningContent::Summary(text.to_owned()),
54 },
55 })
56 .collect::<Vec<_>>();
57
58 if let Some(encrypted_content) = encrypted_content {
59 choices.push(RawStreamingChoice::Reasoning {
60 id: Some(id.to_owned()),
61 content: ReasoningContent::Encrypted(encrypted_content.to_owned()),
62 });
63 }
64
65 choices
66}
67
68impl GetTokenUsage for StreamingCompletionResponse {
69 fn token_usage(&self) -> Option<crate::completion::Usage> {
70 let mut usage = crate::completion::Usage::new();
71 usage.input_tokens = self.usage.input_tokens;
72 usage.output_tokens = self.usage.output_tokens;
73 usage.total_tokens = self.usage.total_tokens;
74 usage.cached_input_tokens = self
75 .usage
76 .input_tokens_details
77 .as_ref()
78 .map(|d| d.cached_tokens)
79 .unwrap_or(0);
80 Some(usage)
81 }
82}
83
84#[derive(Debug, Serialize, Deserialize, Clone)]
86pub struct ResponseChunk {
87 #[serde(rename = "type")]
89 pub kind: ResponseChunkKind,
90 pub response: CompletionResponse,
92 pub sequence_number: u64,
94}
95
96#[derive(Debug, Serialize, Deserialize, Clone)]
99pub enum ResponseChunkKind {
100 #[serde(rename = "response.created")]
101 ResponseCreated,
102 #[serde(rename = "response.in_progress")]
103 ResponseInProgress,
104 #[serde(rename = "response.completed")]
105 ResponseCompleted,
106 #[serde(rename = "response.failed")]
107 ResponseFailed,
108 #[serde(rename = "response.incomplete")]
109 ResponseIncomplete,
110}
111
112#[derive(Debug, Serialize, Deserialize, Clone)]
115pub struct ItemChunk {
116 pub item_id: Option<String>,
118 pub output_index: u64,
120 #[serde(flatten)]
122 pub data: ItemChunkKind,
123}
124
125#[derive(Debug, Serialize, Deserialize, Clone)]
127#[serde(tag = "type")]
128pub enum ItemChunkKind {
129 #[serde(rename = "response.output_item.added")]
130 OutputItemAdded(StreamingItemDoneOutput),
131 #[serde(rename = "response.output_item.done")]
132 OutputItemDone(StreamingItemDoneOutput),
133 #[serde(rename = "response.content_part.added")]
134 ContentPartAdded(ContentPartChunk),
135 #[serde(rename = "response.content_part.done")]
136 ContentPartDone(ContentPartChunk),
137 #[serde(rename = "response.output_text.delta")]
138 OutputTextDelta(DeltaTextChunk),
139 #[serde(rename = "response.output_text.done")]
140 OutputTextDone(OutputTextChunk),
141 #[serde(rename = "response.refusal.delta")]
142 RefusalDelta(DeltaTextChunk),
143 #[serde(rename = "response.refusal.done")]
144 RefusalDone(RefusalTextChunk),
145 #[serde(rename = "response.function_call_arguments.delta")]
146 FunctionCallArgsDelta(DeltaTextChunkWithItemId),
147 #[serde(rename = "response.function_call_arguments.done")]
148 FunctionCallArgsDone(ArgsTextChunk),
149 #[serde(rename = "response.reasoning_summary_part.added")]
150 ReasoningSummaryPartAdded(SummaryPartChunk),
151 #[serde(rename = "response.reasoning_summary_part.done")]
152 ReasoningSummaryPartDone(SummaryPartChunk),
153 #[serde(rename = "response.reasoning_summary_text.delta")]
154 ReasoningSummaryTextDelta(SummaryTextChunk),
155 #[serde(rename = "response.reasoning_summary_text.done")]
156 ReasoningSummaryTextDone(SummaryTextChunk),
157}
158
159#[derive(Debug, Serialize, Deserialize, Clone)]
160pub struct StreamingItemDoneOutput {
161 pub sequence_number: u64,
162 pub item: Output,
163}
164
165#[derive(Debug, Serialize, Deserialize, Clone)]
166pub struct ContentPartChunk {
167 pub content_index: u64,
168 pub sequence_number: u64,
169 pub part: ContentPartChunkPart,
170}
171
172#[derive(Debug, Serialize, Deserialize, Clone)]
173#[serde(tag = "type")]
174pub enum ContentPartChunkPart {
175 OutputText { text: String },
176 SummaryText { text: String },
177}
178
179#[derive(Debug, Serialize, Deserialize, Clone)]
180pub struct DeltaTextChunk {
181 pub content_index: u64,
182 pub sequence_number: u64,
183 pub delta: String,
184}
185
186#[derive(Debug, Serialize, Deserialize, Clone)]
187pub struct DeltaTextChunkWithItemId {
188 pub item_id: String,
189 pub content_index: u64,
190 pub sequence_number: u64,
191 pub delta: String,
192}
193
194#[derive(Debug, Serialize, Deserialize, Clone)]
195pub struct OutputTextChunk {
196 pub content_index: u64,
197 pub sequence_number: u64,
198 pub text: String,
199}
200
201#[derive(Debug, Serialize, Deserialize, Clone)]
202pub struct RefusalTextChunk {
203 pub content_index: u64,
204 pub sequence_number: u64,
205 pub refusal: String,
206}
207
208#[derive(Debug, Serialize, Deserialize, Clone)]
209pub struct ArgsTextChunk {
210 pub content_index: u64,
211 pub sequence_number: u64,
212 pub arguments: serde_json::Value,
213}
214
215#[derive(Debug, Serialize, Deserialize, Clone)]
216pub struct SummaryPartChunk {
217 pub summary_index: u64,
218 pub sequence_number: u64,
219 pub part: SummaryPartChunkPart,
220}
221
222#[derive(Debug, Serialize, Deserialize, Clone)]
223pub struct SummaryTextChunk {
224 pub summary_index: u64,
225 pub sequence_number: u64,
226 pub delta: String,
227}
228
229#[derive(Debug, Serialize, Deserialize, Clone)]
230#[serde(tag = "type")]
231pub enum SummaryPartChunkPart {
232 SummaryText { text: String },
233}
234
235impl<T> ResponsesCompletionModel<T>
236where
237 T: HttpClientExt + Clone + Default + std::fmt::Debug + WasmCompatSend + 'static,
238{
239 pub(crate) async fn stream(
240 &self,
241 completion_request: crate::completion::CompletionRequest,
242 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
243 {
244 let mut request = self.create_completion_request(completion_request)?;
245 request.stream = Some(true);
246
247 if enabled!(Level::TRACE) {
248 tracing::trace!(
249 target: "rig::completions",
250 "OpenAI Responses streaming completion request: {}",
251 serde_json::to_string_pretty(&request)?
252 );
253 }
254
255 let body = serde_json::to_vec(&request)?;
256
257 let req = self
258 .client
259 .post("/responses")?
260 .body(body)
261 .map_err(|e| CompletionError::HttpError(e.into()))?;
262
263 let span = if tracing::Span::current().is_disabled() {
266 info_span!(
267 target: "rig::completions",
268 "chat_streaming",
269 gen_ai.operation.name = "chat_streaming",
270 gen_ai.provider.name = tracing::field::Empty,
271 gen_ai.request.model = tracing::field::Empty,
272 gen_ai.response.id = tracing::field::Empty,
273 gen_ai.response.model = tracing::field::Empty,
274 gen_ai.usage.output_tokens = tracing::field::Empty,
275 gen_ai.usage.input_tokens = tracing::field::Empty,
276 )
277 } else {
278 tracing::Span::current()
279 };
280 span.record("gen_ai.provider.name", "openai");
281 span.record("gen_ai.request.model", &self.model);
282 let client = self.client.clone();
284
285 let mut event_source = GenericEventSource::new(client, req);
286
287 let stream = stream! {
288 let mut final_usage = ResponsesUsage::new();
289
290 let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
291 let mut tool_call_internal_ids: std::collections::HashMap<String, String> = std::collections::HashMap::new();
292 let span = tracing::Span::current();
293
294 while let Some(event_result) = event_source.next().await {
295 match event_result {
296 Ok(Event::Open) => {
297 tracing::trace!("SSE connection opened");
298 tracing::info!("OpenAI stream started");
299 continue;
300 }
301 Ok(Event::Message(evt)) => {
302 if evt.data.trim().is_empty() {
304 continue;
305 }
306
307 let data = serde_json::from_str::<StreamingCompletionChunk>(&evt.data);
308
309 let Ok(data) = data else {
310 let err = data.unwrap_err();
311 debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
312 continue;
313 };
314
315 if let StreamingCompletionChunk::Delta(chunk) = &data {
316 match &chunk.data {
317 ItemChunkKind::OutputItemAdded(message) => {
318 if let StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } = message {
319 let internal_call_id = tool_call_internal_ids
320 .entry(func.id.clone())
321 .or_insert_with(|| nanoid::nanoid!())
322 .clone();
323 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
324 id: func.id.clone(),
325 internal_call_id,
326 content: streaming::ToolCallDeltaContent::Name(func.name.clone()),
327 });
328 }
329 }
330 ItemChunkKind::OutputItemDone(message) => {
331 match message {
332 StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } => {
333 let internal_id = tool_call_internal_ids
334 .entry(func.id.clone())
335 .or_insert_with(|| nanoid::nanoid!())
336 .clone();
337 let raw_tool_call = streaming::RawStreamingToolCall::new(
338 func.id.clone(),
339 func.name.clone(),
340 func.arguments.clone(),
341 )
342 .with_internal_call_id(internal_id)
343 .with_call_id(func.call_id.clone());
344 tool_calls.push(streaming::RawStreamingChoice::ToolCall(raw_tool_call));
345 }
346
347 StreamingItemDoneOutput { item: Output::Reasoning { summary, id, encrypted_content, .. }, .. } => {
348 for reasoning_choice in reasoning_choices_from_done_item(
349 id,
350 summary,
351 encrypted_content.as_deref(),
352 ) {
353 yield Ok(reasoning_choice);
354 }
355 }
356 StreamingItemDoneOutput { item: Output::Message(msg), .. } => {
357 yield Ok(streaming::RawStreamingChoice::MessageId(msg.id.clone()));
358 }
359 }
360 }
361 ItemChunkKind::OutputTextDelta(delta) => {
362 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
363 }
364 ItemChunkKind::ReasoningSummaryTextDelta(delta) => {
365 yield Ok(streaming::RawStreamingChoice::ReasoningDelta { id: None, reasoning: delta.delta.clone() })
366 }
367 ItemChunkKind::RefusalDelta(delta) => {
368 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
369 }
370 ItemChunkKind::FunctionCallArgsDelta(delta) => {
371 let internal_call_id = tool_call_internal_ids
372 .entry(delta.item_id.clone())
373 .or_insert_with(|| nanoid::nanoid!())
374 .clone();
375 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
376 id: delta.item_id.clone(),
377 internal_call_id,
378 content: streaming::ToolCallDeltaContent::Delta(delta.delta.clone())
379 })
380 }
381
382 _ => { continue }
383 }
384 }
385
386 if let StreamingCompletionChunk::Response(chunk) = data {
387 if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk {
388 span.record("gen_ai.response.id", response.id);
389 span.record("gen_ai.response.model", response.model);
390 if let Some(usage) = response.usage {
391 final_usage = usage;
392 }
393 } else {
394 continue;
395 }
396 }
397 }
398 Err(crate::http_client::Error::StreamEnded) => {
399 event_source.close();
400 }
401 Err(error) => {
402 tracing::error!(?error, "SSE error");
403 yield Err(CompletionError::ProviderError(error.to_string()));
404 break;
405 }
406 }
407 }
408
409 event_source.close();
411
412 for tool_call in &tool_calls {
413 yield Ok(tool_call.to_owned())
414 }
415
416 span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
417 span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
418 tracing::info!("OpenAI stream finished");
419
420 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
421 usage: final_usage
422 }));
423 }.instrument(span);
424
425 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
426 stream,
427 )))
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::reasoning_choices_from_done_item;
434 use crate::message::ReasoningContent;
435 use crate::providers::openai::responses_api::ReasoningSummary;
436 use crate::streaming::RawStreamingChoice;
437 use futures::StreamExt;
438 use rig::{client::CompletionClient, providers::openai, streaming::StreamingChat};
439 use serde_json;
440
441 use crate::{
442 completion::ToolDefinition,
443 tool::{Tool, ToolError},
444 };
445
446 struct ExampleTool;
447
448 impl Tool for ExampleTool {
449 type Args = ();
450 type Error = ToolError;
451 type Output = String;
452 const NAME: &'static str = "example_tool";
453
454 async fn definition(&self, _prompt: String) -> ToolDefinition {
455 ToolDefinition {
456 name: self.name(),
457 description: "A tool that returns some example text.".to_string(),
458 parameters: serde_json::json!({
459 "type": "object",
460 "properties": {},
461 "required": []
462 }),
463 }
464 }
465
466 async fn call(&self, _input: Self::Args) -> Result<Self::Output, Self::Error> {
467 let result = "Example answer".to_string();
468 Ok(result)
469 }
470 }
471
472 #[test]
473 fn reasoning_done_item_emits_summary_then_encrypted() {
474 let summary = vec![
475 ReasoningSummary::SummaryText {
476 text: "step 1".to_string(),
477 },
478 ReasoningSummary::SummaryText {
479 text: "step 2".to_string(),
480 },
481 ];
482 let choices = reasoning_choices_from_done_item("rs_1", &summary, Some("enc_blob"));
483
484 assert_eq!(choices.len(), 3);
485 assert!(matches!(
486 choices.first(),
487 Some(RawStreamingChoice::Reasoning {
488 id: Some(id),
489 content: ReasoningContent::Summary(text),
490 }) if id == "rs_1" && text == "step 1"
491 ));
492 assert!(matches!(
493 choices.get(1),
494 Some(RawStreamingChoice::Reasoning {
495 id: Some(id),
496 content: ReasoningContent::Summary(text),
497 }) if id == "rs_1" && text == "step 2"
498 ));
499 assert!(matches!(
500 choices.get(2),
501 Some(RawStreamingChoice::Reasoning {
502 id: Some(id),
503 content: ReasoningContent::Encrypted(data),
504 }) if id == "rs_1" && data == "enc_blob"
505 ));
506 }
507
508 #[test]
509 fn reasoning_done_item_without_encrypted_emits_summary_only() {
510 let summary = vec![ReasoningSummary::SummaryText {
511 text: "only summary".to_string(),
512 }];
513 let choices = reasoning_choices_from_done_item("rs_2", &summary, None);
514
515 assert_eq!(choices.len(), 1);
516 assert!(matches!(
517 choices.first(),
518 Some(RawStreamingChoice::Reasoning {
519 id: Some(id),
520 content: ReasoningContent::Summary(text),
521 }) if id == "rs_2" && text == "only summary"
522 ));
523 }
524
525 #[tokio::test]
527 #[ignore = "requires API key"]
528 async fn test_openai_streaming_tools_reasoning() {
529 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY env var should exist");
530 let client: openai::Client<rig::http_client::ReqwestClient> =
531 openai::Client::new(&api_key).expect("Failed to build client");
532 let agent = client
533 .agent("gpt-5.2")
534 .max_tokens(8192)
535 .tool(ExampleTool)
536 .additional_params(serde_json::json!({
537 "reasoning": {"effort": "high"}
538 }))
539 .build();
540
541 let chat_history = Vec::new();
542 let mut stream = agent
543 .stream_chat("Call my example tool", chat_history)
544 .multi_turn(5)
545 .await;
546
547 while let Some(item) = stream.next().await {
548 println!("Got item: {item:?}");
549 }
550 }
551}