rig/providers/openai/
streaming.rs1use super::completion::CompletionModel;
2use crate::completion::{CompletionError, CompletionRequest};
3use crate::json_utils;
4use crate::json_utils::merge;
5use crate::streaming;
6use crate::streaming::{StreamingCompletionModel, StreamingResult};
7use async_stream::stream;
8use futures::StreamExt;
9use reqwest::RequestBuilder;
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12use std::collections::HashMap;
13
14#[derive(Debug, Serialize, Deserialize, Clone)]
18pub struct StreamingFunction {
19 #[serde(default)]
20 name: Option<String>,
21 #[serde(default)]
22 arguments: String,
23}
24
25#[derive(Debug, Serialize, Deserialize, Clone)]
26pub struct StreamingToolCall {
27 pub index: usize,
28 pub function: StreamingFunction,
29}
30
31#[derive(Deserialize)]
32struct StreamingDelta {
33 #[serde(default)]
34 content: Option<String>,
35 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
36 tool_calls: Vec<StreamingToolCall>,
37}
38
39#[derive(Deserialize)]
40struct StreamingChoice {
41 delta: StreamingDelta,
42}
43
44#[derive(Deserialize)]
45struct StreamingCompletionResponse {
46 choices: Vec<StreamingChoice>,
47}
48
49impl StreamingCompletionModel for CompletionModel {
50 async fn stream(
51 &self,
52 completion_request: CompletionRequest,
53 ) -> Result<StreamingResult, CompletionError> {
54 let mut request = self.create_completion_request(completion_request)?;
55 request = merge(request, json!({"stream": true}));
56
57 let builder = self.client.post("/chat/completions").json(&request);
58 send_compatible_streaming_request(builder).await
59 }
60}
61
62pub async fn send_compatible_streaming_request(
63 request_builder: RequestBuilder,
64) -> Result<StreamingResult, CompletionError> {
65 let response = request_builder.send().await?;
66
67 if !response.status().is_success() {
68 return Err(CompletionError::ProviderError(format!(
69 "{}: {}",
70 response.status(),
71 response.text().await?
72 )));
73 }
74
75 Ok(Box::pin(stream! {
77 let mut stream = response.bytes_stream();
78
79 let mut partial_data = None;
80 let mut calls: HashMap<usize, (String, String)> = HashMap::new();
81
82 while let Some(chunk_result) = stream.next().await {
83 let chunk = match chunk_result {
84 Ok(c) => c,
85 Err(e) => {
86 yield Err(CompletionError::from(e));
87 break;
88 }
89 };
90
91 let text = match String::from_utf8(chunk.to_vec()) {
92 Ok(t) => t,
93 Err(e) => {
94 yield Err(CompletionError::ResponseError(e.to_string()));
95 break;
96 }
97 };
98
99
100 for line in text.lines() {
101 let mut line = line.to_string();
102
103
104
105 if partial_data.is_some() {
107 line = format!("{}{}", partial_data.unwrap(), line);
108 partial_data = None;
109 }
110 else {
112 let Some(data) = line.strip_prefix("data: ") else {
113 continue;
114 };
115
116 if !line.ends_with("}") {
118 partial_data = Some(data.to_string());
119 } else {
120 line = data.to_string();
121 }
122 }
123
124 let data = serde_json::from_str::<StreamingCompletionResponse>(&line);
125
126 let Ok(data) = data else {
127 continue;
128 };
129
130 let choice = data.choices.first().expect("Should have at least one choice");
131
132 let delta = &choice.delta;
133
134 if !delta.tool_calls.is_empty() {
135 for tool_call in &delta.tool_calls {
136 let function = tool_call.function.clone();
137
138 if function.name.is_some() && function.arguments.is_empty() {
142 calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string()));
143 }
144 else if function.name.is_none() && !function.arguments.is_empty() {
148 let Some((name, arguments)) = calls.get(&tool_call.index) else {
149 continue;
150 };
151
152 let new_arguments = &tool_call.function.arguments;
153 let arguments = format!("{}{}", arguments, new_arguments);
154
155 calls.insert(tool_call.index, (name.clone(), arguments));
156 }
157 else {
159 let name = function.name.unwrap();
160 let arguments = function.arguments;
161 let Ok(arguments) = serde_json::from_str(&arguments) else {
162 continue;
163 };
164
165 yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
166 }
167 }
168 }
169
170 if let Some(content) = &choice.delta.content {
171 yield Ok(streaming::StreamingChoice::Message(content.clone()))
172 }
173 }
174 }
175
176 for (_, (name, arguments)) in calls {
177 let Ok(arguments) = serde_json::from_str(&arguments) else {
178 continue;
179 };
180
181 yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
182 }
183 }))
184}