rig/providers/openrouter/
streaming.rs

1use reqwest_eventsource::{Event, RequestBuilderExt};
2use std::collections::HashMap;
3use tracing::info_span;
4
5use crate::{
6    completion::GetTokenUsage,
7    http_client, json_utils,
8    message::{ToolCall, ToolFunction},
9    streaming::{self},
10};
11use async_stream::stream;
12use futures::StreamExt;
13use reqwest::RequestBuilder;
14use serde_json::{Value, json};
15
16use crate::completion::{CompletionError, CompletionRequest};
17use serde::{Deserialize, Serialize};
18
19#[derive(Serialize, Deserialize, Debug)]
20pub struct StreamingCompletionResponse {
21    pub id: String,
22    pub choices: Vec<StreamingChoice>,
23    pub created: u64,
24    pub model: String,
25    pub object: String,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub system_fingerprint: Option<String>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub usage: Option<ResponseUsage>,
30}
31
32impl GetTokenUsage for FinalCompletionResponse {
33    fn token_usage(&self) -> Option<crate::completion::Usage> {
34        let mut usage = crate::completion::Usage::new();
35
36        usage.input_tokens = self.usage.prompt_tokens as u64;
37        usage.output_tokens = self.usage.completion_tokens as u64;
38        usage.total_tokens = self.usage.total_tokens as u64;
39
40        Some(usage)
41    }
42}
43
44#[derive(Serialize, Deserialize, Debug)]
45pub struct StreamingChoice {
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub finish_reason: Option<String>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub native_finish_reason: Option<String>,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub logprobs: Option<Value>,
52    pub index: usize,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub message: Option<MessageResponse>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub delta: Option<DeltaResponse>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub error: Option<ErrorResponse>,
59}
60
61#[derive(Serialize, Deserialize, Debug)]
62pub struct MessageResponse {
63    pub role: String,
64    pub content: String,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub refusal: Option<Value>,
67    #[serde(default)]
68    pub tool_calls: Vec<OpenRouterToolCall>,
69}
70
71#[derive(Serialize, Deserialize, Debug)]
72pub struct OpenRouterToolFunction {
73    pub name: Option<String>,
74    pub arguments: Option<String>,
75}
76
77#[derive(Serialize, Deserialize, Debug)]
78pub struct OpenRouterToolCall {
79    pub index: usize,
80    pub id: Option<String>,
81    pub r#type: Option<String>,
82    pub function: OpenRouterToolFunction,
83}
84
85#[derive(Serialize, Deserialize, Debug, Clone, Default)]
86pub struct ResponseUsage {
87    pub prompt_tokens: u32,
88    pub completion_tokens: u32,
89    pub total_tokens: u32,
90}
91
92#[derive(Serialize, Deserialize, Debug)]
93pub struct ErrorResponse {
94    pub code: i32,
95    pub message: String,
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub metadata: Option<HashMap<String, Value>>,
98}
99
100#[derive(Serialize, Deserialize, Debug)]
101pub struct DeltaResponse {
102    pub role: Option<String>,
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub content: Option<String>,
105    #[serde(default)]
106    pub tool_calls: Vec<OpenRouterToolCall>,
107    #[serde(skip_serializing_if = "Option::is_none")]
108    pub native_finish_reason: Option<String>,
109}
110
111#[derive(Clone, Deserialize, Serialize)]
112pub struct FinalCompletionResponse {
113    pub usage: ResponseUsage,
114}
115
116impl super::CompletionModel<reqwest::Client> {
117    pub(crate) async fn stream(
118        &self,
119        completion_request: CompletionRequest,
120    ) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
121    {
122        let preamble = completion_request.preamble.clone();
123        let request = self.create_completion_request(completion_request)?;
124
125        let request = json_utils::merge(request, json!({"stream": true}));
126
127        let builder = self
128            .client
129            .reqwest_post("/chat/completions")
130            .header("Content-Type", "application/json")
131            .json(&request);
132
133        let span = if tracing::Span::current().is_disabled() {
134            info_span!(
135                target: "rig::completions",
136                "chat_streaming",
137                gen_ai.operation.name = "chat_streaming",
138                gen_ai.provider.name = "openrouter",
139                gen_ai.request.model = self.model,
140                gen_ai.system_instructions = preamble,
141                gen_ai.response.id = tracing::field::Empty,
142                gen_ai.response.model = tracing::field::Empty,
143                gen_ai.usage.output_tokens = tracing::field::Empty,
144                gen_ai.usage.input_tokens = tracing::field::Empty,
145                gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(),
146                gen_ai.output.messages = tracing::field::Empty,
147            )
148        } else {
149            tracing::Span::current()
150        };
151
152        tracing::Instrument::instrument(send_streaming_request(builder), span).await
153    }
154}
155
156pub async fn send_streaming_request(
157    request_builder: RequestBuilder,
158) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
159    let response = request_builder
160        .send()
161        .await
162        .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?;
163
164    if !response.status().is_success() {
165        return Err(CompletionError::ProviderError(format!(
166            "{}: {}",
167            response.status(),
168            response
169                .text()
170                .await
171                .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?
172        )));
173    }
174
175    // Handle OpenAI Compatible SSE chunks
176    let stream = stream! {
177        let mut stream = response.bytes_stream();
178        let mut tool_calls = HashMap::new();
179        let mut partial_line = String::new();
180        let mut final_usage = None;
181
182        while let Some(chunk_result) = stream.next().await {
183            let chunk = match chunk_result {
184                Ok(c) => c,
185                Err(e) => {
186                    yield Err(CompletionError::from(http_client::Error::Instance(e.into())));
187                    break;
188                }
189            };
190
191            let text = match String::from_utf8(chunk.to_vec()) {
192                Ok(t) => t,
193                Err(e) => {
194                    yield Err(CompletionError::ResponseError(e.to_string()));
195                    break;
196                }
197            };
198
199            for line in text.lines() {
200                let mut line = line.to_string();
201
202                // Skip empty lines and processing messages, as well as [DONE] (might be useful though)
203                if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
204                    continue;
205                }
206
207                // Handle data: prefix
208                line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
209
210                // If line starts with { but doesn't end with }, it's a partial JSON
211                if line.starts_with('{') && !line.ends_with('}') {
212                    partial_line = line;
213                    continue;
214                }
215
216                // If we have a partial line and this line ends with }, complete it
217                if !partial_line.is_empty() {
218                    if line.ends_with('}') {
219                        partial_line.push_str(&line);
220                        line = partial_line;
221                        partial_line = String::new();
222                    } else {
223                        partial_line.push_str(&line);
224                        continue;
225                    }
226                }
227
228                let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
229                    Ok(data) => data,
230                    Err(_) => {
231                        continue;
232                    }
233                };
234
235
236                let choice = data.choices.first().expect("Should have at least one choice");
237
238                // TODO this has to handle outputs like this:
239                // [{"index": 0, "id": "call_DdmO9pD3xa9XTPNJ32zg2hcA", "function": {"arguments": "", "name": "get_weather"}, "type": "function"}]
240                // [{"index": 0, "id": null, "function": {"arguments": "{\"", "name": null}, "type": null}]
241                // [{"index": 0, "id": null, "function": {"arguments": "location", "name": null}, "type": null}]
242                // [{"index": 0, "id": null, "function": {"arguments": "\":\"", "name": null}, "type": null}]
243                // [{"index": 0, "id": null, "function": {"arguments": "Paris", "name": null}, "type": null}]
244                // [{"index": 0, "id": null, "function": {"arguments": ",", "name": null}, "type": null}]
245                // [{"index": 0, "id": null, "function": {"arguments": " France", "name": null}, "type": null}]
246                // [{"index": 0, "id": null, "function": {"arguments": "\"}", "name": null}, "type": null}]
247                if let Some(delta) = &choice.delta {
248                    if !delta.tool_calls.is_empty() {
249                        for tool_call in &delta.tool_calls {
250                            let index = tool_call.index;
251
252                            // Get or create tool call entry
253                            let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
254                                id: String::new(),
255                                call_id: None,
256                                function: ToolFunction {
257                                    name: String::new(),
258                                    arguments: serde_json::Value::Null,
259                                },
260                            });
261
262                            // Update fields if present
263                            if let Some(id) = &tool_call.id && !id.is_empty() {
264                                    existing_tool_call.id = id.clone();
265                            }
266
267                            if let Some(name) = &tool_call.function.name && !name.is_empty() {
268                                    existing_tool_call.function.name = name.clone();
269                            }
270
271                            if let Some(chunk) = &tool_call.function.arguments {
272                                // Convert current arguments to string if needed
273                                let current_args = match &existing_tool_call.function.arguments {
274                                    serde_json::Value::Null => String::new(),
275                                    serde_json::Value::String(s) => s.clone(),
276                                    v => v.to_string(),
277                                };
278
279                                // Concatenate the new chunk
280                                let combined = format!("{current_args}{chunk}");
281
282                                // Try to parse as JSON if it looks complete
283                                if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
284                                    match serde_json::from_str(&combined) {
285                                        Ok(parsed) => existing_tool_call.function.arguments = parsed,
286                                        Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
287                                    }
288                                } else {
289                                    existing_tool_call.function.arguments = serde_json::Value::String(combined);
290                                }
291                            }
292                        }
293                    }
294
295                    if let Some(content) = &delta.content &&!content.is_empty() {
296                            yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
297                    }
298
299                    if let Some(usage) = data.usage {
300                        final_usage = Some(usage);
301                    }
302                }
303
304                // Handle message format
305                if let Some(message) = &choice.message {
306                    if !message.tool_calls.is_empty() {
307                        for tool_call in &message.tool_calls {
308                            let name = tool_call.function.name.clone();
309                            let id = tool_call.id.clone();
310                            let arguments = if let Some(args) = &tool_call.function.arguments {
311                                // Try to parse the string as JSON, fallback to string value
312                                match serde_json::from_str(args) {
313                                    Ok(v) => v,
314                                    Err(_) => serde_json::Value::String(args.to_string()),
315                                }
316                            } else {
317                                serde_json::Value::Null
318                            };
319                            let index = tool_call.index;
320
321                            tool_calls.insert(index, ToolCall {
322                                id: id.unwrap_or_default(),
323                                call_id: None,
324                                function: ToolFunction {
325                                    name: name.unwrap_or_default(),
326                                    arguments,
327                                },
328                            });
329                        }
330                    }
331
332                    if !message.content.is_empty() {
333                        yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
334                    }
335                }
336            }
337        }
338
339        for (_, tool_call) in tool_calls.into_iter() {
340
341            yield Ok(streaming::RawStreamingChoice::ToolCall{
342                name: tool_call.function.name,
343                id: tool_call.id,
344                arguments: tool_call.function.arguments,
345                call_id: None
346            });
347        }
348
349        yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
350            usage: final_usage.unwrap_or_default()
351        }))
352
353    };
354
355    Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
356        stream,
357    )))
358}
359
360pub async fn send_streaming_request1(
361    request_builder: RequestBuilder,
362) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
363    let mut event_source = request_builder
364        .eventsource()
365        .expect("Cloning request must always succeed");
366
367    let stream = stream! {
368        // Accumulate tool calls by index while streaming
369        let mut tool_calls: HashMap<usize, ToolCall> = HashMap::new();
370        let mut final_usage = None;
371
372        while let Some(event_result) = event_source.next().await {
373            match event_result {
374                Ok(Event::Open) => {
375                    tracing::trace!("SSE connection opened");
376                    continue;
377                }
378
379                Ok(Event::Message(event_message)) => {
380                    let raw = event_message.data;
381
382                    let parsed = serde_json::from_str::<StreamingCompletionResponse>(&raw);
383                    let Ok(data) = parsed else {
384                        tracing::debug!("Couldn't parse OpenRouter payload as StreamingCompletionResponse; skipping chunk");
385                        continue;
386                    };
387
388                    // Expect at least one choice (keeps original behavior)
389                    let choice = match data.choices.first() {
390                        Some(c) => c,
391                        None => continue,
392                    };
393
394                    // --- Handle delta (streaming updates) ---
395                    if let Some(delta) = &choice.delta {
396                        if !delta.tool_calls.is_empty() {
397                            for tc in &delta.tool_calls {
398                                let index = tc.index;
399
400                                // Ensure entry exists
401                                let existing = tool_calls.entry(index).or_insert_with(|| ToolCall {
402                                    id: String::new(),
403                                    call_id: None,
404                                    function: ToolFunction {
405                                        name: String::new(),
406                                        arguments: Value::Null,
407                                    },
408                                });
409
410                                // Update id if present and non-empty
411                                if let Some(id) = &tc.id && !id.is_empty() {
412                                        existing.id = id.clone();
413                                }
414
415                                // Update name if present and non-empty
416                                if let Some(name) = &tc.function.name && !name.is_empty() {
417                                    existing.function.name = name.clone();
418                                }
419
420                                // Append argument chunk if present
421                                if let Some(chunk) = &tc.function.arguments {
422                                    // Current arguments as string (or empty)
423                                    let current_args = match &existing.function.arguments {
424                                        Value::Null => String::new(),
425                                        Value::String(s) => s.clone(),
426                                        v => v.to_string(),
427                                    };
428
429                                    let combined = format!("{}{}", current_args, chunk);
430
431                                    // If it looks like complete JSON object, try parse
432                                    if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
433                                        match serde_json::from_str::<Value>(&combined) {
434                                            Ok(parsed_value) => existing.function.arguments = parsed_value,
435                                            Err(_) => existing.function.arguments = Value::String(combined),
436                                        }
437                                    } else {
438                                        existing.function.arguments = Value::String(combined);
439                                    }
440                                }
441                            }
442                        }
443
444                        // Streamed text content
445                        if let Some(content) = &delta.content && !content.is_empty() {
446                            yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
447                        }
448
449                        // usage update (if present)
450                        if let Some(usage) = data.usage {
451                            final_usage = Some(usage);
452                        }
453                    }
454
455                    // --- Handle message (final/other message structure) ---
456                    if let Some(message) = &choice.message {
457                        if !message.tool_calls.is_empty() {
458                            for tc in &message.tool_calls {
459                                let idx = tc.index;
460                                let name = tc.function.name.clone().unwrap_or_default();
461                                let id = tc.id.clone().unwrap_or_default();
462
463                                let args_value = if let Some(args_str) = &tc.function.arguments {
464                                    match serde_json::from_str::<Value>(args_str) {
465                                        Ok(v) => v,
466                                        Err(_) => Value::String(args_str.clone()),
467                                    }
468                                } else {
469                                    Value::Null
470                                };
471
472                                tool_calls.insert(idx, ToolCall {
473                                    id,
474                                    call_id: None,
475                                    function: ToolFunction {
476                                        name,
477                                        arguments: args_value,
478                                    },
479                                });
480                            }
481                        }
482
483                        if !message.content.is_empty() {
484                            yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()));
485                        }
486                    }
487                }
488
489                Err(reqwest_eventsource::Error::StreamEnded) => {
490                    break;
491                }
492
493                Err(error) => {
494                    tracing::error!(?error, "SSE error from OpenRouter event source");
495                    yield Err(CompletionError::ResponseError(error.to_string()));
496                    break;
497                }
498            }
499        }
500
501        // Ensure event source is closed when stream ends
502        event_source.close();
503
504        // Flush any accumulated tool calls (that weren't emitted as ToolCall earlier)
505        for (_idx, tool_call) in tool_calls.into_iter() {
506            yield Ok(streaming::RawStreamingChoice::ToolCall {
507                name: tool_call.function.name,
508                id: tool_call.id,
509                arguments: tool_call.function.arguments,
510                call_id: None,
511            });
512        }
513
514        // Final response with usage
515        yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
516            usage: final_usage.unwrap_or_default(),
517        }));
518    };
519
520    Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
521        stream,
522    )))
523}