rig/providers/openrouter/
streaming.rs

1use reqwest_eventsource::{Event, RequestBuilderExt};
2use std::collections::HashMap;
3
4use crate::{
5    completion::GetTokenUsage,
6    json_utils,
7    message::{ToolCall, ToolFunction},
8    streaming::{self},
9};
10use async_stream::stream;
11use futures::StreamExt;
12use reqwest::RequestBuilder;
13use serde_json::{Value, json};
14
15use crate::completion::{CompletionError, CompletionRequest};
16use serde::{Deserialize, Serialize};
17
18#[derive(Serialize, Deserialize, Debug)]
19pub struct StreamingCompletionResponse {
20    pub id: String,
21    pub choices: Vec<StreamingChoice>,
22    pub created: u64,
23    pub model: String,
24    pub object: String,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub system_fingerprint: Option<String>,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub usage: Option<ResponseUsage>,
29}
30
31impl GetTokenUsage for FinalCompletionResponse {
32    fn token_usage(&self) -> Option<crate::completion::Usage> {
33        let mut usage = crate::completion::Usage::new();
34
35        usage.input_tokens = self.usage.prompt_tokens as u64;
36        usage.output_tokens = self.usage.completion_tokens as u64;
37        usage.total_tokens = self.usage.total_tokens as u64;
38
39        Some(usage)
40    }
41}
42
43#[derive(Serialize, Deserialize, Debug)]
44pub struct StreamingChoice {
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub finish_reason: Option<String>,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub native_finish_reason: Option<String>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub logprobs: Option<Value>,
51    pub index: usize,
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub message: Option<MessageResponse>,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub delta: Option<DeltaResponse>,
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub error: Option<ErrorResponse>,
58}
59
60#[derive(Serialize, Deserialize, Debug)]
61pub struct MessageResponse {
62    pub role: String,
63    pub content: String,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub refusal: Option<Value>,
66    #[serde(default)]
67    pub tool_calls: Vec<OpenRouterToolCall>,
68}
69
70#[derive(Serialize, Deserialize, Debug)]
71pub struct OpenRouterToolFunction {
72    pub name: Option<String>,
73    pub arguments: Option<String>,
74}
75
76#[derive(Serialize, Deserialize, Debug)]
77pub struct OpenRouterToolCall {
78    pub index: usize,
79    pub id: Option<String>,
80    pub r#type: Option<String>,
81    pub function: OpenRouterToolFunction,
82}
83
84#[derive(Serialize, Deserialize, Debug, Clone, Default)]
85pub struct ResponseUsage {
86    pub prompt_tokens: u32,
87    pub completion_tokens: u32,
88    pub total_tokens: u32,
89}
90
91#[derive(Serialize, Deserialize, Debug)]
92pub struct ErrorResponse {
93    pub code: i32,
94    pub message: String,
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub metadata: Option<HashMap<String, Value>>,
97}
98
99#[derive(Serialize, Deserialize, Debug)]
100pub struct DeltaResponse {
101    pub role: Option<String>,
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub content: Option<String>,
104    #[serde(default)]
105    pub tool_calls: Vec<OpenRouterToolCall>,
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub native_finish_reason: Option<String>,
108}
109
110#[derive(Clone, Deserialize, Serialize)]
111pub struct FinalCompletionResponse {
112    pub usage: ResponseUsage,
113}
114
115impl super::CompletionModel {
116    pub(crate) async fn stream(
117        &self,
118        completion_request: CompletionRequest,
119    ) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
120    {
121        let request = self.create_completion_request(completion_request)?;
122
123        let request = json_utils::merge(request, json!({"stream": true}));
124
125        let builder = self.client.post("/chat/completions").json(&request);
126
127        send_streaming_request(builder).await
128    }
129}
130
131pub async fn send_streaming_request(
132    request_builder: RequestBuilder,
133) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
134    let response = request_builder.send().await?;
135
136    if !response.status().is_success() {
137        return Err(CompletionError::ProviderError(format!(
138            "{}: {}",
139            response.status(),
140            response.text().await?
141        )));
142    }
143
144    // Handle OpenAI Compatible SSE chunks
145    let stream = Box::pin(stream! {
146        let mut stream = response.bytes_stream();
147        let mut tool_calls = HashMap::new();
148        let mut partial_line = String::new();
149        let mut final_usage = None;
150
151        while let Some(chunk_result) = stream.next().await {
152            let chunk = match chunk_result {
153                Ok(c) => c,
154                Err(e) => {
155                    yield Err(CompletionError::from(e));
156                    break;
157                }
158            };
159
160            let text = match String::from_utf8(chunk.to_vec()) {
161                Ok(t) => t,
162                Err(e) => {
163                    yield Err(CompletionError::ResponseError(e.to_string()));
164                    break;
165                }
166            };
167
168            for line in text.lines() {
169                let mut line = line.to_string();
170
171                // Skip empty lines and processing messages, as well as [DONE] (might be useful though)
172                if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
173                    continue;
174                }
175
176                // Handle data: prefix
177                line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
178
179                // If line starts with { but doesn't end with }, it's a partial JSON
180                if line.starts_with('{') && !line.ends_with('}') {
181                    partial_line = line;
182                    continue;
183                }
184
185                // If we have a partial line and this line ends with }, complete it
186                if !partial_line.is_empty() {
187                    if line.ends_with('}') {
188                        partial_line.push_str(&line);
189                        line = partial_line;
190                        partial_line = String::new();
191                    } else {
192                        partial_line.push_str(&line);
193                        continue;
194                    }
195                }
196
197                let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
198                    Ok(data) => data,
199                    Err(_) => {
200                        continue;
201                    }
202                };
203
204
205                let choice = data.choices.first().expect("Should have at least one choice");
206
207                // TODO this has to handle outputs like this:
208                // [{"index": 0, "id": "call_DdmO9pD3xa9XTPNJ32zg2hcA", "function": {"arguments": "", "name": "get_weather"}, "type": "function"}]
209                // [{"index": 0, "id": null, "function": {"arguments": "{\"", "name": null}, "type": null}]
210                // [{"index": 0, "id": null, "function": {"arguments": "location", "name": null}, "type": null}]
211                // [{"index": 0, "id": null, "function": {"arguments": "\":\"", "name": null}, "type": null}]
212                // [{"index": 0, "id": null, "function": {"arguments": "Paris", "name": null}, "type": null}]
213                // [{"index": 0, "id": null, "function": {"arguments": ",", "name": null}, "type": null}]
214                // [{"index": 0, "id": null, "function": {"arguments": " France", "name": null}, "type": null}]
215                // [{"index": 0, "id": null, "function": {"arguments": "\"}", "name": null}, "type": null}]
216                if let Some(delta) = &choice.delta {
217                    if !delta.tool_calls.is_empty() {
218                        for tool_call in &delta.tool_calls {
219                            let index = tool_call.index;
220
221                            // Get or create tool call entry
222                            let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
223                                id: String::new(),
224                                call_id: None,
225                                function: ToolFunction {
226                                    name: String::new(),
227                                    arguments: serde_json::Value::Null,
228                                },
229                            });
230
231                            // Update fields if present
232                            if let Some(id) = &tool_call.id && !id.is_empty() {
233                                    existing_tool_call.id = id.clone();
234                            }
235
236                            if let Some(name) = &tool_call.function.name && !name.is_empty() {
237                                    existing_tool_call.function.name = name.clone();
238                            }
239
240                            if let Some(chunk) = &tool_call.function.arguments {
241                                // Convert current arguments to string if needed
242                                let current_args = match &existing_tool_call.function.arguments {
243                                    serde_json::Value::Null => String::new(),
244                                    serde_json::Value::String(s) => s.clone(),
245                                    v => v.to_string(),
246                                };
247
248                                // Concatenate the new chunk
249                                let combined = format!("{current_args}{chunk}");
250
251                                // Try to parse as JSON if it looks complete
252                                if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
253                                    match serde_json::from_str(&combined) {
254                                        Ok(parsed) => existing_tool_call.function.arguments = parsed,
255                                        Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
256                                    }
257                                } else {
258                                    existing_tool_call.function.arguments = serde_json::Value::String(combined);
259                                }
260                            }
261                        }
262                    }
263
264                    if let Some(content) = &delta.content &&!content.is_empty() {
265                            yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
266                    }
267
268                    if let Some(usage) = data.usage {
269                        final_usage = Some(usage);
270                    }
271                }
272
273                // Handle message format
274                if let Some(message) = &choice.message {
275                    if !message.tool_calls.is_empty() {
276                        for tool_call in &message.tool_calls {
277                            let name = tool_call.function.name.clone();
278                            let id = tool_call.id.clone();
279                            let arguments = if let Some(args) = &tool_call.function.arguments {
280                                // Try to parse the string as JSON, fallback to string value
281                                match serde_json::from_str(args) {
282                                    Ok(v) => v,
283                                    Err(_) => serde_json::Value::String(args.to_string()),
284                                }
285                            } else {
286                                serde_json::Value::Null
287                            };
288                            let index = tool_call.index;
289
290                            tool_calls.insert(index, ToolCall {
291                                id: id.unwrap_or_default(),
292                                call_id: None,
293                                function: ToolFunction {
294                                    name: name.unwrap_or_default(),
295                                    arguments,
296                                },
297                            });
298                        }
299                    }
300
301                    if !message.content.is_empty() {
302                        yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
303                    }
304                }
305            }
306        }
307
308        for (_, tool_call) in tool_calls.into_iter() {
309
310            yield Ok(streaming::RawStreamingChoice::ToolCall{
311                name: tool_call.function.name,
312                id: tool_call.id,
313                arguments: tool_call.function.arguments,
314                call_id: None
315            });
316        }
317
318        yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
319            usage: final_usage.unwrap_or_default()
320        }))
321
322    });
323
324    Ok(streaming::StreamingCompletionResponse::stream(stream))
325}
326
327pub async fn send_streaming_request1(
328    request_builder: RequestBuilder,
329) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
330    let mut event_source = request_builder
331        .eventsource()
332        .expect("Cloning request must always succeed");
333
334    let stream = Box::pin(stream! {
335        // Accumulate tool calls by index while streaming
336        let mut tool_calls: HashMap<usize, ToolCall> = HashMap::new();
337        let mut final_usage = None;
338
339        while let Some(event_result) = event_source.next().await {
340            match event_result {
341                Ok(Event::Open) => {
342                    tracing::trace!("SSE connection opened");
343                    continue;
344                }
345
346                Ok(Event::Message(event_message)) => {
347                    let raw = event_message.data;
348
349                    let parsed = serde_json::from_str::<StreamingCompletionResponse>(&raw);
350                    let Ok(data) = parsed else {
351                        tracing::debug!("Couldn't parse OpenRouter payload as StreamingCompletionResponse; skipping chunk");
352                        continue;
353                    };
354
355                    // Expect at least one choice (keeps original behavior)
356                    let choice = match data.choices.first() {
357                        Some(c) => c,
358                        None => continue,
359                    };
360
361                    // --- Handle delta (streaming updates) ---
362                    if let Some(delta) = &choice.delta {
363                        if !delta.tool_calls.is_empty() {
364                            for tc in &delta.tool_calls {
365                                let index = tc.index;
366
367                                // Ensure entry exists
368                                let existing = tool_calls.entry(index).or_insert_with(|| ToolCall {
369                                    id: String::new(),
370                                    call_id: None,
371                                    function: ToolFunction {
372                                        name: String::new(),
373                                        arguments: Value::Null,
374                                    },
375                                });
376
377                                // Update id if present and non-empty
378                                if let Some(id) = &tc.id && !id.is_empty() {
379                                        existing.id = id.clone();
380                                }
381
382                                // Update name if present and non-empty
383                                if let Some(name) = &tc.function.name && !name.is_empty() {
384                                    existing.function.name = name.clone();
385                                }
386
387                                // Append argument chunk if present
388                                if let Some(chunk) = &tc.function.arguments {
389                                    // Current arguments as string (or empty)
390                                    let current_args = match &existing.function.arguments {
391                                        Value::Null => String::new(),
392                                        Value::String(s) => s.clone(),
393                                        v => v.to_string(),
394                                    };
395
396                                    let combined = format!("{}{}", current_args, chunk);
397
398                                    // If it looks like complete JSON object, try parse
399                                    if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
400                                        match serde_json::from_str::<Value>(&combined) {
401                                            Ok(parsed_value) => existing.function.arguments = parsed_value,
402                                            Err(_) => existing.function.arguments = Value::String(combined),
403                                        }
404                                    } else {
405                                        existing.function.arguments = Value::String(combined);
406                                    }
407                                }
408                            }
409                        }
410
411                        // Streamed text content
412                        if let Some(content) = &delta.content && !content.is_empty() {
413                            yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
414                        }
415
416                        // usage update (if present)
417                        if let Some(usage) = data.usage {
418                            final_usage = Some(usage);
419                        }
420                    }
421
422                    // --- Handle message (final/other message structure) ---
423                    if let Some(message) = &choice.message {
424                        if !message.tool_calls.is_empty() {
425                            for tc in &message.tool_calls {
426                                let idx = tc.index;
427                                let name = tc.function.name.clone().unwrap_or_default();
428                                let id = tc.id.clone().unwrap_or_default();
429
430                                let args_value = if let Some(args_str) = &tc.function.arguments {
431                                    match serde_json::from_str::<Value>(args_str) {
432                                        Ok(v) => v,
433                                        Err(_) => Value::String(args_str.clone()),
434                                    }
435                                } else {
436                                    Value::Null
437                                };
438
439                                tool_calls.insert(idx, ToolCall {
440                                    id,
441                                    call_id: None,
442                                    function: ToolFunction {
443                                        name,
444                                        arguments: args_value,
445                                    },
446                                });
447                            }
448                        }
449
450                        if !message.content.is_empty() {
451                            yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()));
452                        }
453                    }
454                }
455
456                Err(reqwest_eventsource::Error::StreamEnded) => {
457                    break;
458                }
459
460                Err(error) => {
461                    tracing::error!(?error, "SSE error from OpenRouter event source");
462                    yield Err(CompletionError::ResponseError(error.to_string()));
463                    break;
464                }
465            }
466        }
467
468        // Ensure event source is closed when stream ends
469        event_source.close();
470
471        // Flush any accumulated tool calls (that weren't emitted as ToolCall earlier)
472        for (_idx, tool_call) in tool_calls.into_iter() {
473            yield Ok(streaming::RawStreamingChoice::ToolCall {
474                name: tool_call.function.name,
475                id: tool_call.id,
476                arguments: tool_call.function.arguments,
477                call_id: None,
478            });
479        }
480
481        // Final response with usage
482        yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
483            usage: final_usage.unwrap_or_default(),
484        }));
485    });
486
487    Ok(streaming::StreamingCompletionResponse::stream(stream))
488}