rig/providers/openai/responses_api/
streaming.rs1use crate::completion::{CompletionError, GetTokenUsage};
4use crate::http_client::HttpClientExt;
5use crate::http_client::sse::{Event, GenericEventSource};
6use crate::providers::openai::responses_api::{
7 ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
8};
9use crate::streaming;
10use crate::streaming::RawStreamingChoice;
11use async_stream::stream;
12use futures::StreamExt;
13use serde::{Deserialize, Serialize};
14use tracing::{debug, info_span};
15use tracing_futures::Instrument as _;
16
17use super::{CompletionResponse, Output};
18
19#[derive(Debug, Serialize, Deserialize, Clone)]
28#[serde(untagged)]
29pub enum StreamingCompletionChunk {
30 Response(Box<ResponseChunk>),
31 Delta(ItemChunk),
32}
33
34#[derive(Debug, Serialize, Deserialize, Clone)]
36pub struct StreamingCompletionResponse {
37 pub usage: ResponsesUsage,
39}
40
41impl GetTokenUsage for StreamingCompletionResponse {
42 fn token_usage(&self) -> Option<crate::completion::Usage> {
43 let mut usage = crate::completion::Usage::new();
44 usage.input_tokens = self.usage.input_tokens;
45 usage.output_tokens = self.usage.output_tokens;
46 usage.total_tokens = self.usage.total_tokens;
47 Some(usage)
48 }
49}
50
51#[derive(Debug, Serialize, Deserialize, Clone)]
53pub struct ResponseChunk {
54 #[serde(rename = "type")]
56 pub kind: ResponseChunkKind,
57 pub response: CompletionResponse,
59 pub sequence_number: u64,
61}
62
63#[derive(Debug, Serialize, Deserialize, Clone)]
66pub enum ResponseChunkKind {
67 #[serde(rename = "response.created")]
68 ResponseCreated,
69 #[serde(rename = "response.in_progress")]
70 ResponseInProgress,
71 #[serde(rename = "response.completed")]
72 ResponseCompleted,
73 #[serde(rename = "response.failed")]
74 ResponseFailed,
75 #[serde(rename = "response.incomplete")]
76 ResponseIncomplete,
77}
78
79#[derive(Debug, Serialize, Deserialize, Clone)]
82pub struct ItemChunk {
83 pub item_id: Option<String>,
85 pub output_index: u64,
87 #[serde(flatten)]
89 pub data: ItemChunkKind,
90}
91
92#[derive(Debug, Serialize, Deserialize, Clone)]
94#[serde(tag = "type")]
95pub enum ItemChunkKind {
96 #[serde(rename = "response.output_item.added")]
97 OutputItemAdded(StreamingItemDoneOutput),
98 #[serde(rename = "response.output_item.done")]
99 OutputItemDone(StreamingItemDoneOutput),
100 #[serde(rename = "response.content_part.added")]
101 ContentPartAdded(ContentPartChunk),
102 #[serde(rename = "response.content_part.done")]
103 ContentPartDone(ContentPartChunk),
104 #[serde(rename = "response.output_text.delta")]
105 OutputTextDelta(DeltaTextChunk),
106 #[serde(rename = "response.output_text.done")]
107 OutputTextDone(OutputTextChunk),
108 #[serde(rename = "response.refusal.delta")]
109 RefusalDelta(DeltaTextChunk),
110 #[serde(rename = "response.refusal.done")]
111 RefusalDone(RefusalTextChunk),
112 #[serde(rename = "response.function_call_arguments.delta")]
113 FunctionCallArgsDelta(DeltaTextChunkWithItemId),
114 #[serde(rename = "response.function_call_arguments.done")]
115 FunctionCallArgsDone(ArgsTextChunk),
116 #[serde(rename = "response.reasoning_summary_part.added")]
117 ReasoningSummaryPartAdded(SummaryPartChunk),
118 #[serde(rename = "response.reasoning_summary_part.done")]
119 ReasoningSummaryPartDone(SummaryPartChunk),
120 #[serde(rename = "response.reasoning_summary_text.added")]
121 ReasoningSummaryTextAdded(SummaryTextChunk),
122 #[serde(rename = "response.reasoning_summary_text.done")]
123 ReasoningSummaryTextDone(SummaryTextChunk),
124}
125
126#[derive(Debug, Serialize, Deserialize, Clone)]
127pub struct StreamingItemDoneOutput {
128 pub sequence_number: u64,
129 pub item: Output,
130}
131
132#[derive(Debug, Serialize, Deserialize, Clone)]
133pub struct ContentPartChunk {
134 pub content_index: u64,
135 pub sequence_number: u64,
136 pub part: ContentPartChunkPart,
137}
138
139#[derive(Debug, Serialize, Deserialize, Clone)]
140#[serde(tag = "type")]
141pub enum ContentPartChunkPart {
142 OutputText { text: String },
143 SummaryText { text: String },
144}
145
146#[derive(Debug, Serialize, Deserialize, Clone)]
147pub struct DeltaTextChunk {
148 pub content_index: u64,
149 pub sequence_number: u64,
150 pub delta: String,
151}
152
153#[derive(Debug, Serialize, Deserialize, Clone)]
154pub struct DeltaTextChunkWithItemId {
155 pub item_id: String,
156 pub content_index: u64,
157 pub sequence_number: u64,
158 pub delta: String,
159}
160
161#[derive(Debug, Serialize, Deserialize, Clone)]
162pub struct OutputTextChunk {
163 pub content_index: u64,
164 pub sequence_number: u64,
165 pub text: String,
166}
167
168#[derive(Debug, Serialize, Deserialize, Clone)]
169pub struct RefusalTextChunk {
170 pub content_index: u64,
171 pub sequence_number: u64,
172 pub refusal: String,
173}
174
175#[derive(Debug, Serialize, Deserialize, Clone)]
176pub struct ArgsTextChunk {
177 pub content_index: u64,
178 pub sequence_number: u64,
179 pub arguments: serde_json::Value,
180}
181
182#[derive(Debug, Serialize, Deserialize, Clone)]
183pub struct SummaryPartChunk {
184 pub summary_index: u64,
185 pub sequence_number: u64,
186 pub part: SummaryPartChunkPart,
187}
188
189#[derive(Debug, Serialize, Deserialize, Clone)]
190pub struct SummaryTextChunk {
191 pub summary_index: u64,
192 pub sequence_number: u64,
193 pub delta: String,
194}
195
196#[derive(Debug, Serialize, Deserialize, Clone)]
197#[serde(tag = "type")]
198pub enum SummaryPartChunkPart {
199 SummaryText { text: String },
200}
201
202impl<T> ResponsesCompletionModel<T>
203where
204 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
205{
206 pub(crate) async fn stream(
207 &self,
208 completion_request: crate::completion::CompletionRequest,
209 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
210 {
211 let mut request = self.create_completion_request(completion_request)?;
212 request.stream = Some(true);
213
214 let body = serde_json::to_vec(&request)?;
215
216 let req = self
217 .client
218 .post("/responses")?
219 .header("Content-Type", "application/json")
220 .body(body)
221 .map_err(|e| CompletionError::HttpError(e.into()))?;
222
223 let span = if tracing::Span::current().is_disabled() {
226 info_span!(
227 target: "rig::completions",
228 "chat_streaming",
229 gen_ai.operation.name = "chat_streaming",
230 gen_ai.provider.name = tracing::field::Empty,
231 gen_ai.request.model = tracing::field::Empty,
232 gen_ai.response.id = tracing::field::Empty,
233 gen_ai.response.model = tracing::field::Empty,
234 gen_ai.usage.output_tokens = tracing::field::Empty,
235 gen_ai.usage.input_tokens = tracing::field::Empty,
236 gen_ai.input.messages = tracing::field::Empty,
237 gen_ai.output.messages = tracing::field::Empty,
238 )
239 } else {
240 tracing::Span::current()
241 };
242 span.record("gen_ai.provider.name", "openai");
243 span.record("gen_ai.request.model", &self.model);
244 span.record(
245 "gen_ai.input.messages",
246 serde_json::to_string(&request.input).expect("This should always work"),
247 );
248 let client = self.clone().client.http_client;
250
251 let mut event_source = GenericEventSource::new(client, req);
252
253 let stream = stream! {
254 let mut final_usage = ResponsesUsage::new();
255
256 let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
257 let mut combined_text = String::new();
258 let span = tracing::Span::current();
259
260 while let Some(event_result) = event_source.next().await {
261 match event_result {
262 Ok(Event::Open) => {
263 tracing::trace!("SSE connection opened");
264 tracing::info!("OpenAI stream started");
265 continue;
266 }
267 Ok(Event::Message(evt)) => {
268 if evt.data.trim().is_empty() {
270 continue;
271 }
272
273 let data = serde_json::from_str::<StreamingCompletionChunk>(&evt.data);
274
275 let Ok(data) = data else {
276 let err = data.unwrap_err();
277 debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
278 continue;
279 };
280
281 if let StreamingCompletionChunk::Delta(chunk) = &data {
282 match &chunk.data {
283 ItemChunkKind::OutputItemDone(message) => {
284 match message {
285 StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } => {
286 tool_calls.push(streaming::RawStreamingChoice::ToolCall { id: func.id.clone(), call_id: Some(func.call_id.clone()), name: func.name.clone(), arguments: func.arguments.clone() });
287 }
288
289 StreamingItemDoneOutput { item: Output::Reasoning { summary, id }, .. } => {
290 let reasoning = summary
291 .iter()
292 .map(|x| {
293 let ReasoningSummary::SummaryText { text } = x;
294 text.to_owned()
295 })
296 .collect::<Vec<String>>()
297 .join("\n");
298 yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning, id: Some(id.to_string()), signature: None })
299 }
300 _ => continue
301 }
302 }
303 ItemChunkKind::OutputTextDelta(delta) => {
304 combined_text.push_str(&delta.delta);
305 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
306 }
307 ItemChunkKind::RefusalDelta(delta) => {
308 combined_text.push_str(&delta.delta);
309 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
310 }
311 ItemChunkKind::FunctionCallArgsDelta(delta) => {
312 yield Ok(streaming::RawStreamingChoice::ToolCallDelta { id: delta.item_id.clone(), delta: delta.delta.clone() })
313 }
314
315 _ => { continue }
316 }
317 }
318
319 if let StreamingCompletionChunk::Response(chunk) = data {
320 if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk {
321 span.record("gen_ai.output.messages", serde_json::to_string(&response.output).unwrap());
322 span.record("gen_ai.response.id", response.id);
323 span.record("gen_ai.response.model", response.model);
324 if let Some(usage) = response.usage {
325 final_usage = usage;
326 }
327 } else {
328 continue;
329 }
330 }
331 }
332 Err(crate::http_client::Error::StreamEnded) => {
333 event_source.close();
334 }
335 Err(error) => {
336 tracing::error!(?error, "SSE error");
337 yield Err(CompletionError::ResponseError(error.to_string()));
338 break;
339 }
340 }
341 }
342
343 for tool_call in &tool_calls {
347 yield Ok(tool_call.to_owned())
348 }
349
350 span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
351 span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
352 tracing::info!("OpenAI stream finished");
353
354 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
355 usage: final_usage.clone()
356 }));
357 }.instrument(span);
358
359 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
360 stream,
361 )))
362 }
363}