rig/providers/openrouter/
streaming.rs

1use http::Request;
2use std::collections::HashMap;
3use tracing::info_span;
4
5use crate::{
6    completion::GetTokenUsage,
7    http_client::{self, HttpClientExt},
8    json_utils,
9    message::{ToolCall, ToolFunction},
10    streaming::{self},
11};
12use async_stream::stream;
13use futures::StreamExt;
14use serde_json::{Value, json};
15
16use crate::completion::{CompletionError, CompletionRequest};
17use serde::{Deserialize, Serialize};
18
19#[derive(Serialize, Deserialize, Debug)]
20pub struct StreamingCompletionResponse {
21    pub id: String,
22    pub choices: Vec<StreamingChoice>,
23    pub created: u64,
24    pub model: String,
25    pub object: String,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub system_fingerprint: Option<String>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub usage: Option<ResponseUsage>,
30}
31
32impl GetTokenUsage for FinalCompletionResponse {
33    fn token_usage(&self) -> Option<crate::completion::Usage> {
34        let mut usage = crate::completion::Usage::new();
35
36        usage.input_tokens = self.usage.prompt_tokens as u64;
37        usage.output_tokens = self.usage.completion_tokens as u64;
38        usage.total_tokens = self.usage.total_tokens as u64;
39
40        Some(usage)
41    }
42}
43
44#[derive(Serialize, Deserialize, Debug)]
45pub struct StreamingChoice {
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub finish_reason: Option<String>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub native_finish_reason: Option<String>,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub logprobs: Option<Value>,
52    pub index: usize,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub message: Option<MessageResponse>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub delta: Option<DeltaResponse>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub error: Option<ErrorResponse>,
59}
60
61#[derive(Serialize, Deserialize, Debug)]
62pub struct MessageResponse {
63    pub role: String,
64    pub content: String,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub refusal: Option<Value>,
67    #[serde(default)]
68    pub tool_calls: Vec<OpenRouterToolCall>,
69}
70
71#[derive(Serialize, Deserialize, Debug)]
72pub struct OpenRouterToolFunction {
73    pub name: Option<String>,
74    pub arguments: Option<String>,
75}
76
77#[derive(Serialize, Deserialize, Debug)]
78pub struct OpenRouterToolCall {
79    pub index: usize,
80    pub id: Option<String>,
81    pub r#type: Option<String>,
82    pub function: OpenRouterToolFunction,
83}
84
85#[derive(Serialize, Deserialize, Debug, Clone, Default)]
86pub struct ResponseUsage {
87    pub prompt_tokens: u32,
88    pub completion_tokens: u32,
89    pub total_tokens: u32,
90}
91
92#[derive(Serialize, Deserialize, Debug)]
93pub struct ErrorResponse {
94    pub code: i32,
95    pub message: String,
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub metadata: Option<HashMap<String, Value>>,
98}
99
100#[derive(Serialize, Deserialize, Debug)]
101pub struct DeltaResponse {
102    pub role: Option<String>,
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub content: Option<String>,
105    #[serde(default)]
106    pub tool_calls: Vec<OpenRouterToolCall>,
107    #[serde(skip_serializing_if = "Option::is_none")]
108    pub native_finish_reason: Option<String>,
109}
110
111#[derive(Clone, Deserialize, Serialize)]
112pub struct FinalCompletionResponse {
113    pub usage: ResponseUsage,
114}
115
116impl<T> super::CompletionModel<T>
117where
118    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
119{
120    pub(crate) async fn stream(
121        &self,
122        completion_request: CompletionRequest,
123    ) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
124    {
125        let preamble = completion_request.preamble.clone();
126        let request = self.create_completion_request(completion_request)?;
127
128        let request = json_utils::merge(request, json!({"stream": true}));
129
130        let body = serde_json::to_vec(&request)?;
131
132        let req = self
133            .client
134            .post("/chat/completions")?
135            .header("Content-Type", "application/json")
136            .body(body)
137            .map_err(|x| CompletionError::HttpError(x.into()))?;
138
139        let span = if tracing::Span::current().is_disabled() {
140            info_span!(
141                target: "rig::completions",
142                "chat_streaming",
143                gen_ai.operation.name = "chat_streaming",
144                gen_ai.provider.name = "openrouter",
145                gen_ai.request.model = self.model,
146                gen_ai.system_instructions = preamble,
147                gen_ai.response.id = tracing::field::Empty,
148                gen_ai.response.model = tracing::field::Empty,
149                gen_ai.usage.output_tokens = tracing::field::Empty,
150                gen_ai.usage.input_tokens = tracing::field::Empty,
151                gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(),
152                gen_ai.output.messages = tracing::field::Empty,
153            )
154        } else {
155            tracing::Span::current()
156        };
157
158        tracing::Instrument::instrument(
159            send_streaming_request(self.client.http_client.clone(), req),
160            span,
161        )
162        .await
163    }
164}
165
166pub async fn send_streaming_request<T>(
167    client: T,
168    req: Request<Vec<u8>>,
169) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
170where
171    T: HttpClientExt + Clone + 'static,
172{
173    let response = client.send_streaming(req).await?;
174    let status = response.status();
175
176    if !status.is_success() {
177        return Err(CompletionError::ProviderError(format!(
178            "Got response error trying to send a completion request to OpenRouter: {status}"
179        )));
180    }
181
182    let mut stream = response.into_body();
183
184    // Handle OpenAI Compatible SSE chunks
185    let stream = stream! {
186        let mut tool_calls = HashMap::new();
187        let mut partial_line = String::new();
188        let mut final_usage = None;
189
190        while let Some(chunk_result) = stream.next().await {
191            let chunk = match chunk_result {
192                Ok(c) => c,
193                Err(e) => {
194                    yield Err(CompletionError::from(http_client::Error::Instance(e.into())));
195                    break;
196                }
197            };
198
199            let text = match String::from_utf8(chunk.to_vec()) {
200                Ok(t) => t,
201                Err(e) => {
202                    yield Err(CompletionError::ResponseError(e.to_string()));
203                    break;
204                }
205            };
206
207            for line in text.lines() {
208                let mut line = line.to_string();
209
210                // Skip empty lines and processing messages, as well as [DONE] (might be useful though)
211                if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
212                    continue;
213                }
214
215                // Handle data: prefix
216                line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
217
218                // If line starts with { but doesn't end with }, it's a partial JSON
219                if line.starts_with('{') && !line.ends_with('}') {
220                    partial_line = line;
221                    continue;
222                }
223
224                // If we have a partial line and this line ends with }, complete it
225                if !partial_line.is_empty() {
226                    if line.ends_with('}') {
227                        partial_line.push_str(&line);
228                        line = partial_line;
229                        partial_line = String::new();
230                    } else {
231                        partial_line.push_str(&line);
232                        continue;
233                    }
234                }
235
236                let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
237                    Ok(data) => data,
238                    Err(_) => {
239                        continue;
240                    }
241                };
242
243
244                let choice = data.choices.first().expect("Should have at least one choice");
245
246                // TODO this has to handle outputs like this:
247                // [{"index": 0, "id": "call_DdmO9pD3xa9XTPNJ32zg2hcA", "function": {"arguments": "", "name": "get_weather"}, "type": "function"}]
248                // [{"index": 0, "id": null, "function": {"arguments": "{\"", "name": null}, "type": null}]
249                // [{"index": 0, "id": null, "function": {"arguments": "location", "name": null}, "type": null}]
250                // [{"index": 0, "id": null, "function": {"arguments": "\":\"", "name": null}, "type": null}]
251                // [{"index": 0, "id": null, "function": {"arguments": "Paris", "name": null}, "type": null}]
252                // [{"index": 0, "id": null, "function": {"arguments": ",", "name": null}, "type": null}]
253                // [{"index": 0, "id": null, "function": {"arguments": " France", "name": null}, "type": null}]
254                // [{"index": 0, "id": null, "function": {"arguments": "\"}", "name": null}, "type": null}]
255                if let Some(delta) = &choice.delta {
256                    if !delta.tool_calls.is_empty() {
257                        for tool_call in &delta.tool_calls {
258                            let index = tool_call.index;
259
260                            // Get or create tool call entry
261                            let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
262                                id: String::new(),
263                                call_id: None,
264                                function: ToolFunction {
265                                    name: String::new(),
266                                    arguments: serde_json::Value::Null,
267                                },
268                            });
269
270                            // Update fields if present
271                            if let Some(id) = &tool_call.id && !id.is_empty() {
272                                    existing_tool_call.id = id.clone();
273                            }
274
275                            if let Some(name) = &tool_call.function.name && !name.is_empty() {
276                                    existing_tool_call.function.name = name.clone();
277                            }
278
279                            if let Some(chunk) = &tool_call.function.arguments {
280                                // Convert current arguments to string if needed
281                                let current_args = match &existing_tool_call.function.arguments {
282                                    serde_json::Value::Null => String::new(),
283                                    serde_json::Value::String(s) => s.clone(),
284                                    v => v.to_string(),
285                                };
286
287                                // Concatenate the new chunk
288                                let combined = format!("{current_args}{chunk}");
289
290                                // Try to parse as JSON if it looks complete
291                                if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
292                                    match serde_json::from_str(&combined) {
293                                        Ok(parsed) => existing_tool_call.function.arguments = parsed,
294                                        Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
295                                    }
296                                } else {
297                                    existing_tool_call.function.arguments = serde_json::Value::String(combined);
298                                }
299                            }
300                        }
301                    }
302
303                    if let Some(content) = &delta.content &&!content.is_empty() {
304                            yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
305                    }
306
307                    if let Some(usage) = data.usage {
308                        final_usage = Some(usage);
309                    }
310                }
311
312                // Handle message format
313                if let Some(message) = &choice.message {
314                    if !message.tool_calls.is_empty() {
315                        for tool_call in &message.tool_calls {
316                            let name = tool_call.function.name.clone();
317                            let id = tool_call.id.clone();
318                            let arguments = if let Some(args) = &tool_call.function.arguments {
319                                // Try to parse the string as JSON, fallback to string value
320                                match serde_json::from_str(args) {
321                                    Ok(v) => v,
322                                    Err(_) => serde_json::Value::String(args.to_string()),
323                                }
324                            } else {
325                                serde_json::Value::Null
326                            };
327                            let index = tool_call.index;
328
329                            tool_calls.insert(index, ToolCall {
330                                id: id.unwrap_or_default(),
331                                call_id: None,
332                                function: ToolFunction {
333                                    name: name.unwrap_or_default(),
334                                    arguments,
335                                },
336                            });
337                        }
338                    }
339
340                    if !message.content.is_empty() {
341                        yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
342                    }
343                }
344            }
345        }
346
347        for (_, tool_call) in tool_calls.into_iter() {
348
349            yield Ok(streaming::RawStreamingChoice::ToolCall{
350                name: tool_call.function.name,
351                id: tool_call.id,
352                arguments: tool_call.function.arguments,
353                call_id: None
354            });
355        }
356
357        yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
358            usage: final_usage.unwrap_or_default()
359        }))
360
361    };
362
363    Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
364        stream,
365    )))
366}