rig/providers/openrouter/
streaming.rs1use std::collections::HashMap;
2
3use crate::{
4 completion::GetTokenUsage,
5 json_utils,
6 message::{ToolCall, ToolFunction},
7 streaming::{self},
8};
9use async_stream::stream;
10use futures::StreamExt;
11use reqwest::RequestBuilder;
12use serde_json::{Value, json};
13
14use crate::completion::{CompletionError, CompletionRequest};
15use serde::{Deserialize, Serialize};
16
17#[derive(Serialize, Deserialize, Debug)]
18pub struct StreamingCompletionResponse {
19 pub id: String,
20 pub choices: Vec<StreamingChoice>,
21 pub created: u64,
22 pub model: String,
23 pub object: String,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub system_fingerprint: Option<String>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub usage: Option<ResponseUsage>,
28}
29
30impl GetTokenUsage for FinalCompletionResponse {
31 fn token_usage(&self) -> Option<crate::completion::Usage> {
32 let mut usage = crate::completion::Usage::new();
33
34 usage.input_tokens = self.usage.prompt_tokens as u64;
35 usage.output_tokens = self.usage.completion_tokens as u64;
36 usage.total_tokens = self.usage.total_tokens as u64;
37
38 Some(usage)
39 }
40}
41
42#[derive(Serialize, Deserialize, Debug)]
43pub struct StreamingChoice {
44 #[serde(skip_serializing_if = "Option::is_none")]
45 pub finish_reason: Option<String>,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 pub native_finish_reason: Option<String>,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 pub logprobs: Option<Value>,
50 pub index: usize,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub message: Option<MessageResponse>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub delta: Option<DeltaResponse>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub error: Option<ErrorResponse>,
57}
58
59#[derive(Serialize, Deserialize, Debug)]
60pub struct MessageResponse {
61 pub role: String,
62 pub content: String,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub refusal: Option<Value>,
65 #[serde(default)]
66 pub tool_calls: Vec<OpenRouterToolCall>,
67}
68
69#[derive(Serialize, Deserialize, Debug)]
70pub struct OpenRouterToolFunction {
71 pub name: Option<String>,
72 pub arguments: Option<String>,
73}
74
75#[derive(Serialize, Deserialize, Debug)]
76pub struct OpenRouterToolCall {
77 pub index: usize,
78 pub id: Option<String>,
79 pub r#type: Option<String>,
80 pub function: OpenRouterToolFunction,
81}
82
83#[derive(Serialize, Deserialize, Debug, Clone, Default)]
84pub struct ResponseUsage {
85 pub prompt_tokens: u32,
86 pub completion_tokens: u32,
87 pub total_tokens: u32,
88}
89
90#[derive(Serialize, Deserialize, Debug)]
91pub struct ErrorResponse {
92 pub code: i32,
93 pub message: String,
94 #[serde(skip_serializing_if = "Option::is_none")]
95 pub metadata: Option<HashMap<String, Value>>,
96}
97
98#[derive(Serialize, Deserialize, Debug)]
99pub struct DeltaResponse {
100 pub role: Option<String>,
101 #[serde(skip_serializing_if = "Option::is_none")]
102 pub content: Option<String>,
103 #[serde(default)]
104 pub tool_calls: Vec<OpenRouterToolCall>,
105 #[serde(skip_serializing_if = "Option::is_none")]
106 pub native_finish_reason: Option<String>,
107}
108
109#[derive(Clone, Deserialize, Serialize)]
110pub struct FinalCompletionResponse {
111 pub usage: ResponseUsage,
112}
113
114impl super::CompletionModel {
115 pub(crate) async fn stream(
116 &self,
117 completion_request: CompletionRequest,
118 ) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
119 {
120 let request = self.create_completion_request(completion_request)?;
121
122 let request = json_utils::merge(request, json!({"stream": true}));
123
124 let builder = self.client.post("/chat/completions").json(&request);
125
126 send_streaming_request(builder).await
127 }
128}
129
130pub async fn send_streaming_request(
131 request_builder: RequestBuilder,
132) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
133 let response = request_builder.send().await?;
134
135 if !response.status().is_success() {
136 return Err(CompletionError::ProviderError(format!(
137 "{}: {}",
138 response.status(),
139 response.text().await?
140 )));
141 }
142
143 let stream = Box::pin(stream! {
145 let mut stream = response.bytes_stream();
146 let mut tool_calls = HashMap::new();
147 let mut partial_line = String::new();
148 let mut final_usage = None;
149
150 while let Some(chunk_result) = stream.next().await {
151 let chunk = match chunk_result {
152 Ok(c) => c,
153 Err(e) => {
154 yield Err(CompletionError::from(e));
155 break;
156 }
157 };
158
159 let text = match String::from_utf8(chunk.to_vec()) {
160 Ok(t) => t,
161 Err(e) => {
162 yield Err(CompletionError::ResponseError(e.to_string()));
163 break;
164 }
165 };
166
167 for line in text.lines() {
168 let mut line = line.to_string();
169
170 if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
172 continue;
173 }
174
175 line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
177
178 if line.starts_with('{') && !line.ends_with('}') {
180 partial_line = line;
181 continue;
182 }
183
184 if !partial_line.is_empty() {
186 if line.ends_with('}') {
187 partial_line.push_str(&line);
188 line = partial_line;
189 partial_line = String::new();
190 } else {
191 partial_line.push_str(&line);
192 continue;
193 }
194 }
195
196 let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
197 Ok(data) => data,
198 Err(_) => {
199 continue;
200 }
201 };
202
203
204 let choice = data.choices.first().expect("Should have at least one choice");
205
206 if let Some(delta) = &choice.delta {
216 if !delta.tool_calls.is_empty() {
217 for tool_call in &delta.tool_calls {
218 let index = tool_call.index;
219
220 let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
222 id: String::new(),
223 call_id: None,
224 function: ToolFunction {
225 name: String::new(),
226 arguments: serde_json::Value::Null,
227 },
228 });
229
230 if let Some(id) = &tool_call.id && !id.is_empty() {
232 existing_tool_call.id = id.clone();
233 }
234
235 if let Some(name) = &tool_call.function.name && !name.is_empty() {
236 existing_tool_call.function.name = name.clone();
237 }
238
239 if let Some(chunk) = &tool_call.function.arguments {
240 let current_args = match &existing_tool_call.function.arguments {
242 serde_json::Value::Null => String::new(),
243 serde_json::Value::String(s) => s.clone(),
244 v => v.to_string(),
245 };
246
247 let combined = format!("{current_args}{chunk}");
249
250 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
252 match serde_json::from_str(&combined) {
253 Ok(parsed) => existing_tool_call.function.arguments = parsed,
254 Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
255 }
256 } else {
257 existing_tool_call.function.arguments = serde_json::Value::String(combined);
258 }
259 }
260 }
261 }
262
263 if let Some(content) = &delta.content &&!content.is_empty() {
264 yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
265 }
266
267 if let Some(usage) = data.usage {
268 final_usage = Some(usage);
269 }
270 }
271
272 if let Some(message) = &choice.message {
274 if !message.tool_calls.is_empty() {
275 for tool_call in &message.tool_calls {
276 let name = tool_call.function.name.clone();
277 let id = tool_call.id.clone();
278 let arguments = if let Some(args) = &tool_call.function.arguments {
279 match serde_json::from_str(args) {
281 Ok(v) => v,
282 Err(_) => serde_json::Value::String(args.to_string()),
283 }
284 } else {
285 serde_json::Value::Null
286 };
287 let index = tool_call.index;
288
289 tool_calls.insert(index, ToolCall {
290 id: id.unwrap_or_default(),
291 call_id: None,
292 function: ToolFunction {
293 name: name.unwrap_or_default(),
294 arguments,
295 },
296 });
297 }
298 }
299
300 if !message.content.is_empty() {
301 yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
302 }
303 }
304 }
305 }
306
307 for (_, tool_call) in tool_calls.into_iter() {
308
309 yield Ok(streaming::RawStreamingChoice::ToolCall{
310 name: tool_call.function.name,
311 id: tool_call.id,
312 arguments: tool_call.function.arguments,
313 call_id: None
314 });
315 }
316
317 yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
318 usage: final_usage.unwrap_or_default()
319 }))
320
321 });
322
323 Ok(streaming::StreamingCompletionResponse::stream(stream))
324}