rig/providers/openai/completion/
streaming.rs1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::json_utils;
3use crate::json_utils::merge;
4use crate::providers::openai::completion::{CompletionModel, Usage};
5use crate::streaming;
6use crate::streaming::RawStreamingChoice;
7use async_stream::stream;
8use futures::StreamExt;
9use reqwest::RequestBuilder;
10use reqwest_eventsource::Event;
11use reqwest_eventsource::RequestBuilderExt;
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14use std::collections::HashMap;
15use tracing::debug;
16
17#[derive(Debug, Serialize, Deserialize, Clone)]
21pub struct StreamingFunction {
22 #[serde(default)]
23 pub name: Option<String>,
24 #[serde(default)]
25 pub arguments: String,
26}
27
28#[derive(Debug, Serialize, Deserialize, Clone)]
29pub struct StreamingToolCall {
30 pub index: usize,
31 pub id: Option<String>,
32 pub function: StreamingFunction,
33}
34
35#[derive(Deserialize, Debug)]
36struct StreamingDelta {
37 #[serde(default)]
38 content: Option<String>,
39 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
40 tool_calls: Vec<StreamingToolCall>,
41}
42
43#[derive(Deserialize, Debug)]
44struct StreamingChoice {
45 delta: StreamingDelta,
46}
47
48#[derive(Deserialize, Debug)]
49struct StreamingCompletionChunk {
50 choices: Vec<StreamingChoice>,
51 usage: Option<Usage>,
52}
53
54#[derive(Clone, Serialize, Deserialize)]
55pub struct StreamingCompletionResponse {
56 pub usage: Usage,
57}
58
59impl GetTokenUsage for StreamingCompletionResponse {
60 fn token_usage(&self) -> Option<crate::completion::Usage> {
61 let mut usage = crate::completion::Usage::new();
62 usage.input_tokens = self.usage.prompt_tokens as u64;
63 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
64 usage.total_tokens = self.usage.total_tokens as u64;
65 Some(usage)
66 }
67}
68
69impl CompletionModel {
70 pub(crate) async fn stream(
71 &self,
72 completion_request: CompletionRequest,
73 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
74 {
75 let mut request = self.create_completion_request(completion_request)?;
76 request = merge(
77 request,
78 json!({"stream": true, "stream_options": {"include_usage": true}}),
79 );
80
81 let builder = self.client.post("/chat/completions").json(&request);
82 send_compatible_streaming_request(builder).await
83 }
84}
85
86pub async fn send_compatible_streaming_request(
87 request_builder: RequestBuilder,
88) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
89 let mut event_source = request_builder
91 .eventsource()
92 .expect("Cloning request must always succeed");
93
94 let stream = Box::pin(stream! {
95 let mut final_usage = Usage::new();
96
97 let mut tool_calls: HashMap<usize, (String, String, String)> = HashMap::new();
99
100 while let Some(event_result) = event_source.next().await {
101 match event_result {
102 Ok(Event::Open) => {
103 tracing::trace!("SSE connection opened");
104 continue;
105 }
106 Ok(Event::Message(message)) => {
107 if message.data.trim().is_empty() || message.data == "[DONE]" {
108 continue;
109 }
110
111 let data = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
112 let Ok(data) = data else {
113 let err = data.unwrap_err();
114 debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
115 continue;
116 };
117
118 if let Some(choice) = data.choices.first() {
119 let delta = &choice.delta;
120
121 if !delta.tool_calls.is_empty() {
123 for tool_call in &delta.tool_calls {
124 let function = tool_call.function.clone();
125
126 if function.name.is_some() && function.arguments.is_empty() {
128 let id = tool_call.id.clone().unwrap_or_default();
129 tool_calls.insert(
130 tool_call.index,
131 (id, function.name.clone().unwrap(), "".to_string()),
132 );
133 }
134 else if function.name.clone().is_none_or(|s| s.is_empty())
138 && !function.arguments.is_empty()
139 {
140 if let Some((id, name, arguments)) =
141 tool_calls.get(&tool_call.index)
142 {
143 let new_arguments = &tool_call.function.arguments;
144 let arguments = format!("{arguments}{new_arguments}");
145 tool_calls.insert(
146 tool_call.index,
147 (id.clone(), name.clone(), arguments),
148 );
149 } else {
150 debug!("Partial tool call received but tool call was never started.");
151 }
152 }
153 else {
155 let id = tool_call.id.clone().unwrap_or_default();
156 let name = function.name.expect("tool call should have a name");
157 let arguments = function.arguments;
158 let Ok(arguments) = serde_json::from_str(&arguments) else {
159 debug!("Couldn't serialize '{arguments}' as JSON");
160 continue;
161 };
162
163 yield Ok(streaming::RawStreamingChoice::ToolCall {
164 id,
165 name,
166 arguments,
167 call_id: None,
168 });
169 }
170 }
171 }
172
173 if let Some(content) = &choice.delta.content {
175 yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
176 }
177 }
178
179 if let Some(usage) = data.usage {
181 final_usage = usage.clone();
182 }
183 }
184 Err(reqwest_eventsource::Error::StreamEnded) => {
185 break;
186 }
187 Err(error) => {
188 tracing::error!(?error, "SSE error");
189 yield Err(CompletionError::ResponseError(error.to_string()));
190 break;
191 }
192 }
193 }
194
195 event_source.close();
197
198 for (_, (id, name, arguments)) in tool_calls {
200 let Ok(arguments) = serde_json::from_str(&arguments) else {
201 continue;
202 };
203
204 yield Ok(RawStreamingChoice::ToolCall {
205 id,
206 name,
207 arguments,
208 call_id: None,
209 });
210 }
211
212 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
213 usage: final_usage.clone()
214 }));
215 });
216
217 Ok(streaming::StreamingCompletionResponse::stream(stream))
218}