rig/providers/openrouter/
streaming.rs

1use std::collections::HashMap;
2
3use crate::{
4    json_utils,
5    message::{ToolCall, ToolFunction},
6    streaming::{self},
7};
8use async_stream::stream;
9use futures::StreamExt;
10use reqwest::RequestBuilder;
11use serde_json::{json, Value};
12
13use crate::completion::{CompletionError, CompletionRequest};
14use serde::{Deserialize, Serialize};
15
16#[derive(Serialize, Deserialize, Debug)]
17pub struct StreamingCompletionResponse {
18    pub id: String,
19    pub choices: Vec<StreamingChoice>,
20    pub created: u64,
21    pub model: String,
22    pub object: String,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub system_fingerprint: Option<String>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub usage: Option<ResponseUsage>,
27}
28
29#[derive(Serialize, Deserialize, Debug)]
30pub struct StreamingChoice {
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub finish_reason: Option<String>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub native_finish_reason: Option<String>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub logprobs: Option<Value>,
37    pub index: usize,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub message: Option<MessageResponse>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub delta: Option<DeltaResponse>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub error: Option<ErrorResponse>,
44}
45
46#[derive(Serialize, Deserialize, Debug)]
47pub struct MessageResponse {
48    pub role: String,
49    pub content: String,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub refusal: Option<Value>,
52    #[serde(default)]
53    pub tool_calls: Vec<OpenRouterToolCall>,
54}
55
56#[derive(Serialize, Deserialize, Debug)]
57pub struct OpenRouterToolFunction {
58    pub name: Option<String>,
59    pub arguments: Option<String>,
60}
61
62#[derive(Serialize, Deserialize, Debug)]
63pub struct OpenRouterToolCall {
64    pub index: usize,
65    pub id: Option<String>,
66    pub r#type: Option<String>,
67    pub function: OpenRouterToolFunction,
68}
69
70#[derive(Serialize, Deserialize, Debug, Clone, Default)]
71pub struct ResponseUsage {
72    pub prompt_tokens: u32,
73    pub completion_tokens: u32,
74    pub total_tokens: u32,
75}
76
77#[derive(Serialize, Deserialize, Debug)]
78pub struct ErrorResponse {
79    pub code: i32,
80    pub message: String,
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub metadata: Option<HashMap<String, Value>>,
83}
84
85#[derive(Serialize, Deserialize, Debug)]
86pub struct DeltaResponse {
87    pub role: Option<String>,
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub content: Option<String>,
90    #[serde(default)]
91    pub tool_calls: Vec<OpenRouterToolCall>,
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub native_finish_reason: Option<String>,
94}
95
96#[derive(Clone)]
97pub struct FinalCompletionResponse {
98    pub usage: ResponseUsage,
99}
100
101impl super::CompletionModel {
102    pub(crate) async fn stream(
103        &self,
104        completion_request: CompletionRequest,
105    ) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
106    {
107        let request = self.create_completion_request(completion_request)?;
108
109        let request = json_utils::merge(request, json!({"stream": true}));
110
111        let builder = self.client.post("/chat/completions").json(&request);
112
113        send_streaming_request(builder).await
114    }
115}
116
117pub async fn send_streaming_request(
118    request_builder: RequestBuilder,
119) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
120    let response = request_builder.send().await?;
121
122    if !response.status().is_success() {
123        return Err(CompletionError::ProviderError(format!(
124            "{}: {}",
125            response.status(),
126            response.text().await?
127        )));
128    }
129
130    // Handle OpenAI Compatible SSE chunks
131    let stream = Box::pin(stream! {
132        let mut stream = response.bytes_stream();
133        let mut tool_calls = HashMap::new();
134        let mut partial_line = String::new();
135        let mut final_usage = None;
136
137        while let Some(chunk_result) = stream.next().await {
138            let chunk = match chunk_result {
139                Ok(c) => c,
140                Err(e) => {
141                    yield Err(CompletionError::from(e));
142                    break;
143                }
144            };
145
146            let text = match String::from_utf8(chunk.to_vec()) {
147                Ok(t) => t,
148                Err(e) => {
149                    yield Err(CompletionError::ResponseError(e.to_string()));
150                    break;
151                }
152            };
153
154            for line in text.lines() {
155                let mut line = line.to_string();
156
157                // Skip empty lines and processing messages, as well as [DONE] (might be useful though)
158                if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
159                    continue;
160                }
161
162                // Handle data: prefix
163                line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
164
165                // If line starts with { but doesn't end with }, it's a partial JSON
166                if line.starts_with('{') && !line.ends_with('}') {
167                    partial_line = line;
168                    continue;
169                }
170
171                // If we have a partial line and this line ends with }, complete it
172                if !partial_line.is_empty() {
173                    if line.ends_with('}') {
174                        partial_line.push_str(&line);
175                        line = partial_line;
176                        partial_line = String::new();
177                    } else {
178                        partial_line.push_str(&line);
179                        continue;
180                    }
181                }
182
183                let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
184                    Ok(data) => data,
185                    Err(_) => {
186                        continue;
187                    }
188                };
189
190
191                let choice = data.choices.first().expect("Should have at least one choice");
192
193                // TODO this has to handle outputs like this:
194                // [{"index": 0, "id": "call_DdmO9pD3xa9XTPNJ32zg2hcA", "function": {"arguments": "", "name": "get_weather"}, "type": "function"}]
195                // [{"index": 0, "id": null, "function": {"arguments": "{\"", "name": null}, "type": null}]
196                // [{"index": 0, "id": null, "function": {"arguments": "location", "name": null}, "type": null}]
197                // [{"index": 0, "id": null, "function": {"arguments": "\":\"", "name": null}, "type": null}]
198                // [{"index": 0, "id": null, "function": {"arguments": "Paris", "name": null}, "type": null}]
199                // [{"index": 0, "id": null, "function": {"arguments": ",", "name": null}, "type": null}]
200                // [{"index": 0, "id": null, "function": {"arguments": " France", "name": null}, "type": null}]
201                // [{"index": 0, "id": null, "function": {"arguments": "\"}", "name": null}, "type": null}]
202                if let Some(delta) = &choice.delta {
203                    if !delta.tool_calls.is_empty() {
204                        for tool_call in &delta.tool_calls {
205                            let index = tool_call.index;
206
207                            // Get or create tool call entry
208                            let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
209                                id: String::new(),
210                                function: ToolFunction {
211                                    name: String::new(),
212                                    arguments: serde_json::Value::Null,
213                                },
214                            });
215
216                            // Update fields if present
217                            if let Some(id) = &tool_call.id {
218                                if !id.is_empty() {
219                                    existing_tool_call.id = id.clone();
220                                }
221                            }
222                            if let Some(name) = &tool_call.function.name {
223                                if !name.is_empty() {
224                                    existing_tool_call.function.name = name.clone();
225                                }
226                            }
227                            if let Some(chunk) = &tool_call.function.arguments {
228                                // Convert current arguments to string if needed
229                                let current_args = match &existing_tool_call.function.arguments {
230                                    serde_json::Value::Null => String::new(),
231                                    serde_json::Value::String(s) => s.clone(),
232                                    v => v.to_string(),
233                                };
234
235                                // Concatenate the new chunk
236                                let combined = format!("{current_args}{chunk}");
237
238                                // Try to parse as JSON if it looks complete
239                                if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
240                                    match serde_json::from_str(&combined) {
241                                        Ok(parsed) => existing_tool_call.function.arguments = parsed,
242                                        Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
243                                    }
244                                } else {
245                                    existing_tool_call.function.arguments = serde_json::Value::String(combined);
246                                }
247                            }
248                        }
249                    }
250
251                    if let Some(content) = &delta.content {
252                        if !content.is_empty() {
253                            yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
254                        }
255                    }
256
257                    if let Some(usage) = data.usage {
258                        final_usage = Some(usage);
259                    }
260                }
261
262                // Handle message format
263                if let Some(message) = &choice.message {
264                    if !message.tool_calls.is_empty() {
265                        for tool_call in &message.tool_calls {
266                            let name = tool_call.function.name.clone();
267                            let id = tool_call.id.clone();
268                            let arguments = if let Some(args) = &tool_call.function.arguments {
269                                // Try to parse the string as JSON, fallback to string value
270                                match serde_json::from_str(args) {
271                                    Ok(v) => v,
272                                    Err(_) => serde_json::Value::String(args.to_string()),
273                                }
274                            } else {
275                                serde_json::Value::Null
276                            };
277                            let index = tool_call.index;
278
279                            tool_calls.insert(index, ToolCall{
280                                id: id.unwrap_or_default(),
281                                function: ToolFunction {
282                                    name: name.unwrap_or_default(),
283                                    arguments,
284                                },
285                            });
286                        }
287                    }
288
289                    if !message.content.is_empty() {
290                        yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
291                    }
292                }
293            }
294        }
295
296        for (_, tool_call) in tool_calls.into_iter() {
297
298            yield Ok(streaming::RawStreamingChoice::ToolCall{
299                name: tool_call.function.name,
300                id: tool_call.id,
301                arguments: tool_call.function.arguments
302            });
303        }
304
305        yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
306            usage: final_usage.unwrap_or_default()
307        }))
308
309    });
310
311    Ok(streaming::StreamingCompletionResponse::stream(stream))
312}