rig/providers/openai/responses_api/
streaming.rs1use crate::completion::{CompletionError, GetTokenUsage};
4use crate::http_client::HttpClientExt;
5use crate::http_client::sse::{Event, GenericEventSource};
6use crate::providers::openai::responses_api::{
7 ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
8};
9use crate::streaming;
10use crate::streaming::RawStreamingChoice;
11use crate::wasm_compat::WasmCompatSend;
12use async_stream::stream;
13use futures::StreamExt;
14use serde::{Deserialize, Serialize};
15use tracing::{Level, debug, enabled, info_span};
16use tracing_futures::Instrument as _;
17
18use super::{CompletionResponse, Output};
19
20#[derive(Debug, Serialize, Deserialize, Clone)]
29#[serde(untagged)]
30pub enum StreamingCompletionChunk {
31 Response(Box<ResponseChunk>),
32 Delta(ItemChunk),
33}
34
35#[derive(Debug, Serialize, Deserialize, Clone)]
37pub struct StreamingCompletionResponse {
38 pub usage: ResponsesUsage,
40}
41
42impl GetTokenUsage for StreamingCompletionResponse {
43 fn token_usage(&self) -> Option<crate::completion::Usage> {
44 let mut usage = crate::completion::Usage::new();
45 usage.input_tokens = self.usage.input_tokens;
46 usage.output_tokens = self.usage.output_tokens;
47 usage.total_tokens = self.usage.total_tokens;
48 Some(usage)
49 }
50}
51
52#[derive(Debug, Serialize, Deserialize, Clone)]
54pub struct ResponseChunk {
55 #[serde(rename = "type")]
57 pub kind: ResponseChunkKind,
58 pub response: CompletionResponse,
60 pub sequence_number: u64,
62}
63
64#[derive(Debug, Serialize, Deserialize, Clone)]
67pub enum ResponseChunkKind {
68 #[serde(rename = "response.created")]
69 ResponseCreated,
70 #[serde(rename = "response.in_progress")]
71 ResponseInProgress,
72 #[serde(rename = "response.completed")]
73 ResponseCompleted,
74 #[serde(rename = "response.failed")]
75 ResponseFailed,
76 #[serde(rename = "response.incomplete")]
77 ResponseIncomplete,
78}
79
80#[derive(Debug, Serialize, Deserialize, Clone)]
83pub struct ItemChunk {
84 pub item_id: Option<String>,
86 pub output_index: u64,
88 #[serde(flatten)]
90 pub data: ItemChunkKind,
91}
92
93#[derive(Debug, Serialize, Deserialize, Clone)]
95#[serde(tag = "type")]
96pub enum ItemChunkKind {
97 #[serde(rename = "response.output_item.added")]
98 OutputItemAdded(StreamingItemDoneOutput),
99 #[serde(rename = "response.output_item.done")]
100 OutputItemDone(StreamingItemDoneOutput),
101 #[serde(rename = "response.content_part.added")]
102 ContentPartAdded(ContentPartChunk),
103 #[serde(rename = "response.content_part.done")]
104 ContentPartDone(ContentPartChunk),
105 #[serde(rename = "response.output_text.delta")]
106 OutputTextDelta(DeltaTextChunk),
107 #[serde(rename = "response.output_text.done")]
108 OutputTextDone(OutputTextChunk),
109 #[serde(rename = "response.refusal.delta")]
110 RefusalDelta(DeltaTextChunk),
111 #[serde(rename = "response.refusal.done")]
112 RefusalDone(RefusalTextChunk),
113 #[serde(rename = "response.function_call_arguments.delta")]
114 FunctionCallArgsDelta(DeltaTextChunkWithItemId),
115 #[serde(rename = "response.function_call_arguments.done")]
116 FunctionCallArgsDone(ArgsTextChunk),
117 #[serde(rename = "response.reasoning_summary_part.added")]
118 ReasoningSummaryPartAdded(SummaryPartChunk),
119 #[serde(rename = "response.reasoning_summary_part.done")]
120 ReasoningSummaryPartDone(SummaryPartChunk),
121 #[serde(rename = "response.reasoning_summary_text.added")]
122 ReasoningSummaryTextAdded(SummaryTextChunk),
123 #[serde(rename = "response.reasoning_summary_text.done")]
124 ReasoningSummaryTextDone(SummaryTextChunk),
125}
126
127#[derive(Debug, Serialize, Deserialize, Clone)]
128pub struct StreamingItemDoneOutput {
129 pub sequence_number: u64,
130 pub item: Output,
131}
132
133#[derive(Debug, Serialize, Deserialize, Clone)]
134pub struct ContentPartChunk {
135 pub content_index: u64,
136 pub sequence_number: u64,
137 pub part: ContentPartChunkPart,
138}
139
140#[derive(Debug, Serialize, Deserialize, Clone)]
141#[serde(tag = "type")]
142pub enum ContentPartChunkPart {
143 OutputText { text: String },
144 SummaryText { text: String },
145}
146
147#[derive(Debug, Serialize, Deserialize, Clone)]
148pub struct DeltaTextChunk {
149 pub content_index: u64,
150 pub sequence_number: u64,
151 pub delta: String,
152}
153
154#[derive(Debug, Serialize, Deserialize, Clone)]
155pub struct DeltaTextChunkWithItemId {
156 pub item_id: String,
157 pub content_index: u64,
158 pub sequence_number: u64,
159 pub delta: String,
160}
161
162#[derive(Debug, Serialize, Deserialize, Clone)]
163pub struct OutputTextChunk {
164 pub content_index: u64,
165 pub sequence_number: u64,
166 pub text: String,
167}
168
169#[derive(Debug, Serialize, Deserialize, Clone)]
170pub struct RefusalTextChunk {
171 pub content_index: u64,
172 pub sequence_number: u64,
173 pub refusal: String,
174}
175
176#[derive(Debug, Serialize, Deserialize, Clone)]
177pub struct ArgsTextChunk {
178 pub content_index: u64,
179 pub sequence_number: u64,
180 pub arguments: serde_json::Value,
181}
182
183#[derive(Debug, Serialize, Deserialize, Clone)]
184pub struct SummaryPartChunk {
185 pub summary_index: u64,
186 pub sequence_number: u64,
187 pub part: SummaryPartChunkPart,
188}
189
190#[derive(Debug, Serialize, Deserialize, Clone)]
191pub struct SummaryTextChunk {
192 pub summary_index: u64,
193 pub sequence_number: u64,
194 pub delta: String,
195}
196
197#[derive(Debug, Serialize, Deserialize, Clone)]
198#[serde(tag = "type")]
199pub enum SummaryPartChunkPart {
200 SummaryText { text: String },
201}
202
203impl<T> ResponsesCompletionModel<T>
204where
205 T: HttpClientExt + Clone + Default + std::fmt::Debug + WasmCompatSend + 'static,
206{
207 pub(crate) async fn stream(
208 &self,
209 completion_request: crate::completion::CompletionRequest,
210 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
211 {
212 let mut request = self.create_completion_request(completion_request)?;
213 request.stream = Some(true);
214
215 if enabled!(Level::TRACE) {
216 tracing::trace!(
217 target: "rig::completions",
218 "OpenAI Responses streaming completion request: {}",
219 serde_json::to_string_pretty(&request)?
220 );
221 }
222
223 let body = serde_json::to_vec(&request)?;
224
225 let req = self
226 .client
227 .post("/responses")?
228 .body(body)
229 .map_err(|e| CompletionError::HttpError(e.into()))?;
230
231 let span = if tracing::Span::current().is_disabled() {
234 info_span!(
235 target: "rig::completions",
236 "chat_streaming",
237 gen_ai.operation.name = "chat_streaming",
238 gen_ai.provider.name = tracing::field::Empty,
239 gen_ai.request.model = tracing::field::Empty,
240 gen_ai.response.id = tracing::field::Empty,
241 gen_ai.response.model = tracing::field::Empty,
242 gen_ai.usage.output_tokens = tracing::field::Empty,
243 gen_ai.usage.input_tokens = tracing::field::Empty,
244 )
245 } else {
246 tracing::Span::current()
247 };
248 span.record("gen_ai.provider.name", "openai");
249 span.record("gen_ai.request.model", &self.model);
250 let client = self.client.clone();
252
253 let mut event_source = GenericEventSource::new(client, req);
254
255 let stream = stream! {
256 let mut final_usage = ResponsesUsage::new();
257
258 let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
259 let mut combined_text = String::new();
260 let span = tracing::Span::current();
261
262 while let Some(event_result) = event_source.next().await {
263 match event_result {
264 Ok(Event::Open) => {
265 tracing::trace!("SSE connection opened");
266 tracing::info!("OpenAI stream started");
267 continue;
268 }
269 Ok(Event::Message(evt)) => {
270 if evt.data.trim().is_empty() {
272 continue;
273 }
274
275 let data = serde_json::from_str::<StreamingCompletionChunk>(&evt.data);
276
277 let Ok(data) = data else {
278 let err = data.unwrap_err();
279 debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
280 continue;
281 };
282
283 if let StreamingCompletionChunk::Delta(chunk) = &data {
284 match &chunk.data {
285 ItemChunkKind::OutputItemDone(message) => {
286 match message {
287 StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } => {
288 tool_calls.push(streaming::RawStreamingChoice::ToolCall { id: func.id.clone(), call_id: Some(func.call_id.clone()), name: func.name.clone(), arguments: func.arguments.clone() });
289 }
290
291 StreamingItemDoneOutput { item: Output::Reasoning { summary, id }, .. } => {
292 let reasoning = summary
293 .iter()
294 .map(|x| {
295 let ReasoningSummary::SummaryText { text } = x;
296 text.to_owned()
297 })
298 .collect::<Vec<String>>()
299 .join("\n");
300 yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning, id: Some(id.to_string()), signature: None })
301 }
302 _ => continue
303 }
304 }
305 ItemChunkKind::OutputTextDelta(delta) => {
306 combined_text.push_str(&delta.delta);
307 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
308 }
309 ItemChunkKind::RefusalDelta(delta) => {
310 combined_text.push_str(&delta.delta);
311 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
312 }
313 ItemChunkKind::FunctionCallArgsDelta(delta) => {
314 yield Ok(streaming::RawStreamingChoice::ToolCallDelta { id: delta.item_id.clone(), delta: delta.delta.clone() })
315 }
316
317 _ => { continue }
318 }
319 }
320
321 if let StreamingCompletionChunk::Response(chunk) = data {
322 if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk {
323 span.record("gen_ai.response.id", response.id);
324 span.record("gen_ai.response.model", response.model);
325 if let Some(usage) = response.usage {
326 final_usage = usage;
327 }
328 } else {
329 continue;
330 }
331 }
332 }
333 Err(crate::http_client::Error::StreamEnded) => {
334 event_source.close();
335 }
336 Err(error) => {
337 tracing::error!(?error, "SSE error");
338 yield Err(CompletionError::ProviderError(error.to_string()));
339 break;
340 }
341 }
342 }
343
344 event_source.close();
346
347 for tool_call in &tool_calls {
348 yield Ok(tool_call.to_owned())
349 }
350
351 span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
352 span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
353 tracing::info!("OpenAI stream finished");
354
355 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
356 usage: final_usage
357 }));
358 }.instrument(span);
359
360 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
361 stream,
362 )))
363 }
364}