rig/providers/openrouter/
streaming.rs1use std::collections::HashMap;
2
3use crate::{
4 json_utils,
5 message::{ToolCall, ToolFunction},
6 streaming::{self},
7};
8use async_stream::stream;
9use futures::StreamExt;
10use reqwest::RequestBuilder;
11use serde_json::{json, Value};
12
13use crate::completion::{CompletionError, CompletionRequest};
14use serde::{Deserialize, Serialize};
15
16#[derive(Serialize, Deserialize, Debug)]
17pub struct StreamingCompletionResponse {
18 pub id: String,
19 pub choices: Vec<StreamingChoice>,
20 pub created: u64,
21 pub model: String,
22 pub object: String,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub system_fingerprint: Option<String>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub usage: Option<ResponseUsage>,
27}
28
29#[derive(Serialize, Deserialize, Debug)]
30pub struct StreamingChoice {
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub finish_reason: Option<String>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub native_finish_reason: Option<String>,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub logprobs: Option<Value>,
37 pub index: usize,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub message: Option<MessageResponse>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub delta: Option<DeltaResponse>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub error: Option<ErrorResponse>,
44}
45
46#[derive(Serialize, Deserialize, Debug)]
47pub struct MessageResponse {
48 pub role: String,
49 pub content: String,
50 #[serde(skip_serializing_if = "Option::is_none")]
51 pub refusal: Option<Value>,
52 #[serde(default)]
53 pub tool_calls: Vec<OpenRouterToolCall>,
54}
55
56#[derive(Serialize, Deserialize, Debug)]
57pub struct OpenRouterToolFunction {
58 pub name: Option<String>,
59 pub arguments: Option<String>,
60}
61
62#[derive(Serialize, Deserialize, Debug)]
63pub struct OpenRouterToolCall {
64 pub index: usize,
65 pub id: Option<String>,
66 pub r#type: Option<String>,
67 pub function: OpenRouterToolFunction,
68}
69
70#[derive(Serialize, Deserialize, Debug, Clone, Default)]
71pub struct ResponseUsage {
72 pub prompt_tokens: u32,
73 pub completion_tokens: u32,
74 pub total_tokens: u32,
75}
76
77#[derive(Serialize, Deserialize, Debug)]
78pub struct ErrorResponse {
79 pub code: i32,
80 pub message: String,
81 #[serde(skip_serializing_if = "Option::is_none")]
82 pub metadata: Option<HashMap<String, Value>>,
83}
84
85#[derive(Serialize, Deserialize, Debug)]
86pub struct DeltaResponse {
87 pub role: Option<String>,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 pub content: Option<String>,
90 #[serde(default)]
91 pub tool_calls: Vec<OpenRouterToolCall>,
92 #[serde(skip_serializing_if = "Option::is_none")]
93 pub native_finish_reason: Option<String>,
94}
95
96#[derive(Clone)]
97pub struct FinalCompletionResponse {
98 pub usage: ResponseUsage,
99}
100
101impl super::CompletionModel {
102 pub(crate) async fn stream(
103 &self,
104 completion_request: CompletionRequest,
105 ) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
106 {
107 let request = self.create_completion_request(completion_request)?;
108
109 let request = json_utils::merge(request, json!({"stream": true}));
110
111 let builder = self.client.post("/chat/completions").json(&request);
112
113 send_streaming_request(builder).await
114 }
115}
116
117pub async fn send_streaming_request(
118 request_builder: RequestBuilder,
119) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
120 let response = request_builder.send().await?;
121
122 if !response.status().is_success() {
123 return Err(CompletionError::ProviderError(format!(
124 "{}: {}",
125 response.status(),
126 response.text().await?
127 )));
128 }
129
130 let stream = Box::pin(stream! {
132 let mut stream = response.bytes_stream();
133 let mut tool_calls = HashMap::new();
134 let mut partial_line = String::new();
135 let mut final_usage = None;
136
137 while let Some(chunk_result) = stream.next().await {
138 let chunk = match chunk_result {
139 Ok(c) => c,
140 Err(e) => {
141 yield Err(CompletionError::from(e));
142 break;
143 }
144 };
145
146 let text = match String::from_utf8(chunk.to_vec()) {
147 Ok(t) => t,
148 Err(e) => {
149 yield Err(CompletionError::ResponseError(e.to_string()));
150 break;
151 }
152 };
153
154 for line in text.lines() {
155 let mut line = line.to_string();
156
157 if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
159 continue;
160 }
161
162 line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
164
165 if line.starts_with('{') && !line.ends_with('}') {
167 partial_line = line;
168 continue;
169 }
170
171 if !partial_line.is_empty() {
173 if line.ends_with('}') {
174 partial_line.push_str(&line);
175 line = partial_line;
176 partial_line = String::new();
177 } else {
178 partial_line.push_str(&line);
179 continue;
180 }
181 }
182
183 let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
184 Ok(data) => data,
185 Err(_) => {
186 continue;
187 }
188 };
189
190
191 let choice = data.choices.first().expect("Should have at least one choice");
192
193 if let Some(delta) = &choice.delta {
203 if !delta.tool_calls.is_empty() {
204 for tool_call in &delta.tool_calls {
205 let index = tool_call.index;
206
207 let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
209 id: String::new(),
210 function: ToolFunction {
211 name: String::new(),
212 arguments: serde_json::Value::Null,
213 },
214 });
215
216 if let Some(id) = &tool_call.id {
218 if !id.is_empty() {
219 existing_tool_call.id = id.clone();
220 }
221 }
222 if let Some(name) = &tool_call.function.name {
223 if !name.is_empty() {
224 existing_tool_call.function.name = name.clone();
225 }
226 }
227 if let Some(chunk) = &tool_call.function.arguments {
228 let current_args = match &existing_tool_call.function.arguments {
230 serde_json::Value::Null => String::new(),
231 serde_json::Value::String(s) => s.clone(),
232 v => v.to_string(),
233 };
234
235 let combined = format!("{current_args}{chunk}");
237
238 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
240 match serde_json::from_str(&combined) {
241 Ok(parsed) => existing_tool_call.function.arguments = parsed,
242 Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
243 }
244 } else {
245 existing_tool_call.function.arguments = serde_json::Value::String(combined);
246 }
247 }
248 }
249 }
250
251 if let Some(content) = &delta.content {
252 if !content.is_empty() {
253 yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
254 }
255 }
256
257 if let Some(usage) = data.usage {
258 final_usage = Some(usage);
259 }
260 }
261
262 if let Some(message) = &choice.message {
264 if !message.tool_calls.is_empty() {
265 for tool_call in &message.tool_calls {
266 let name = tool_call.function.name.clone();
267 let id = tool_call.id.clone();
268 let arguments = if let Some(args) = &tool_call.function.arguments {
269 match serde_json::from_str(args) {
271 Ok(v) => v,
272 Err(_) => serde_json::Value::String(args.to_string()),
273 }
274 } else {
275 serde_json::Value::Null
276 };
277 let index = tool_call.index;
278
279 tool_calls.insert(index, ToolCall{
280 id: id.unwrap_or_default(),
281 function: ToolFunction {
282 name: name.unwrap_or_default(),
283 arguments,
284 },
285 });
286 }
287 }
288
289 if !message.content.is_empty() {
290 yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
291 }
292 }
293 }
294 }
295
296 for (_, tool_call) in tool_calls.into_iter() {
297
298 yield Ok(streaming::RawStreamingChoice::ToolCall{
299 name: tool_call.function.name,
300 id: tool_call.id,
301 arguments: tool_call.function.arguments
302 });
303 }
304
305 yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
306 usage: final_usage.unwrap_or_default()
307 }))
308
309 });
310
311 Ok(streaming::StreamingCompletionResponse::stream(stream))
312}