rig/providers/openrouter/
streaming.rs

1use std::collections::HashMap;
2
3use crate::{
4    completion::GetTokenUsage,
5    json_utils,
6    message::{ToolCall, ToolFunction},
7    streaming::{self},
8};
9use async_stream::stream;
10use futures::StreamExt;
11use reqwest::RequestBuilder;
12use serde_json::{Value, json};
13
14use crate::completion::{CompletionError, CompletionRequest};
15use serde::{Deserialize, Serialize};
16
17#[derive(Serialize, Deserialize, Debug)]
18pub struct StreamingCompletionResponse {
19    pub id: String,
20    pub choices: Vec<StreamingChoice>,
21    pub created: u64,
22    pub model: String,
23    pub object: String,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub system_fingerprint: Option<String>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub usage: Option<ResponseUsage>,
28}
29
30impl GetTokenUsage for FinalCompletionResponse {
31    fn token_usage(&self) -> Option<crate::completion::Usage> {
32        let mut usage = crate::completion::Usage::new();
33
34        usage.input_tokens = self.usage.prompt_tokens as u64;
35        usage.output_tokens = self.usage.completion_tokens as u64;
36        usage.total_tokens = self.usage.total_tokens as u64;
37
38        Some(usage)
39    }
40}
41
42#[derive(Serialize, Deserialize, Debug)]
43pub struct StreamingChoice {
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub finish_reason: Option<String>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub native_finish_reason: Option<String>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub logprobs: Option<Value>,
50    pub index: usize,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub message: Option<MessageResponse>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub delta: Option<DeltaResponse>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub error: Option<ErrorResponse>,
57}
58
59#[derive(Serialize, Deserialize, Debug)]
60pub struct MessageResponse {
61    pub role: String,
62    pub content: String,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub refusal: Option<Value>,
65    #[serde(default)]
66    pub tool_calls: Vec<OpenRouterToolCall>,
67}
68
69#[derive(Serialize, Deserialize, Debug)]
70pub struct OpenRouterToolFunction {
71    pub name: Option<String>,
72    pub arguments: Option<String>,
73}
74
75#[derive(Serialize, Deserialize, Debug)]
76pub struct OpenRouterToolCall {
77    pub index: usize,
78    pub id: Option<String>,
79    pub r#type: Option<String>,
80    pub function: OpenRouterToolFunction,
81}
82
83#[derive(Serialize, Deserialize, Debug, Clone, Default)]
84pub struct ResponseUsage {
85    pub prompt_tokens: u32,
86    pub completion_tokens: u32,
87    pub total_tokens: u32,
88}
89
90#[derive(Serialize, Deserialize, Debug)]
91pub struct ErrorResponse {
92    pub code: i32,
93    pub message: String,
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub metadata: Option<HashMap<String, Value>>,
96}
97
98#[derive(Serialize, Deserialize, Debug)]
99pub struct DeltaResponse {
100    pub role: Option<String>,
101    #[serde(skip_serializing_if = "Option::is_none")]
102    pub content: Option<String>,
103    #[serde(default)]
104    pub tool_calls: Vec<OpenRouterToolCall>,
105    #[serde(skip_serializing_if = "Option::is_none")]
106    pub native_finish_reason: Option<String>,
107}
108
109#[derive(Clone, Deserialize, Serialize)]
110pub struct FinalCompletionResponse {
111    pub usage: ResponseUsage,
112}
113
114impl super::CompletionModel {
115    pub(crate) async fn stream(
116        &self,
117        completion_request: CompletionRequest,
118    ) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
119    {
120        let request = self.create_completion_request(completion_request)?;
121
122        let request = json_utils::merge(request, json!({"stream": true}));
123
124        let builder = self.client.post("/chat/completions").json(&request);
125
126        send_streaming_request(builder).await
127    }
128}
129
130pub async fn send_streaming_request(
131    request_builder: RequestBuilder,
132) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
133    let response = request_builder.send().await?;
134
135    if !response.status().is_success() {
136        return Err(CompletionError::ProviderError(format!(
137            "{}: {}",
138            response.status(),
139            response.text().await?
140        )));
141    }
142
143    // Handle OpenAI Compatible SSE chunks
144    let stream = Box::pin(stream! {
145        let mut stream = response.bytes_stream();
146        let mut tool_calls = HashMap::new();
147        let mut partial_line = String::new();
148        let mut final_usage = None;
149
150        while let Some(chunk_result) = stream.next().await {
151            let chunk = match chunk_result {
152                Ok(c) => c,
153                Err(e) => {
154                    yield Err(CompletionError::from(e));
155                    break;
156                }
157            };
158
159            let text = match String::from_utf8(chunk.to_vec()) {
160                Ok(t) => t,
161                Err(e) => {
162                    yield Err(CompletionError::ResponseError(e.to_string()));
163                    break;
164                }
165            };
166
167            for line in text.lines() {
168                let mut line = line.to_string();
169
170                // Skip empty lines and processing messages, as well as [DONE] (might be useful though)
171                if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
172                    continue;
173                }
174
175                // Handle data: prefix
176                line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
177
178                // If line starts with { but doesn't end with }, it's a partial JSON
179                if line.starts_with('{') && !line.ends_with('}') {
180                    partial_line = line;
181                    continue;
182                }
183
184                // If we have a partial line and this line ends with }, complete it
185                if !partial_line.is_empty() {
186                    if line.ends_with('}') {
187                        partial_line.push_str(&line);
188                        line = partial_line;
189                        partial_line = String::new();
190                    } else {
191                        partial_line.push_str(&line);
192                        continue;
193                    }
194                }
195
196                let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
197                    Ok(data) => data,
198                    Err(_) => {
199                        continue;
200                    }
201                };
202
203
204                let choice = data.choices.first().expect("Should have at least one choice");
205
206                // TODO this has to handle outputs like this:
207                // [{"index": 0, "id": "call_DdmO9pD3xa9XTPNJ32zg2hcA", "function": {"arguments": "", "name": "get_weather"}, "type": "function"}]
208                // [{"index": 0, "id": null, "function": {"arguments": "{\"", "name": null}, "type": null}]
209                // [{"index": 0, "id": null, "function": {"arguments": "location", "name": null}, "type": null}]
210                // [{"index": 0, "id": null, "function": {"arguments": "\":\"", "name": null}, "type": null}]
211                // [{"index": 0, "id": null, "function": {"arguments": "Paris", "name": null}, "type": null}]
212                // [{"index": 0, "id": null, "function": {"arguments": ",", "name": null}, "type": null}]
213                // [{"index": 0, "id": null, "function": {"arguments": " France", "name": null}, "type": null}]
214                // [{"index": 0, "id": null, "function": {"arguments": "\"}", "name": null}, "type": null}]
215                if let Some(delta) = &choice.delta {
216                    if !delta.tool_calls.is_empty() {
217                        for tool_call in &delta.tool_calls {
218                            let index = tool_call.index;
219
220                            // Get or create tool call entry
221                            let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
222                                id: String::new(),
223                                call_id: None,
224                                function: ToolFunction {
225                                    name: String::new(),
226                                    arguments: serde_json::Value::Null,
227                                },
228                            });
229
230                            // Update fields if present
231                            if let Some(id) = &tool_call.id && !id.is_empty() {
232                                    existing_tool_call.id = id.clone();
233                            }
234
235                            if let Some(name) = &tool_call.function.name && !name.is_empty() {
236                                    existing_tool_call.function.name = name.clone();
237                            }
238
239                            if let Some(chunk) = &tool_call.function.arguments {
240                                // Convert current arguments to string if needed
241                                let current_args = match &existing_tool_call.function.arguments {
242                                    serde_json::Value::Null => String::new(),
243                                    serde_json::Value::String(s) => s.clone(),
244                                    v => v.to_string(),
245                                };
246
247                                // Concatenate the new chunk
248                                let combined = format!("{current_args}{chunk}");
249
250                                // Try to parse as JSON if it looks complete
251                                if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
252                                    match serde_json::from_str(&combined) {
253                                        Ok(parsed) => existing_tool_call.function.arguments = parsed,
254                                        Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
255                                    }
256                                } else {
257                                    existing_tool_call.function.arguments = serde_json::Value::String(combined);
258                                }
259                            }
260                        }
261                    }
262
263                    if let Some(content) = &delta.content &&!content.is_empty() {
264                            yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
265                    }
266
267                    if let Some(usage) = data.usage {
268                        final_usage = Some(usage);
269                    }
270                }
271
272                // Handle message format
273                if let Some(message) = &choice.message {
274                    if !message.tool_calls.is_empty() {
275                        for tool_call in &message.tool_calls {
276                            let name = tool_call.function.name.clone();
277                            let id = tool_call.id.clone();
278                            let arguments = if let Some(args) = &tool_call.function.arguments {
279                                // Try to parse the string as JSON, fallback to string value
280                                match serde_json::from_str(args) {
281                                    Ok(v) => v,
282                                    Err(_) => serde_json::Value::String(args.to_string()),
283                                }
284                            } else {
285                                serde_json::Value::Null
286                            };
287                            let index = tool_call.index;
288
289                            tool_calls.insert(index, ToolCall {
290                                id: id.unwrap_or_default(),
291                                call_id: None,
292                                function: ToolFunction {
293                                    name: name.unwrap_or_default(),
294                                    arguments,
295                                },
296                            });
297                        }
298                    }
299
300                    if !message.content.is_empty() {
301                        yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
302                    }
303                }
304            }
305        }
306
307        for (_, tool_call) in tool_calls.into_iter() {
308
309            yield Ok(streaming::RawStreamingChoice::ToolCall{
310                name: tool_call.function.name,
311                id: tool_call.id,
312                arguments: tool_call.function.arguments,
313                call_id: None
314            });
315        }
316
317        yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
318            usage: final_usage.unwrap_or_default()
319        }))
320
321    });
322
323    Ok(streaming::StreamingCompletionResponse::stream(stream))
324}