rig/providers/openrouter/
streaming.rs

1use std::collections::HashMap;
2
3use crate::{
4    json_utils,
5    message::{ToolCall, ToolFunction},
6    streaming::{self},
7};
8use async_stream::stream;
9use futures::StreamExt;
10use reqwest::RequestBuilder;
11use serde_json::{Value, json};
12
13use crate::completion::{CompletionError, CompletionRequest};
14use serde::{Deserialize, Serialize};
15
16#[derive(Serialize, Deserialize, Debug)]
17pub struct StreamingCompletionResponse {
18    pub id: String,
19    pub choices: Vec<StreamingChoice>,
20    pub created: u64,
21    pub model: String,
22    pub object: String,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub system_fingerprint: Option<String>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub usage: Option<ResponseUsage>,
27}
28
29#[derive(Serialize, Deserialize, Debug)]
30pub struct StreamingChoice {
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub finish_reason: Option<String>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub native_finish_reason: Option<String>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub logprobs: Option<Value>,
37    pub index: usize,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub message: Option<MessageResponse>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub delta: Option<DeltaResponse>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub error: Option<ErrorResponse>,
44}
45
46#[derive(Serialize, Deserialize, Debug)]
47pub struct MessageResponse {
48    pub role: String,
49    pub content: String,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub refusal: Option<Value>,
52    #[serde(default)]
53    pub tool_calls: Vec<OpenRouterToolCall>,
54}
55
56#[derive(Serialize, Deserialize, Debug)]
57pub struct OpenRouterToolFunction {
58    pub name: Option<String>,
59    pub arguments: Option<String>,
60}
61
62#[derive(Serialize, Deserialize, Debug)]
63pub struct OpenRouterToolCall {
64    pub index: usize,
65    pub id: Option<String>,
66    pub r#type: Option<String>,
67    pub function: OpenRouterToolFunction,
68}
69
70#[derive(Serialize, Deserialize, Debug, Clone, Default)]
71pub struct ResponseUsage {
72    pub prompt_tokens: u32,
73    pub completion_tokens: u32,
74    pub total_tokens: u32,
75}
76
77#[derive(Serialize, Deserialize, Debug)]
78pub struct ErrorResponse {
79    pub code: i32,
80    pub message: String,
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub metadata: Option<HashMap<String, Value>>,
83}
84
85#[derive(Serialize, Deserialize, Debug)]
86pub struct DeltaResponse {
87    pub role: Option<String>,
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub content: Option<String>,
90    #[serde(default)]
91    pub tool_calls: Vec<OpenRouterToolCall>,
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub native_finish_reason: Option<String>,
94}
95
96#[derive(Clone)]
97pub struct FinalCompletionResponse {
98    pub usage: ResponseUsage,
99}
100
101impl super::CompletionModel {
102    pub(crate) async fn stream(
103        &self,
104        completion_request: CompletionRequest,
105    ) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
106    {
107        let request = self.create_completion_request(completion_request)?;
108
109        let request = json_utils::merge(request, json!({"stream": true}));
110
111        let builder = self.client.post("/chat/completions").json(&request);
112
113        send_streaming_request(builder).await
114    }
115}
116
117pub async fn send_streaming_request(
118    request_builder: RequestBuilder,
119) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
120    let response = request_builder.send().await?;
121
122    if !response.status().is_success() {
123        return Err(CompletionError::ProviderError(format!(
124            "{}: {}",
125            response.status(),
126            response.text().await?
127        )));
128    }
129
130    // Handle OpenAI Compatible SSE chunks
131    let stream = Box::pin(stream! {
132        let mut stream = response.bytes_stream();
133        let mut tool_calls = HashMap::new();
134        let mut partial_line = String::new();
135        let mut final_usage = None;
136
137        while let Some(chunk_result) = stream.next().await {
138            let chunk = match chunk_result {
139                Ok(c) => c,
140                Err(e) => {
141                    yield Err(CompletionError::from(e));
142                    break;
143                }
144            };
145
146            let text = match String::from_utf8(chunk.to_vec()) {
147                Ok(t) => t,
148                Err(e) => {
149                    yield Err(CompletionError::ResponseError(e.to_string()));
150                    break;
151                }
152            };
153
154            for line in text.lines() {
155                let mut line = line.to_string();
156
157                // Skip empty lines and processing messages, as well as [DONE] (might be useful though)
158                if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
159                    continue;
160                }
161
162                // Handle data: prefix
163                line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
164
165                // If line starts with { but doesn't end with }, it's a partial JSON
166                if line.starts_with('{') && !line.ends_with('}') {
167                    partial_line = line;
168                    continue;
169                }
170
171                // If we have a partial line and this line ends with }, complete it
172                if !partial_line.is_empty() {
173                    if line.ends_with('}') {
174                        partial_line.push_str(&line);
175                        line = partial_line;
176                        partial_line = String::new();
177                    } else {
178                        partial_line.push_str(&line);
179                        continue;
180                    }
181                }
182
183                let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
184                    Ok(data) => data,
185                    Err(_) => {
186                        continue;
187                    }
188                };
189
190
191                let choice = data.choices.first().expect("Should have at least one choice");
192
193                // TODO this has to handle outputs like this:
194                // [{"index": 0, "id": "call_DdmO9pD3xa9XTPNJ32zg2hcA", "function": {"arguments": "", "name": "get_weather"}, "type": "function"}]
195                // [{"index": 0, "id": null, "function": {"arguments": "{\"", "name": null}, "type": null}]
196                // [{"index": 0, "id": null, "function": {"arguments": "location", "name": null}, "type": null}]
197                // [{"index": 0, "id": null, "function": {"arguments": "\":\"", "name": null}, "type": null}]
198                // [{"index": 0, "id": null, "function": {"arguments": "Paris", "name": null}, "type": null}]
199                // [{"index": 0, "id": null, "function": {"arguments": ",", "name": null}, "type": null}]
200                // [{"index": 0, "id": null, "function": {"arguments": " France", "name": null}, "type": null}]
201                // [{"index": 0, "id": null, "function": {"arguments": "\"}", "name": null}, "type": null}]
202                if let Some(delta) = &choice.delta {
203                    if !delta.tool_calls.is_empty() {
204                        for tool_call in &delta.tool_calls {
205                            let index = tool_call.index;
206
207                            // Get or create tool call entry
208                            let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
209                                id: String::new(),
210                                call_id: None,
211                                function: ToolFunction {
212                                    name: String::new(),
213                                    arguments: serde_json::Value::Null,
214                                },
215                            });
216
217                            // Update fields if present
218                            if let Some(id) = &tool_call.id {
219                                if !id.is_empty() {
220                                    existing_tool_call.id = id.clone();
221                                }
222                            }
223                            if let Some(name) = &tool_call.function.name {
224                                if !name.is_empty() {
225                                    existing_tool_call.function.name = name.clone();
226                                }
227                            }
228                            if let Some(chunk) = &tool_call.function.arguments {
229                                // Convert current arguments to string if needed
230                                let current_args = match &existing_tool_call.function.arguments {
231                                    serde_json::Value::Null => String::new(),
232                                    serde_json::Value::String(s) => s.clone(),
233                                    v => v.to_string(),
234                                };
235
236                                // Concatenate the new chunk
237                                let combined = format!("{current_args}{chunk}");
238
239                                // Try to parse as JSON if it looks complete
240                                if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
241                                    match serde_json::from_str(&combined) {
242                                        Ok(parsed) => existing_tool_call.function.arguments = parsed,
243                                        Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
244                                    }
245                                } else {
246                                    existing_tool_call.function.arguments = serde_json::Value::String(combined);
247                                }
248                            }
249                        }
250                    }
251
252                    if let Some(content) = &delta.content {
253                        if !content.is_empty() {
254                            yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
255                        }
256                    }
257
258                    if let Some(usage) = data.usage {
259                        final_usage = Some(usage);
260                    }
261                }
262
263                // Handle message format
264                if let Some(message) = &choice.message {
265                    if !message.tool_calls.is_empty() {
266                        for tool_call in &message.tool_calls {
267                            let name = tool_call.function.name.clone();
268                            let id = tool_call.id.clone();
269                            let arguments = if let Some(args) = &tool_call.function.arguments {
270                                // Try to parse the string as JSON, fallback to string value
271                                match serde_json::from_str(args) {
272                                    Ok(v) => v,
273                                    Err(_) => serde_json::Value::String(args.to_string()),
274                                }
275                            } else {
276                                serde_json::Value::Null
277                            };
278                            let index = tool_call.index;
279
280                            tool_calls.insert(index, ToolCall {
281                                id: id.unwrap_or_default(),
282                                call_id: None,
283                                function: ToolFunction {
284                                    name: name.unwrap_or_default(),
285                                    arguments,
286                                },
287                            });
288                        }
289                    }
290
291                    if !message.content.is_empty() {
292                        yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
293                    }
294                }
295            }
296        }
297
298        for (_, tool_call) in tool_calls.into_iter() {
299
300            yield Ok(streaming::RawStreamingChoice::ToolCall{
301                name: tool_call.function.name,
302                id: tool_call.id,
303                arguments: tool_call.function.arguments,
304                call_id: None
305            });
306        }
307
308        yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
309            usage: final_usage.unwrap_or_default()
310        }))
311
312    });
313
314    Ok(streaming::StreamingCompletionResponse::stream(stream))
315}