rig/providers/openrouter/
streaming.rs1use reqwest_eventsource::{Event, RequestBuilderExt};
2use std::collections::HashMap;
3
4use crate::{
5 completion::GetTokenUsage,
6 json_utils,
7 message::{ToolCall, ToolFunction},
8 streaming::{self},
9};
10use async_stream::stream;
11use futures::StreamExt;
12use reqwest::RequestBuilder;
13use serde_json::{Value, json};
14
15use crate::completion::{CompletionError, CompletionRequest};
16use serde::{Deserialize, Serialize};
17
18#[derive(Serialize, Deserialize, Debug)]
19pub struct StreamingCompletionResponse {
20 pub id: String,
21 pub choices: Vec<StreamingChoice>,
22 pub created: u64,
23 pub model: String,
24 pub object: String,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub system_fingerprint: Option<String>,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub usage: Option<ResponseUsage>,
29}
30
31impl GetTokenUsage for FinalCompletionResponse {
32 fn token_usage(&self) -> Option<crate::completion::Usage> {
33 let mut usage = crate::completion::Usage::new();
34
35 usage.input_tokens = self.usage.prompt_tokens as u64;
36 usage.output_tokens = self.usage.completion_tokens as u64;
37 usage.total_tokens = self.usage.total_tokens as u64;
38
39 Some(usage)
40 }
41}
42
43#[derive(Serialize, Deserialize, Debug)]
44pub struct StreamingChoice {
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub finish_reason: Option<String>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 pub native_finish_reason: Option<String>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 pub logprobs: Option<Value>,
51 pub index: usize,
52 #[serde(skip_serializing_if = "Option::is_none")]
53 pub message: Option<MessageResponse>,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 pub delta: Option<DeltaResponse>,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub error: Option<ErrorResponse>,
58}
59
60#[derive(Serialize, Deserialize, Debug)]
61pub struct MessageResponse {
62 pub role: String,
63 pub content: String,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 pub refusal: Option<Value>,
66 #[serde(default)]
67 pub tool_calls: Vec<OpenRouterToolCall>,
68}
69
70#[derive(Serialize, Deserialize, Debug)]
71pub struct OpenRouterToolFunction {
72 pub name: Option<String>,
73 pub arguments: Option<String>,
74}
75
76#[derive(Serialize, Deserialize, Debug)]
77pub struct OpenRouterToolCall {
78 pub index: usize,
79 pub id: Option<String>,
80 pub r#type: Option<String>,
81 pub function: OpenRouterToolFunction,
82}
83
84#[derive(Serialize, Deserialize, Debug, Clone, Default)]
85pub struct ResponseUsage {
86 pub prompt_tokens: u32,
87 pub completion_tokens: u32,
88 pub total_tokens: u32,
89}
90
91#[derive(Serialize, Deserialize, Debug)]
92pub struct ErrorResponse {
93 pub code: i32,
94 pub message: String,
95 #[serde(skip_serializing_if = "Option::is_none")]
96 pub metadata: Option<HashMap<String, Value>>,
97}
98
99#[derive(Serialize, Deserialize, Debug)]
100pub struct DeltaResponse {
101 pub role: Option<String>,
102 #[serde(skip_serializing_if = "Option::is_none")]
103 pub content: Option<String>,
104 #[serde(default)]
105 pub tool_calls: Vec<OpenRouterToolCall>,
106 #[serde(skip_serializing_if = "Option::is_none")]
107 pub native_finish_reason: Option<String>,
108}
109
110#[derive(Clone, Deserialize, Serialize)]
111pub struct FinalCompletionResponse {
112 pub usage: ResponseUsage,
113}
114
115impl super::CompletionModel {
116 pub(crate) async fn stream(
117 &self,
118 completion_request: CompletionRequest,
119 ) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
120 {
121 let request = self.create_completion_request(completion_request)?;
122
123 let request = json_utils::merge(request, json!({"stream": true}));
124
125 let builder = self.client.post("/chat/completions").json(&request);
126
127 send_streaming_request(builder).await
128 }
129}
130
131pub async fn send_streaming_request(
132 request_builder: RequestBuilder,
133) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
134 let response = request_builder.send().await?;
135
136 if !response.status().is_success() {
137 return Err(CompletionError::ProviderError(format!(
138 "{}: {}",
139 response.status(),
140 response.text().await?
141 )));
142 }
143
144 let stream = Box::pin(stream! {
146 let mut stream = response.bytes_stream();
147 let mut tool_calls = HashMap::new();
148 let mut partial_line = String::new();
149 let mut final_usage = None;
150
151 while let Some(chunk_result) = stream.next().await {
152 let chunk = match chunk_result {
153 Ok(c) => c,
154 Err(e) => {
155 yield Err(CompletionError::from(e));
156 break;
157 }
158 };
159
160 let text = match String::from_utf8(chunk.to_vec()) {
161 Ok(t) => t,
162 Err(e) => {
163 yield Err(CompletionError::ResponseError(e.to_string()));
164 break;
165 }
166 };
167
168 for line in text.lines() {
169 let mut line = line.to_string();
170
171 if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
173 continue;
174 }
175
176 line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
178
179 if line.starts_with('{') && !line.ends_with('}') {
181 partial_line = line;
182 continue;
183 }
184
185 if !partial_line.is_empty() {
187 if line.ends_with('}') {
188 partial_line.push_str(&line);
189 line = partial_line;
190 partial_line = String::new();
191 } else {
192 partial_line.push_str(&line);
193 continue;
194 }
195 }
196
197 let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
198 Ok(data) => data,
199 Err(_) => {
200 continue;
201 }
202 };
203
204
205 let choice = data.choices.first().expect("Should have at least one choice");
206
207 if let Some(delta) = &choice.delta {
217 if !delta.tool_calls.is_empty() {
218 for tool_call in &delta.tool_calls {
219 let index = tool_call.index;
220
221 let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
223 id: String::new(),
224 call_id: None,
225 function: ToolFunction {
226 name: String::new(),
227 arguments: serde_json::Value::Null,
228 },
229 });
230
231 if let Some(id) = &tool_call.id && !id.is_empty() {
233 existing_tool_call.id = id.clone();
234 }
235
236 if let Some(name) = &tool_call.function.name && !name.is_empty() {
237 existing_tool_call.function.name = name.clone();
238 }
239
240 if let Some(chunk) = &tool_call.function.arguments {
241 let current_args = match &existing_tool_call.function.arguments {
243 serde_json::Value::Null => String::new(),
244 serde_json::Value::String(s) => s.clone(),
245 v => v.to_string(),
246 };
247
248 let combined = format!("{current_args}{chunk}");
250
251 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
253 match serde_json::from_str(&combined) {
254 Ok(parsed) => existing_tool_call.function.arguments = parsed,
255 Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
256 }
257 } else {
258 existing_tool_call.function.arguments = serde_json::Value::String(combined);
259 }
260 }
261 }
262 }
263
264 if let Some(content) = &delta.content &&!content.is_empty() {
265 yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
266 }
267
268 if let Some(usage) = data.usage {
269 final_usage = Some(usage);
270 }
271 }
272
273 if let Some(message) = &choice.message {
275 if !message.tool_calls.is_empty() {
276 for tool_call in &message.tool_calls {
277 let name = tool_call.function.name.clone();
278 let id = tool_call.id.clone();
279 let arguments = if let Some(args) = &tool_call.function.arguments {
280 match serde_json::from_str(args) {
282 Ok(v) => v,
283 Err(_) => serde_json::Value::String(args.to_string()),
284 }
285 } else {
286 serde_json::Value::Null
287 };
288 let index = tool_call.index;
289
290 tool_calls.insert(index, ToolCall {
291 id: id.unwrap_or_default(),
292 call_id: None,
293 function: ToolFunction {
294 name: name.unwrap_or_default(),
295 arguments,
296 },
297 });
298 }
299 }
300
301 if !message.content.is_empty() {
302 yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
303 }
304 }
305 }
306 }
307
308 for (_, tool_call) in tool_calls.into_iter() {
309
310 yield Ok(streaming::RawStreamingChoice::ToolCall{
311 name: tool_call.function.name,
312 id: tool_call.id,
313 arguments: tool_call.function.arguments,
314 call_id: None
315 });
316 }
317
318 yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
319 usage: final_usage.unwrap_or_default()
320 }))
321
322 });
323
324 Ok(streaming::StreamingCompletionResponse::stream(stream))
325}
326
327pub async fn send_streaming_request1(
328 request_builder: RequestBuilder,
329) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
330 let mut event_source = request_builder
331 .eventsource()
332 .expect("Cloning request must always succeed");
333
334 let stream = Box::pin(stream! {
335 let mut tool_calls: HashMap<usize, ToolCall> = HashMap::new();
337 let mut final_usage = None;
338
339 while let Some(event_result) = event_source.next().await {
340 match event_result {
341 Ok(Event::Open) => {
342 tracing::trace!("SSE connection opened");
343 continue;
344 }
345
346 Ok(Event::Message(event_message)) => {
347 let raw = event_message.data;
348
349 let parsed = serde_json::from_str::<StreamingCompletionResponse>(&raw);
350 let Ok(data) = parsed else {
351 tracing::debug!("Couldn't parse OpenRouter payload as StreamingCompletionResponse; skipping chunk");
352 continue;
353 };
354
355 let choice = match data.choices.first() {
357 Some(c) => c,
358 None => continue,
359 };
360
361 if let Some(delta) = &choice.delta {
363 if !delta.tool_calls.is_empty() {
364 for tc in &delta.tool_calls {
365 let index = tc.index;
366
367 let existing = tool_calls.entry(index).or_insert_with(|| ToolCall {
369 id: String::new(),
370 call_id: None,
371 function: ToolFunction {
372 name: String::new(),
373 arguments: Value::Null,
374 },
375 });
376
377 if let Some(id) = &tc.id && !id.is_empty() {
379 existing.id = id.clone();
380 }
381
382 if let Some(name) = &tc.function.name && !name.is_empty() {
384 existing.function.name = name.clone();
385 }
386
387 if let Some(chunk) = &tc.function.arguments {
389 let current_args = match &existing.function.arguments {
391 Value::Null => String::new(),
392 Value::String(s) => s.clone(),
393 v => v.to_string(),
394 };
395
396 let combined = format!("{}{}", current_args, chunk);
397
398 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
400 match serde_json::from_str::<Value>(&combined) {
401 Ok(parsed_value) => existing.function.arguments = parsed_value,
402 Err(_) => existing.function.arguments = Value::String(combined),
403 }
404 } else {
405 existing.function.arguments = Value::String(combined);
406 }
407 }
408 }
409 }
410
411 if let Some(content) = &delta.content && !content.is_empty() {
413 yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
414 }
415
416 if let Some(usage) = data.usage {
418 final_usage = Some(usage);
419 }
420 }
421
422 if let Some(message) = &choice.message {
424 if !message.tool_calls.is_empty() {
425 for tc in &message.tool_calls {
426 let idx = tc.index;
427 let name = tc.function.name.clone().unwrap_or_default();
428 let id = tc.id.clone().unwrap_or_default();
429
430 let args_value = if let Some(args_str) = &tc.function.arguments {
431 match serde_json::from_str::<Value>(args_str) {
432 Ok(v) => v,
433 Err(_) => Value::String(args_str.clone()),
434 }
435 } else {
436 Value::Null
437 };
438
439 tool_calls.insert(idx, ToolCall {
440 id,
441 call_id: None,
442 function: ToolFunction {
443 name,
444 arguments: args_value,
445 },
446 });
447 }
448 }
449
450 if !message.content.is_empty() {
451 yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()));
452 }
453 }
454 }
455
456 Err(reqwest_eventsource::Error::StreamEnded) => {
457 break;
458 }
459
460 Err(error) => {
461 tracing::error!(?error, "SSE error from OpenRouter event source");
462 yield Err(CompletionError::ResponseError(error.to_string()));
463 break;
464 }
465 }
466 }
467
468 event_source.close();
470
471 for (_idx, tool_call) in tool_calls.into_iter() {
473 yield Ok(streaming::RawStreamingChoice::ToolCall {
474 name: tool_call.function.name,
475 id: tool_call.id,
476 arguments: tool_call.function.arguments,
477 call_id: None,
478 });
479 }
480
481 yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
483 usage: final_usage.unwrap_or_default(),
484 }));
485 });
486
487 Ok(streaming::StreamingCompletionResponse::stream(stream))
488}