rig/providers/openai/responses_api/
streaming.rs1use crate::completion::{CompletionError, GetTokenUsage};
4use crate::http_client::HttpClientExt;
5use crate::http_client::sse::{Event, GenericEventSource};
6use crate::providers::openai::responses_api::{
7 ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
8};
9use crate::streaming;
10use crate::streaming::RawStreamingChoice;
11use crate::wasm_compat::WasmCompatSend;
12use async_stream::stream;
13use futures::StreamExt;
14use serde::{Deserialize, Serialize};
15use tracing::{debug, info_span};
16use tracing_futures::Instrument as _;
17
18use super::{CompletionResponse, Output};
19
20#[derive(Debug, Serialize, Deserialize, Clone)]
29#[serde(untagged)]
30pub enum StreamingCompletionChunk {
31 Response(Box<ResponseChunk>),
32 Delta(ItemChunk),
33}
34
35#[derive(Debug, Serialize, Deserialize, Clone)]
37pub struct StreamingCompletionResponse {
38 pub usage: ResponsesUsage,
40}
41
42impl GetTokenUsage for StreamingCompletionResponse {
43 fn token_usage(&self) -> Option<crate::completion::Usage> {
44 let mut usage = crate::completion::Usage::new();
45 usage.input_tokens = self.usage.input_tokens;
46 usage.output_tokens = self.usage.output_tokens;
47 usage.total_tokens = self.usage.total_tokens;
48 Some(usage)
49 }
50}
51
52#[derive(Debug, Serialize, Deserialize, Clone)]
54pub struct ResponseChunk {
55 #[serde(rename = "type")]
57 pub kind: ResponseChunkKind,
58 pub response: CompletionResponse,
60 pub sequence_number: u64,
62}
63
64#[derive(Debug, Serialize, Deserialize, Clone)]
67pub enum ResponseChunkKind {
68 #[serde(rename = "response.created")]
69 ResponseCreated,
70 #[serde(rename = "response.in_progress")]
71 ResponseInProgress,
72 #[serde(rename = "response.completed")]
73 ResponseCompleted,
74 #[serde(rename = "response.failed")]
75 ResponseFailed,
76 #[serde(rename = "response.incomplete")]
77 ResponseIncomplete,
78}
79
80#[derive(Debug, Serialize, Deserialize, Clone)]
83pub struct ItemChunk {
84 pub item_id: Option<String>,
86 pub output_index: u64,
88 #[serde(flatten)]
90 pub data: ItemChunkKind,
91}
92
93#[derive(Debug, Serialize, Deserialize, Clone)]
95#[serde(tag = "type")]
96pub enum ItemChunkKind {
97 #[serde(rename = "response.output_item.added")]
98 OutputItemAdded(StreamingItemDoneOutput),
99 #[serde(rename = "response.output_item.done")]
100 OutputItemDone(StreamingItemDoneOutput),
101 #[serde(rename = "response.content_part.added")]
102 ContentPartAdded(ContentPartChunk),
103 #[serde(rename = "response.content_part.done")]
104 ContentPartDone(ContentPartChunk),
105 #[serde(rename = "response.output_text.delta")]
106 OutputTextDelta(DeltaTextChunk),
107 #[serde(rename = "response.output_text.done")]
108 OutputTextDone(OutputTextChunk),
109 #[serde(rename = "response.refusal.delta")]
110 RefusalDelta(DeltaTextChunk),
111 #[serde(rename = "response.refusal.done")]
112 RefusalDone(RefusalTextChunk),
113 #[serde(rename = "response.function_call_arguments.delta")]
114 FunctionCallArgsDelta(DeltaTextChunkWithItemId),
115 #[serde(rename = "response.function_call_arguments.done")]
116 FunctionCallArgsDone(ArgsTextChunk),
117 #[serde(rename = "response.reasoning_summary_part.added")]
118 ReasoningSummaryPartAdded(SummaryPartChunk),
119 #[serde(rename = "response.reasoning_summary_part.done")]
120 ReasoningSummaryPartDone(SummaryPartChunk),
121 #[serde(rename = "response.reasoning_summary_text.added")]
122 ReasoningSummaryTextAdded(SummaryTextChunk),
123 #[serde(rename = "response.reasoning_summary_text.done")]
124 ReasoningSummaryTextDone(SummaryTextChunk),
125}
126
127#[derive(Debug, Serialize, Deserialize, Clone)]
128pub struct StreamingItemDoneOutput {
129 pub sequence_number: u64,
130 pub item: Output,
131}
132
133#[derive(Debug, Serialize, Deserialize, Clone)]
134pub struct ContentPartChunk {
135 pub content_index: u64,
136 pub sequence_number: u64,
137 pub part: ContentPartChunkPart,
138}
139
140#[derive(Debug, Serialize, Deserialize, Clone)]
141#[serde(tag = "type")]
142pub enum ContentPartChunkPart {
143 OutputText { text: String },
144 SummaryText { text: String },
145}
146
147#[derive(Debug, Serialize, Deserialize, Clone)]
148pub struct DeltaTextChunk {
149 pub content_index: u64,
150 pub sequence_number: u64,
151 pub delta: String,
152}
153
154#[derive(Debug, Serialize, Deserialize, Clone)]
155pub struct DeltaTextChunkWithItemId {
156 pub item_id: String,
157 pub content_index: u64,
158 pub sequence_number: u64,
159 pub delta: String,
160}
161
162#[derive(Debug, Serialize, Deserialize, Clone)]
163pub struct OutputTextChunk {
164 pub content_index: u64,
165 pub sequence_number: u64,
166 pub text: String,
167}
168
169#[derive(Debug, Serialize, Deserialize, Clone)]
170pub struct RefusalTextChunk {
171 pub content_index: u64,
172 pub sequence_number: u64,
173 pub refusal: String,
174}
175
176#[derive(Debug, Serialize, Deserialize, Clone)]
177pub struct ArgsTextChunk {
178 pub content_index: u64,
179 pub sequence_number: u64,
180 pub arguments: serde_json::Value,
181}
182
183#[derive(Debug, Serialize, Deserialize, Clone)]
184pub struct SummaryPartChunk {
185 pub summary_index: u64,
186 pub sequence_number: u64,
187 pub part: SummaryPartChunkPart,
188}
189
190#[derive(Debug, Serialize, Deserialize, Clone)]
191pub struct SummaryTextChunk {
192 pub summary_index: u64,
193 pub sequence_number: u64,
194 pub delta: String,
195}
196
197#[derive(Debug, Serialize, Deserialize, Clone)]
198#[serde(tag = "type")]
199pub enum SummaryPartChunkPart {
200 SummaryText { text: String },
201}
202
203impl<T> ResponsesCompletionModel<T>
204where
205 T: HttpClientExt + Clone + Default + std::fmt::Debug + WasmCompatSend + 'static,
206{
207 pub(crate) async fn stream(
208 &self,
209 completion_request: crate::completion::CompletionRequest,
210 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
211 {
212 let mut request = self.create_completion_request(completion_request)?;
213 request.stream = Some(true);
214
215 let body = serde_json::to_vec(&request)?;
216
217 let req = self
218 .client
219 .post("/responses")?
220 .body(body)
221 .map_err(|e| CompletionError::HttpError(e.into()))?;
222
223 let span = if tracing::Span::current().is_disabled() {
226 info_span!(
227 target: "rig::completions",
228 "chat_streaming",
229 gen_ai.operation.name = "chat_streaming",
230 gen_ai.provider.name = tracing::field::Empty,
231 gen_ai.request.model = tracing::field::Empty,
232 gen_ai.response.id = tracing::field::Empty,
233 gen_ai.response.model = tracing::field::Empty,
234 gen_ai.usage.output_tokens = tracing::field::Empty,
235 gen_ai.usage.input_tokens = tracing::field::Empty,
236 gen_ai.input.messages = tracing::field::Empty,
237 gen_ai.output.messages = tracing::field::Empty,
238 )
239 } else {
240 tracing::Span::current()
241 };
242 span.record("gen_ai.provider.name", "openai");
243 span.record("gen_ai.request.model", &self.model);
244 span.record(
245 "gen_ai.input.messages",
246 serde_json::to_string(&request.input).expect("This should always work"),
247 );
248 let client = self.client.http_client().clone();
250
251 let mut event_source = GenericEventSource::new(client, req);
252
253 let stream = stream! {
254 let mut final_usage = ResponsesUsage::new();
255
256 let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
257 let mut combined_text = String::new();
258 let span = tracing::Span::current();
259
260 while let Some(event_result) = event_source.next().await {
261 match event_result {
262 Ok(Event::Open) => {
263 tracing::trace!("SSE connection opened");
264 tracing::info!("OpenAI stream started");
265 continue;
266 }
267 Ok(Event::Message(evt)) => {
268 if evt.data.trim().is_empty() {
270 continue;
271 }
272
273 let data = serde_json::from_str::<StreamingCompletionChunk>(&evt.data);
274
275 let Ok(data) = data else {
276 let err = data.unwrap_err();
277 debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
278 continue;
279 };
280
281 if let StreamingCompletionChunk::Delta(chunk) = &data {
282 match &chunk.data {
283 ItemChunkKind::OutputItemDone(message) => {
284 match message {
285 StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } => {
286 tool_calls.push(streaming::RawStreamingChoice::ToolCall { id: func.id.clone(), call_id: Some(func.call_id.clone()), name: func.name.clone(), arguments: func.arguments.clone() });
287 }
288
289 StreamingItemDoneOutput { item: Output::Reasoning { summary, id }, .. } => {
290 let reasoning = summary
291 .iter()
292 .map(|x| {
293 let ReasoningSummary::SummaryText { text } = x;
294 text.to_owned()
295 })
296 .collect::<Vec<String>>()
297 .join("\n");
298 yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning, id: Some(id.to_string()), signature: None })
299 }
300 _ => continue
301 }
302 }
303 ItemChunkKind::OutputTextDelta(delta) => {
304 combined_text.push_str(&delta.delta);
305 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
306 }
307 ItemChunkKind::RefusalDelta(delta) => {
308 combined_text.push_str(&delta.delta);
309 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
310 }
311 ItemChunkKind::FunctionCallArgsDelta(delta) => {
312 yield Ok(streaming::RawStreamingChoice::ToolCallDelta { id: delta.item_id.clone(), delta: delta.delta.clone() })
313 }
314
315 _ => { continue }
316 }
317 }
318
319 if let StreamingCompletionChunk::Response(chunk) = data {
320 if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk {
321 span.record("gen_ai.output.messages", serde_json::to_string(&response.output).unwrap());
322 span.record("gen_ai.response.id", response.id);
323 span.record("gen_ai.response.model", response.model);
324 if let Some(usage) = response.usage {
325 final_usage = usage;
326 }
327 } else {
328 continue;
329 }
330 }
331 }
332 Err(crate::http_client::Error::StreamEnded) => {
333 event_source.close();
334 }
335 Err(error) => {
336 tracing::error!(?error, "SSE error");
337 yield Err(CompletionError::ProviderError(error.to_string()));
338 break;
339 }
340 }
341 }
342
343 event_source.close();
345
346 for tool_call in &tool_calls {
347 yield Ok(tool_call.to_owned())
348 }
349
350 span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
351 span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
352 tracing::info!("OpenAI stream finished");
353
354 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
355 usage: final_usage
356 }));
357 }.instrument(span);
358
359 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
360 stream,
361 )))
362 }
363}