rig/providers/openrouter/
streaming.rs1use std::collections::HashMap;
2
3use crate::{
4 json_utils,
5 message::{ToolCall, ToolFunction},
6 streaming::{self},
7};
8use async_stream::stream;
9use futures::StreamExt;
10use reqwest::RequestBuilder;
11use serde_json::{Value, json};
12
13use crate::completion::{CompletionError, CompletionRequest};
14use serde::{Deserialize, Serialize};
15
16#[derive(Serialize, Deserialize, Debug)]
17pub struct StreamingCompletionResponse {
18 pub id: String,
19 pub choices: Vec<StreamingChoice>,
20 pub created: u64,
21 pub model: String,
22 pub object: String,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub system_fingerprint: Option<String>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub usage: Option<ResponseUsage>,
27}
28
29#[derive(Serialize, Deserialize, Debug)]
30pub struct StreamingChoice {
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub finish_reason: Option<String>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub native_finish_reason: Option<String>,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub logprobs: Option<Value>,
37 pub index: usize,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub message: Option<MessageResponse>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub delta: Option<DeltaResponse>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub error: Option<ErrorResponse>,
44}
45
46#[derive(Serialize, Deserialize, Debug)]
47pub struct MessageResponse {
48 pub role: String,
49 pub content: String,
50 #[serde(skip_serializing_if = "Option::is_none")]
51 pub refusal: Option<Value>,
52 #[serde(default)]
53 pub tool_calls: Vec<OpenRouterToolCall>,
54}
55
56#[derive(Serialize, Deserialize, Debug)]
57pub struct OpenRouterToolFunction {
58 pub name: Option<String>,
59 pub arguments: Option<String>,
60}
61
62#[derive(Serialize, Deserialize, Debug)]
63pub struct OpenRouterToolCall {
64 pub index: usize,
65 pub id: Option<String>,
66 pub r#type: Option<String>,
67 pub function: OpenRouterToolFunction,
68}
69
70#[derive(Serialize, Deserialize, Debug, Clone, Default)]
71pub struct ResponseUsage {
72 pub prompt_tokens: u32,
73 pub completion_tokens: u32,
74 pub total_tokens: u32,
75}
76
77#[derive(Serialize, Deserialize, Debug)]
78pub struct ErrorResponse {
79 pub code: i32,
80 pub message: String,
81 #[serde(skip_serializing_if = "Option::is_none")]
82 pub metadata: Option<HashMap<String, Value>>,
83}
84
85#[derive(Serialize, Deserialize, Debug)]
86pub struct DeltaResponse {
87 pub role: Option<String>,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 pub content: Option<String>,
90 #[serde(default)]
91 pub tool_calls: Vec<OpenRouterToolCall>,
92 #[serde(skip_serializing_if = "Option::is_none")]
93 pub native_finish_reason: Option<String>,
94}
95
96#[derive(Clone)]
97pub struct FinalCompletionResponse {
98 pub usage: ResponseUsage,
99}
100
101impl super::CompletionModel {
102 pub(crate) async fn stream(
103 &self,
104 completion_request: CompletionRequest,
105 ) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
106 {
107 let request = self.create_completion_request(completion_request)?;
108
109 let request = json_utils::merge(request, json!({"stream": true}));
110
111 let builder = self.client.post("/chat/completions").json(&request);
112
113 send_streaming_request(builder).await
114 }
115}
116
117pub async fn send_streaming_request(
118 request_builder: RequestBuilder,
119) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
120 let response = request_builder.send().await?;
121
122 if !response.status().is_success() {
123 return Err(CompletionError::ProviderError(format!(
124 "{}: {}",
125 response.status(),
126 response.text().await?
127 )));
128 }
129
130 let stream = Box::pin(stream! {
132 let mut stream = response.bytes_stream();
133 let mut tool_calls = HashMap::new();
134 let mut partial_line = String::new();
135 let mut final_usage = None;
136
137 while let Some(chunk_result) = stream.next().await {
138 let chunk = match chunk_result {
139 Ok(c) => c,
140 Err(e) => {
141 yield Err(CompletionError::from(e));
142 break;
143 }
144 };
145
146 let text = match String::from_utf8(chunk.to_vec()) {
147 Ok(t) => t,
148 Err(e) => {
149 yield Err(CompletionError::ResponseError(e.to_string()));
150 break;
151 }
152 };
153
154 for line in text.lines() {
155 let mut line = line.to_string();
156
157 if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
159 continue;
160 }
161
162 line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
164
165 if line.starts_with('{') && !line.ends_with('}') {
167 partial_line = line;
168 continue;
169 }
170
171 if !partial_line.is_empty() {
173 if line.ends_with('}') {
174 partial_line.push_str(&line);
175 line = partial_line;
176 partial_line = String::new();
177 } else {
178 partial_line.push_str(&line);
179 continue;
180 }
181 }
182
183 let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
184 Ok(data) => data,
185 Err(_) => {
186 continue;
187 }
188 };
189
190
191 let choice = data.choices.first().expect("Should have at least one choice");
192
193 if let Some(delta) = &choice.delta {
203 if !delta.tool_calls.is_empty() {
204 for tool_call in &delta.tool_calls {
205 let index = tool_call.index;
206
207 let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
209 id: String::new(),
210 call_id: None,
211 function: ToolFunction {
212 name: String::new(),
213 arguments: serde_json::Value::Null,
214 },
215 });
216
217 if let Some(id) = &tool_call.id {
219 if !id.is_empty() {
220 existing_tool_call.id = id.clone();
221 }
222 }
223 if let Some(name) = &tool_call.function.name {
224 if !name.is_empty() {
225 existing_tool_call.function.name = name.clone();
226 }
227 }
228 if let Some(chunk) = &tool_call.function.arguments {
229 let current_args = match &existing_tool_call.function.arguments {
231 serde_json::Value::Null => String::new(),
232 serde_json::Value::String(s) => s.clone(),
233 v => v.to_string(),
234 };
235
236 let combined = format!("{current_args}{chunk}");
238
239 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
241 match serde_json::from_str(&combined) {
242 Ok(parsed) => existing_tool_call.function.arguments = parsed,
243 Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
244 }
245 } else {
246 existing_tool_call.function.arguments = serde_json::Value::String(combined);
247 }
248 }
249 }
250 }
251
252 if let Some(content) = &delta.content {
253 if !content.is_empty() {
254 yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
255 }
256 }
257
258 if let Some(usage) = data.usage {
259 final_usage = Some(usage);
260 }
261 }
262
263 if let Some(message) = &choice.message {
265 if !message.tool_calls.is_empty() {
266 for tool_call in &message.tool_calls {
267 let name = tool_call.function.name.clone();
268 let id = tool_call.id.clone();
269 let arguments = if let Some(args) = &tool_call.function.arguments {
270 match serde_json::from_str(args) {
272 Ok(v) => v,
273 Err(_) => serde_json::Value::String(args.to_string()),
274 }
275 } else {
276 serde_json::Value::Null
277 };
278 let index = tool_call.index;
279
280 tool_calls.insert(index, ToolCall {
281 id: id.unwrap_or_default(),
282 call_id: None,
283 function: ToolFunction {
284 name: name.unwrap_or_default(),
285 arguments,
286 },
287 });
288 }
289 }
290
291 if !message.content.is_empty() {
292 yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
293 }
294 }
295 }
296 }
297
298 for (_, tool_call) in tool_calls.into_iter() {
299
300 yield Ok(streaming::RawStreamingChoice::ToolCall{
301 name: tool_call.function.name,
302 id: tool_call.id,
303 arguments: tool_call.function.arguments,
304 call_id: None
305 });
306 }
307
308 yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
309 usage: final_usage.unwrap_or_default()
310 }))
311
312 });
313
314 Ok(streaming::StreamingCompletionResponse::stream(stream))
315}