1use crate::completion::{CompletionError, GetTokenUsage};
4use crate::http_client::HttpClientExt;
5use crate::http_client::sse::{Event, GenericEventSource};
6use crate::providers::openai::responses_api::{
7 ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
8};
9use crate::streaming;
10use crate::streaming::RawStreamingChoice;
11use crate::wasm_compat::WasmCompatSend;
12use async_stream::stream;
13use futures::StreamExt;
14use serde::{Deserialize, Serialize};
15use tracing::{Level, debug, enabled, info_span};
16use tracing_futures::Instrument as _;
17
18use super::{CompletionResponse, Output};
19
20#[derive(Debug, Serialize, Deserialize, Clone)]
29#[serde(untagged)]
30pub enum StreamingCompletionChunk {
31 Response(Box<ResponseChunk>),
32 Delta(ItemChunk),
33}
34
35#[derive(Debug, Serialize, Deserialize, Clone)]
37pub struct StreamingCompletionResponse {
38 pub usage: ResponsesUsage,
40}
41
42impl GetTokenUsage for StreamingCompletionResponse {
43 fn token_usage(&self) -> Option<crate::completion::Usage> {
44 let mut usage = crate::completion::Usage::new();
45 usage.input_tokens = self.usage.input_tokens;
46 usage.output_tokens = self.usage.output_tokens;
47 usage.total_tokens = self.usage.total_tokens;
48 usage.cached_input_tokens = self
49 .usage
50 .input_tokens_details
51 .as_ref()
52 .map(|d| d.cached_tokens)
53 .unwrap_or(0);
54 Some(usage)
55 }
56}
57
58#[derive(Debug, Serialize, Deserialize, Clone)]
60pub struct ResponseChunk {
61 #[serde(rename = "type")]
63 pub kind: ResponseChunkKind,
64 pub response: CompletionResponse,
66 pub sequence_number: u64,
68}
69
70#[derive(Debug, Serialize, Deserialize, Clone)]
73pub enum ResponseChunkKind {
74 #[serde(rename = "response.created")]
75 ResponseCreated,
76 #[serde(rename = "response.in_progress")]
77 ResponseInProgress,
78 #[serde(rename = "response.completed")]
79 ResponseCompleted,
80 #[serde(rename = "response.failed")]
81 ResponseFailed,
82 #[serde(rename = "response.incomplete")]
83 ResponseIncomplete,
84}
85
86#[derive(Debug, Serialize, Deserialize, Clone)]
89pub struct ItemChunk {
90 pub item_id: Option<String>,
92 pub output_index: u64,
94 #[serde(flatten)]
96 pub data: ItemChunkKind,
97}
98
99#[derive(Debug, Serialize, Deserialize, Clone)]
101#[serde(tag = "type")]
102pub enum ItemChunkKind {
103 #[serde(rename = "response.output_item.added")]
104 OutputItemAdded(StreamingItemDoneOutput),
105 #[serde(rename = "response.output_item.done")]
106 OutputItemDone(StreamingItemDoneOutput),
107 #[serde(rename = "response.content_part.added")]
108 ContentPartAdded(ContentPartChunk),
109 #[serde(rename = "response.content_part.done")]
110 ContentPartDone(ContentPartChunk),
111 #[serde(rename = "response.output_text.delta")]
112 OutputTextDelta(DeltaTextChunk),
113 #[serde(rename = "response.output_text.done")]
114 OutputTextDone(OutputTextChunk),
115 #[serde(rename = "response.refusal.delta")]
116 RefusalDelta(DeltaTextChunk),
117 #[serde(rename = "response.refusal.done")]
118 RefusalDone(RefusalTextChunk),
119 #[serde(rename = "response.function_call_arguments.delta")]
120 FunctionCallArgsDelta(DeltaTextChunkWithItemId),
121 #[serde(rename = "response.function_call_arguments.done")]
122 FunctionCallArgsDone(ArgsTextChunk),
123 #[serde(rename = "response.reasoning_summary_part.added")]
124 ReasoningSummaryPartAdded(SummaryPartChunk),
125 #[serde(rename = "response.reasoning_summary_part.done")]
126 ReasoningSummaryPartDone(SummaryPartChunk),
127 #[serde(rename = "response.reasoning_summary_text.delta")]
128 ReasoningSummaryTextDelta(SummaryTextChunk),
129 #[serde(rename = "response.reasoning_summary_text.done")]
130 ReasoningSummaryTextDone(SummaryTextChunk),
131}
132
133#[derive(Debug, Serialize, Deserialize, Clone)]
134pub struct StreamingItemDoneOutput {
135 pub sequence_number: u64,
136 pub item: Output,
137}
138
139#[derive(Debug, Serialize, Deserialize, Clone)]
140pub struct ContentPartChunk {
141 pub content_index: u64,
142 pub sequence_number: u64,
143 pub part: ContentPartChunkPart,
144}
145
146#[derive(Debug, Serialize, Deserialize, Clone)]
147#[serde(tag = "type")]
148pub enum ContentPartChunkPart {
149 OutputText { text: String },
150 SummaryText { text: String },
151}
152
153#[derive(Debug, Serialize, Deserialize, Clone)]
154pub struct DeltaTextChunk {
155 pub content_index: u64,
156 pub sequence_number: u64,
157 pub delta: String,
158}
159
160#[derive(Debug, Serialize, Deserialize, Clone)]
161pub struct DeltaTextChunkWithItemId {
162 pub item_id: String,
163 pub content_index: u64,
164 pub sequence_number: u64,
165 pub delta: String,
166}
167
168#[derive(Debug, Serialize, Deserialize, Clone)]
169pub struct OutputTextChunk {
170 pub content_index: u64,
171 pub sequence_number: u64,
172 pub text: String,
173}
174
175#[derive(Debug, Serialize, Deserialize, Clone)]
176pub struct RefusalTextChunk {
177 pub content_index: u64,
178 pub sequence_number: u64,
179 pub refusal: String,
180}
181
182#[derive(Debug, Serialize, Deserialize, Clone)]
183pub struct ArgsTextChunk {
184 pub content_index: u64,
185 pub sequence_number: u64,
186 pub arguments: serde_json::Value,
187}
188
189#[derive(Debug, Serialize, Deserialize, Clone)]
190pub struct SummaryPartChunk {
191 pub summary_index: u64,
192 pub sequence_number: u64,
193 pub part: SummaryPartChunkPart,
194}
195
196#[derive(Debug, Serialize, Deserialize, Clone)]
197pub struct SummaryTextChunk {
198 pub summary_index: u64,
199 pub sequence_number: u64,
200 pub delta: String,
201}
202
203#[derive(Debug, Serialize, Deserialize, Clone)]
204#[serde(tag = "type")]
205pub enum SummaryPartChunkPart {
206 SummaryText { text: String },
207}
208
209impl<T> ResponsesCompletionModel<T>
210where
211 T: HttpClientExt + Clone + Default + std::fmt::Debug + WasmCompatSend + 'static,
212{
213 pub(crate) async fn stream(
214 &self,
215 completion_request: crate::completion::CompletionRequest,
216 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
217 {
218 let mut request = self.create_completion_request(completion_request)?;
219 request.stream = Some(true);
220
221 if enabled!(Level::TRACE) {
222 tracing::trace!(
223 target: "rig::completions",
224 "OpenAI Responses streaming completion request: {}",
225 serde_json::to_string_pretty(&request)?
226 );
227 }
228
229 let body = serde_json::to_vec(&request)?;
230
231 let req = self
232 .client
233 .post("/responses")?
234 .body(body)
235 .map_err(|e| CompletionError::HttpError(e.into()))?;
236
237 let span = if tracing::Span::current().is_disabled() {
240 info_span!(
241 target: "rig::completions",
242 "chat_streaming",
243 gen_ai.operation.name = "chat_streaming",
244 gen_ai.provider.name = tracing::field::Empty,
245 gen_ai.request.model = tracing::field::Empty,
246 gen_ai.response.id = tracing::field::Empty,
247 gen_ai.response.model = tracing::field::Empty,
248 gen_ai.usage.output_tokens = tracing::field::Empty,
249 gen_ai.usage.input_tokens = tracing::field::Empty,
250 )
251 } else {
252 tracing::Span::current()
253 };
254 span.record("gen_ai.provider.name", "openai");
255 span.record("gen_ai.request.model", &self.model);
256 let client = self.client.clone();
258
259 let mut event_source = GenericEventSource::new(client, req);
260
261 let stream = stream! {
262 let mut final_usage = ResponsesUsage::new();
263
264 let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
265 let mut tool_call_internal_ids: std::collections::HashMap<String, String> = std::collections::HashMap::new();
266 let mut combined_text = String::new();
267 let span = tracing::Span::current();
268
269 while let Some(event_result) = event_source.next().await {
270 match event_result {
271 Ok(Event::Open) => {
272 tracing::trace!("SSE connection opened");
273 tracing::info!("OpenAI stream started");
274 continue;
275 }
276 Ok(Event::Message(evt)) => {
277 if evt.data.trim().is_empty() {
279 continue;
280 }
281
282 let data = serde_json::from_str::<StreamingCompletionChunk>(&evt.data);
283
284 let Ok(data) = data else {
285 let err = data.unwrap_err();
286 debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
287 continue;
288 };
289
290 if let StreamingCompletionChunk::Delta(chunk) = &data {
291 match &chunk.data {
292 ItemChunkKind::OutputItemAdded(message) => {
293 if let StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } = message {
294 let internal_call_id = tool_call_internal_ids
295 .entry(func.id.clone())
296 .or_insert_with(|| nanoid::nanoid!())
297 .clone();
298 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
299 id: func.id.clone(),
300 internal_call_id,
301 content: streaming::ToolCallDeltaContent::Name(func.name.clone()),
302 });
303 }
304 }
305 ItemChunkKind::OutputItemDone(message) => {
306 match message {
307 StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } => {
308 let internal_id = tool_call_internal_ids
309 .entry(func.id.clone())
310 .or_insert_with(|| nanoid::nanoid!())
311 .clone();
312 let raw_tool_call = streaming::RawStreamingToolCall::new(
313 func.id.clone(),
314 func.name.clone(),
315 func.arguments.clone(),
316 )
317 .with_internal_call_id(internal_id)
318 .with_call_id(func.call_id.clone());
319 tool_calls.push(streaming::RawStreamingChoice::ToolCall(raw_tool_call));
320 }
321
322 StreamingItemDoneOutput { item: Output::Reasoning { summary, id }, .. } => {
323 let reasoning = summary
324 .iter()
325 .map(|x| {
326 let ReasoningSummary::SummaryText { text } = x;
327 text.to_owned()
328 })
329 .collect::<Vec<String>>()
330 .join("\n");
331 yield Ok(streaming::RawStreamingChoice::Reasoning {
332 id: Some(id.to_string()),
333 reasoning,
334 signature: None,
335 })
336 }
337 _ => continue
338 }
339 }
340 ItemChunkKind::OutputTextDelta(delta) => {
341 combined_text.push_str(&delta.delta);
342 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
343 }
344 ItemChunkKind::ReasoningSummaryTextDelta(delta) => {
345 yield Ok(streaming::RawStreamingChoice::ReasoningDelta { id: None, reasoning: delta.delta.clone() })
346 }
347 ItemChunkKind::RefusalDelta(delta) => {
348 combined_text.push_str(&delta.delta);
349 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
350 }
351 ItemChunkKind::FunctionCallArgsDelta(delta) => {
352 let internal_call_id = tool_call_internal_ids
353 .entry(delta.item_id.clone())
354 .or_insert_with(|| nanoid::nanoid!())
355 .clone();
356 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
357 id: delta.item_id.clone(),
358 internal_call_id,
359 content: streaming::ToolCallDeltaContent::Delta(delta.delta.clone())
360 })
361 }
362
363 _ => { continue }
364 }
365 }
366
367 if let StreamingCompletionChunk::Response(chunk) = data {
368 if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk {
369 span.record("gen_ai.response.id", response.id);
370 span.record("gen_ai.response.model", response.model);
371 if let Some(usage) = response.usage {
372 final_usage = usage;
373 }
374 } else {
375 continue;
376 }
377 }
378 }
379 Err(crate::http_client::Error::StreamEnded) => {
380 event_source.close();
381 }
382 Err(error) => {
383 tracing::error!(?error, "SSE error");
384 yield Err(CompletionError::ProviderError(error.to_string()));
385 break;
386 }
387 }
388 }
389
390 event_source.close();
392
393 for tool_call in &tool_calls {
394 yield Ok(tool_call.to_owned())
395 }
396
397 span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
398 span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
399 tracing::info!("OpenAI stream finished");
400
401 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
402 usage: final_usage
403 }));
404 }.instrument(span);
405
406 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
407 stream,
408 )))
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use futures::StreamExt;
415 use rig::{client::CompletionClient, providers::openai, streaming::StreamingChat};
416 use serde_json;
417
418 use crate::{
419 completion::ToolDefinition,
420 tool::{Tool, ToolError},
421 };
422
423 struct ExampleTool;
424
425 impl Tool for ExampleTool {
426 type Args = ();
427 type Error = ToolError;
428 type Output = String;
429 const NAME: &'static str = "example_tool";
430
431 async fn definition(&self, _prompt: String) -> ToolDefinition {
432 ToolDefinition {
433 name: self.name(),
434 description: "A tool that returns some example text.".to_string(),
435 parameters: serde_json::json!({
436 "type": "object",
437 "properties": {},
438 "required": []
439 }),
440 }
441 }
442
443 async fn call(&self, _input: Self::Args) -> Result<Self::Output, Self::Error> {
444 let result = "Example answer".to_string();
445 Ok(result)
446 }
447 }
448
449 #[tokio::test]
451 #[ignore = "requires API key"]
452 async fn test_openai_streaming_tools_reasoning() {
453 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY env var should exist");
454 let client: openai::Client<rig::http_client::ReqwestClient> =
455 openai::Client::new(&api_key).expect("Failed to build client");
456 let agent = client
457 .agent("gpt-5.2")
458 .max_tokens(8192)
459 .tool(ExampleTool)
460 .additional_params(serde_json::json!({
461 "reasoning": {"effort": "high"}
462 }))
463 .build();
464
465 let chat_history = Vec::new();
466 let mut stream = agent
467 .stream_chat("Call my example tool", chat_history)
468 .multi_turn(5)
469 .await;
470
471 while let Some(item) = stream.next().await {
472 println!("Got item: {item:?}");
473 }
474 }
475}