rig/providers/openai/responses_api/
streaming.rs1use crate::completion::{CompletionError, GetTokenUsage};
4use crate::providers::openai::responses_api::{
5 ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
6};
7use crate::streaming;
8use crate::streaming::RawStreamingChoice;
9use async_stream::stream;
10use futures::StreamExt;
11use reqwest::RequestBuilder;
12use serde::{Deserialize, Serialize};
13use tracing::debug;
14
15use super::{CompletionResponse, Output};
16
17#[derive(Debug, Serialize, Deserialize, Clone)]
26#[serde(untagged)]
27pub enum StreamingCompletionChunk {
28 Response(Box<ResponseChunk>),
29 Delta(ItemChunk),
30}
31
32#[derive(Debug, Serialize, Deserialize, Clone)]
34pub struct StreamingCompletionResponse {
35 pub usage: ResponsesUsage,
37}
38
39impl GetTokenUsage for StreamingCompletionResponse {
40 fn token_usage(&self) -> Option<crate::completion::Usage> {
41 let mut usage = crate::completion::Usage::new();
42 usage.input_tokens = self.usage.input_tokens;
43 usage.output_tokens = self.usage.output_tokens;
44 usage.total_tokens = self.usage.total_tokens;
45 Some(usage)
46 }
47}
48
49#[derive(Debug, Serialize, Deserialize, Clone)]
51pub struct ResponseChunk {
52 #[serde(rename = "type")]
54 pub kind: ResponseChunkKind,
55 pub response: CompletionResponse,
57 pub sequence_number: u64,
59}
60
61#[derive(Debug, Serialize, Deserialize, Clone)]
64pub enum ResponseChunkKind {
65 #[serde(rename = "response.created")]
66 ResponseCreated,
67 #[serde(rename = "response.in_progress")]
68 ResponseInProgress,
69 #[serde(rename = "response.completed")]
70 ResponseCompleted,
71 #[serde(rename = "response.failed")]
72 ResponseFailed,
73 #[serde(rename = "response.incomplete")]
74 ResponseIncomplete,
75}
76
77#[derive(Debug, Serialize, Deserialize, Clone)]
80pub struct ItemChunk {
81 pub item_id: Option<String>,
83 pub output_index: u64,
85 #[serde(flatten)]
87 pub data: ItemChunkKind,
88}
89
90#[derive(Debug, Serialize, Deserialize, Clone)]
92#[serde(tag = "type")]
93pub enum ItemChunkKind {
94 #[serde(rename = "response.output_item.added")]
95 OutputItemAdded(StreamingItemDoneOutput),
96 #[serde(rename = "response.output_item.done")]
97 OutputItemDone(StreamingItemDoneOutput),
98 #[serde(rename = "response.content_part.added")]
99 ContentPartAdded(ContentPartChunk),
100 #[serde(rename = "response.content_part.done")]
101 ContentPartDone(ContentPartChunk),
102 #[serde(rename = "response.output_text.delta")]
103 OutputTextDelta(DeltaTextChunk),
104 #[serde(rename = "response.output_text.done")]
105 OutputTextDone(OutputTextChunk),
106 #[serde(rename = "response.refusal.delta")]
107 RefusalDelta(DeltaTextChunk),
108 #[serde(rename = "response.refusal.done")]
109 RefusalDone(RefusalTextChunk),
110 #[serde(rename = "response.function_call_arguments.delta")]
111 FunctionCallArgsDelta(DeltaTextChunk),
112 #[serde(rename = "response.function_call_arguments.done")]
113 FunctionCallArgsDone(ArgsTextChunk),
114 #[serde(rename = "response.reasoning_summary_part.added")]
115 ReasoningSummaryPartAdded(SummaryPartChunk),
116 #[serde(rename = "response.reasoning_summary_part.done")]
117 ReasoningSummaryPartDone(SummaryPartChunk),
118 #[serde(rename = "response.reasoning_summary_text.added")]
119 ReasoningSummaryTextAdded(SummaryTextChunk),
120 #[serde(rename = "response.reasoning_summary_text.done")]
121 ReasoningSummaryTextDone(SummaryTextChunk),
122}
123
124#[derive(Debug, Serialize, Deserialize, Clone)]
125pub struct StreamingItemDoneOutput {
126 pub sequence_number: u64,
127 pub item: Output,
128}
129
130#[derive(Debug, Serialize, Deserialize, Clone)]
131pub struct ContentPartChunk {
132 pub content_index: u64,
133 pub sequence_number: u64,
134 pub part: ContentPartChunkPart,
135}
136
137#[derive(Debug, Serialize, Deserialize, Clone)]
138#[serde(tag = "type")]
139pub enum ContentPartChunkPart {
140 OutputText { text: String },
141 SummaryText { text: String },
142}
143
144#[derive(Debug, Serialize, Deserialize, Clone)]
145pub struct DeltaTextChunk {
146 pub content_index: u64,
147 pub sequence_number: u64,
148 pub delta: String,
149}
150
151#[derive(Debug, Serialize, Deserialize, Clone)]
152pub struct OutputTextChunk {
153 pub content_index: u64,
154 pub sequence_number: u64,
155 pub text: String,
156}
157
158#[derive(Debug, Serialize, Deserialize, Clone)]
159pub struct RefusalTextChunk {
160 pub content_index: u64,
161 pub sequence_number: u64,
162 pub refusal: String,
163}
164
165#[derive(Debug, Serialize, Deserialize, Clone)]
166pub struct ArgsTextChunk {
167 pub content_index: u64,
168 pub sequence_number: u64,
169 pub arguments: serde_json::Value,
170}
171
172#[derive(Debug, Serialize, Deserialize, Clone)]
173pub struct SummaryPartChunk {
174 pub summary_index: u64,
175 pub sequence_number: u64,
176 pub part: SummaryPartChunkPart,
177}
178
179#[derive(Debug, Serialize, Deserialize, Clone)]
180pub struct SummaryTextChunk {
181 pub summary_index: u64,
182 pub sequence_number: u64,
183 pub delta: String,
184}
185
186#[derive(Debug, Serialize, Deserialize, Clone)]
187#[serde(tag = "type")]
188pub enum SummaryPartChunkPart {
189 SummaryText { text: String },
190}
191
192impl ResponsesCompletionModel {
193 pub(crate) async fn stream(
194 &self,
195 completion_request: crate::completion::CompletionRequest,
196 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
197 {
198 let mut request = self.create_completion_request(completion_request)?;
199 request.stream = Some(true);
200
201 tracing::debug!("Input: {}", serde_json::to_string_pretty(&request)?);
202
203 let builder = self.client.post("/responses").json(&request);
204 send_compatible_streaming_request(builder).await
205 }
206}
207
208pub async fn send_compatible_streaming_request(
209 request_builder: RequestBuilder,
210) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
211 let response = request_builder.send().await?;
212
213 if !response.status().is_success() {
214 return Err(CompletionError::ProviderError(format!(
215 "{}: {}",
216 response.status(),
217 response.text().await?
218 )));
219 }
220
221 let inner = Box::pin(stream! {
223 let mut stream = response.bytes_stream();
224
225 let mut final_usage = ResponsesUsage::new();
226
227 let mut partial_data = None;
228
229 let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
230
231 while let Some(chunk_result) = stream.next().await {
232 let chunk = match chunk_result {
233 Ok(c) => c,
234 Err(e) => {
235 yield Err(CompletionError::from(e));
236 break;
237 }
238 };
239
240 let text = match String::from_utf8(chunk.to_vec()) {
241 Ok(t) => t,
242 Err(e) => {
243 yield Err(CompletionError::ResponseError(e.to_string()));
244 break;
245 }
246 };
247
248 for line in text.lines() {
249 let mut line = line.to_string();
250
251 if partial_data.is_some() {
253 line = format!("{}{}", partial_data.unwrap(), line);
254 partial_data = None;
255 }
256 else {
258 let Some(data) = line.strip_prefix("data: ") else {
259 continue;
260 };
261
262 if !line.ends_with("}") {
264 partial_data = Some(data.to_string());
265 } else {
266 line = data.to_string();
267 }
268 }
269
270 let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
271
272 let Ok(data) = data else {
273 let err = data.unwrap_err();
274 debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
275 continue;
276 };
277
278 debug!("Data get: {data:?}");
279
280
281 if let StreamingCompletionChunk::Delta(chunk) = &data {
282 match &chunk.data {
283 ItemChunkKind::OutputItemDone(message) => {
284 match message {
285 StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } => {
286 tracing::debug!("Function call received: {func:?}");
287 tool_calls.push(streaming::RawStreamingChoice::ToolCall { id: func.id.clone(), call_id: Some(func.call_id.clone()), name: func.name.clone(), arguments: func.arguments.clone() });
288 }
289
290 StreamingItemDoneOutput { item: Output::Reasoning { summary, .. }, .. } => {
291 let reasoning = summary
292 .iter()
293 .map(|x| {
294 let ReasoningSummary::SummaryText { text } = x;
295 text.to_owned()
296 })
297 .collect::<Vec<String>>()
298 .join("\n");
299 yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning, id: None })
300 }
301 _ => continue
302 }
303 }
304 ItemChunkKind::OutputTextDelta(delta) => {
305 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
306 }
307 ItemChunkKind::RefusalDelta(delta) => {
308 yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
309 }
310
311 _ => { continue }
312 }
313 }
314
315 if let StreamingCompletionChunk::Response(chunk) = data && let Some(usage) = chunk.response.usage {
316 final_usage = usage;
317 }
318 }
319 }
320
321 for tool_call in tool_calls {
322 yield Ok(tool_call)
323 }
324
325 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
326 usage: final_usage.clone()
327 }))
328 });
329
330 Ok(streaming::StreamingCompletionResponse::stream(inner))
331}