rig/providers/openai/responses_api/
streaming.rs1use crate::completion::{CompletionError, GetTokenUsage};
4use crate::providers::openai::responses_api::{
5 ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
6};
7use crate::streaming;
8use crate::streaming::RawStreamingChoice;
9use async_stream::stream;
10use futures::StreamExt;
11use reqwest::RequestBuilder;
12use reqwest_eventsource::Event;
13use reqwest_eventsource::RequestBuilderExt;
14use serde::{Deserialize, Serialize};
15use tracing::debug;
16
17use super::{CompletionResponse, Output};
18
19#[derive(Debug, Serialize, Deserialize, Clone)]
28#[serde(untagged)]
29pub enum StreamingCompletionChunk {
30 Response(Box<ResponseChunk>),
31 Delta(ItemChunk),
32}
33
34#[derive(Debug, Serialize, Deserialize, Clone)]
36pub struct StreamingCompletionResponse {
37 pub usage: ResponsesUsage,
39}
40
41impl GetTokenUsage for StreamingCompletionResponse {
42 fn token_usage(&self) -> Option<crate::completion::Usage> {
43 let mut usage = crate::completion::Usage::new();
44 usage.input_tokens = self.usage.input_tokens;
45 usage.output_tokens = self.usage.output_tokens;
46 usage.total_tokens = self.usage.total_tokens;
47 Some(usage)
48 }
49}
50
51#[derive(Debug, Serialize, Deserialize, Clone)]
53pub struct ResponseChunk {
54 #[serde(rename = "type")]
56 pub kind: ResponseChunkKind,
57 pub response: CompletionResponse,
59 pub sequence_number: u64,
61}
62
63#[derive(Debug, Serialize, Deserialize, Clone)]
66pub enum ResponseChunkKind {
67 #[serde(rename = "response.created")]
68 ResponseCreated,
69 #[serde(rename = "response.in_progress")]
70 ResponseInProgress,
71 #[serde(rename = "response.completed")]
72 ResponseCompleted,
73 #[serde(rename = "response.failed")]
74 ResponseFailed,
75 #[serde(rename = "response.incomplete")]
76 ResponseIncomplete,
77}
78
79#[derive(Debug, Serialize, Deserialize, Clone)]
82pub struct ItemChunk {
83 pub item_id: Option<String>,
85 pub output_index: u64,
87 #[serde(flatten)]
89 pub data: ItemChunkKind,
90}
91
92#[derive(Debug, Serialize, Deserialize, Clone)]
94#[serde(tag = "type")]
95pub enum ItemChunkKind {
96 #[serde(rename = "response.output_item.added")]
97 OutputItemAdded(StreamingItemDoneOutput),
98 #[serde(rename = "response.output_item.done")]
99 OutputItemDone(StreamingItemDoneOutput),
100 #[serde(rename = "response.content_part.added")]
101 ContentPartAdded(ContentPartChunk),
102 #[serde(rename = "response.content_part.done")]
103 ContentPartDone(ContentPartChunk),
104 #[serde(rename = "response.output_text.delta")]
105 OutputTextDelta(DeltaTextChunk),
106 #[serde(rename = "response.output_text.done")]
107 OutputTextDone(OutputTextChunk),
108 #[serde(rename = "response.refusal.delta")]
109 RefusalDelta(DeltaTextChunk),
110 #[serde(rename = "response.refusal.done")]
111 RefusalDone(RefusalTextChunk),
112 #[serde(rename = "response.function_call_arguments.delta")]
113 FunctionCallArgsDelta(DeltaTextChunk),
114 #[serde(rename = "response.function_call_arguments.done")]
115 FunctionCallArgsDone(ArgsTextChunk),
116 #[serde(rename = "response.reasoning_summary_part.added")]
117 ReasoningSummaryPartAdded(SummaryPartChunk),
118 #[serde(rename = "response.reasoning_summary_part.done")]
119 ReasoningSummaryPartDone(SummaryPartChunk),
120 #[serde(rename = "response.reasoning_summary_text.added")]
121 ReasoningSummaryTextAdded(SummaryTextChunk),
122 #[serde(rename = "response.reasoning_summary_text.done")]
123 ReasoningSummaryTextDone(SummaryTextChunk),
124}
125
126#[derive(Debug, Serialize, Deserialize, Clone)]
127pub struct StreamingItemDoneOutput {
128 pub sequence_number: u64,
129 pub item: Output,
130}
131
132#[derive(Debug, Serialize, Deserialize, Clone)]
133pub struct ContentPartChunk {
134 pub content_index: u64,
135 pub sequence_number: u64,
136 pub part: ContentPartChunkPart,
137}
138
139#[derive(Debug, Serialize, Deserialize, Clone)]
140#[serde(tag = "type")]
141pub enum ContentPartChunkPart {
142 OutputText { text: String },
143 SummaryText { text: String },
144}
145
146#[derive(Debug, Serialize, Deserialize, Clone)]
147pub struct DeltaTextChunk {
148 pub content_index: u64,
149 pub sequence_number: u64,
150 pub delta: String,
151}
152
153#[derive(Debug, Serialize, Deserialize, Clone)]
154pub struct OutputTextChunk {
155 pub content_index: u64,
156 pub sequence_number: u64,
157 pub text: String,
158}
159
160#[derive(Debug, Serialize, Deserialize, Clone)]
161pub struct RefusalTextChunk {
162 pub content_index: u64,
163 pub sequence_number: u64,
164 pub refusal: String,
165}
166
167#[derive(Debug, Serialize, Deserialize, Clone)]
168pub struct ArgsTextChunk {
169 pub content_index: u64,
170 pub sequence_number: u64,
171 pub arguments: serde_json::Value,
172}
173
174#[derive(Debug, Serialize, Deserialize, Clone)]
175pub struct SummaryPartChunk {
176 pub summary_index: u64,
177 pub sequence_number: u64,
178 pub part: SummaryPartChunkPart,
179}
180
181#[derive(Debug, Serialize, Deserialize, Clone)]
182pub struct SummaryTextChunk {
183 pub summary_index: u64,
184 pub sequence_number: u64,
185 pub delta: String,
186}
187
188#[derive(Debug, Serialize, Deserialize, Clone)]
189#[serde(tag = "type")]
190pub enum SummaryPartChunkPart {
191 SummaryText { text: String },
192}
193
194impl ResponsesCompletionModel {
195 pub(crate) async fn stream(
196 &self,
197 completion_request: crate::completion::CompletionRequest,
198 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
199 {
200 let mut request = self.create_completion_request(completion_request)?;
201 request.stream = Some(true);
202
203 tracing::debug!("Input: {}", serde_json::to_string_pretty(&request)?);
204
205 let builder = self.client.post("/responses").json(&request);
206 send_compatible_streaming_request(builder).await
207 }
208}
209
210pub async fn send_compatible_streaming_request(
215 request_builder: RequestBuilder,
216) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
217 let mut event_source = request_builder
219 .eventsource()
220 .expect("Cloning request must always succeed");
221
222 let stream = Box::pin(stream! {
223 let mut final_usage = ResponsesUsage::new();
224
225 let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
226
227 while let Some(event_result) = event_source.next().await {
228 match event_result {
229 Ok(Event::Open) => {
230 tracing::trace!("SSE connection opened");
231 continue;
232 }
233 Ok(Event::Message(message)) => {
234 if message.data.trim().is_empty() {
236 continue;
237 }
238
239 let data = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
240
241 let Ok(data) = data else {
242 let err = data.unwrap_err();
243 debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
244 continue;
245 };
246
247 if let StreamingCompletionChunk::Delta(chunk) = &data {
248 match &chunk.data {
249 ItemChunkKind::OutputItemDone(message) => {
250 match message {
251 StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } => {
252 tracing::debug!("Function call received: {func:?}");
253 tool_calls.push(streaming::RawStreamingChoice::ToolCall { id: func.id.clone(), call_id: Some(func.call_id.clone()), name: func.name.clone(), arguments: func.arguments.clone() });
254 }
255
256 StreamingItemDoneOutput { item: Output::Reasoning { summary, id }, .. } => {
257 let reasoning = summary
258 .iter()
259 .map(|x| {
260 let ReasoningSummary::SummaryText { text } = x;
261 text.to_owned()
262 })
263 .collect::<Vec<String>>()
264 .join("\n");
265 yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning, id: Some(id.to_string()) })
266 }
267 _ => continue
268 }
269 }
270 ItemChunkKind::OutputTextDelta(delta) => {
271 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
272 }
273 ItemChunkKind::RefusalDelta(delta) => {
274 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
275 }
276
277 _ => { continue }
278 }
279 }
280
281 if let StreamingCompletionChunk::Response(chunk) = data {
282 if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk {
283 if let Some(usage) = response.usage {
284 final_usage = usage;
285 }
286 } else {
287 continue;
288 }
289 }
290 }
291 Err(reqwest_eventsource::Error::StreamEnded) => {
292 break;
293 }
294 Err(error) => {
295 tracing::error!(?error, "SSE error");
296 yield Err(CompletionError::ResponseError(error.to_string()));
297 break;
298 }
299 }
300 }
301
302 event_source.close();
304
305 for tool_call in &tool_calls {
306 yield Ok(tool_call.to_owned())
307 }
308
309 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
310 usage: final_usage.clone()
311 }));
312 });
313
314 Ok(streaming::StreamingCompletionResponse::stream(stream))
315}