rig/providers/openrouter/
streaming.rs

1use http::Request;
2use std::collections::HashMap;
3use tracing::info_span;
4
5use crate::{
6    completion::GetTokenUsage,
7    http_client::{self, HttpClientExt},
8    json_utils,
9    message::{ToolCall, ToolFunction},
10    streaming::{self},
11};
12use async_stream::stream;
13use futures::StreamExt;
14use serde_json::{Value, json};
15
16use crate::completion::{CompletionError, CompletionRequest};
17use serde::{Deserialize, Serialize};
18
19#[derive(Serialize, Deserialize, Debug)]
20pub struct StreamingCompletionResponse {
21    pub id: String,
22    pub choices: Vec<StreamingChoice>,
23    pub created: u64,
24    pub model: String,
25    pub object: String,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub system_fingerprint: Option<String>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub usage: Option<ResponseUsage>,
30}
31
32impl GetTokenUsage for FinalCompletionResponse {
33    fn token_usage(&self) -> Option<crate::completion::Usage> {
34        let mut usage = crate::completion::Usage::new();
35
36        usage.input_tokens = self.usage.prompt_tokens as u64;
37        usage.output_tokens = self.usage.completion_tokens as u64;
38        usage.total_tokens = self.usage.total_tokens as u64;
39
40        Some(usage)
41    }
42}
43
44#[derive(Serialize, Deserialize, Debug)]
45pub struct StreamingChoice {
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub finish_reason: Option<String>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub native_finish_reason: Option<String>,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub logprobs: Option<Value>,
52    pub index: usize,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub message: Option<MessageResponse>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub delta: Option<DeltaResponse>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub error: Option<ErrorResponse>,
59}
60
61#[derive(Serialize, Deserialize, Debug)]
62pub struct MessageResponse {
63    pub role: String,
64    pub content: String,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub refusal: Option<Value>,
67    #[serde(default)]
68    pub tool_calls: Vec<OpenRouterToolCall>,
69}
70
71#[derive(Serialize, Deserialize, Debug)]
72pub struct OpenRouterToolFunction {
73    pub name: Option<String>,
74    pub arguments: Option<String>,
75}
76
77#[derive(Serialize, Deserialize, Debug)]
78pub struct OpenRouterToolCall {
79    pub index: usize,
80    pub id: Option<String>,
81    pub r#type: Option<String>,
82    pub function: OpenRouterToolFunction,
83}
84
85#[derive(Serialize, Deserialize, Debug, Clone, Default)]
86pub struct ResponseUsage {
87    pub prompt_tokens: u32,
88    pub completion_tokens: u32,
89    pub total_tokens: u32,
90}
91
92#[derive(Serialize, Deserialize, Debug)]
93pub struct ErrorResponse {
94    pub code: i32,
95    pub message: String,
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub metadata: Option<HashMap<String, Value>>,
98}
99
100#[derive(Serialize, Deserialize, Debug)]
101pub struct DeltaResponse {
102    pub role: Option<String>,
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub content: Option<String>,
105    #[serde(default)]
106    pub tool_calls: Vec<OpenRouterToolCall>,
107    #[serde(skip_serializing_if = "Option::is_none")]
108    pub native_finish_reason: Option<String>,
109}
110
111#[derive(Clone, Deserialize, Serialize)]
112pub struct FinalCompletionResponse {
113    pub usage: ResponseUsage,
114}
115
116impl<T> super::CompletionModel<T>
117where
118    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
119{
120    pub(crate) async fn stream(
121        &self,
122        completion_request: CompletionRequest,
123    ) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
124    {
125        let preamble = completion_request.preamble.clone();
126        let request = self.create_completion_request(completion_request)?;
127
128        let request = json_utils::merge(request, json!({"stream": true}));
129
130        let body = serde_json::to_vec(&request)?;
131
132        let req = self
133            .client
134            .post("/chat/completions")?
135            .header("Content-Type", "application/json")
136            .body(body)
137            .map_err(|x| CompletionError::HttpError(x.into()))?;
138
139        let span = if tracing::Span::current().is_disabled() {
140            info_span!(
141                target: "rig::completions",
142                "chat_streaming",
143                gen_ai.operation.name = "chat_streaming",
144                gen_ai.provider.name = "openrouter",
145                gen_ai.request.model = self.model,
146                gen_ai.system_instructions = preamble,
147                gen_ai.response.id = tracing::field::Empty,
148                gen_ai.response.model = tracing::field::Empty,
149                gen_ai.usage.output_tokens = tracing::field::Empty,
150                gen_ai.usage.input_tokens = tracing::field::Empty,
151                gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(),
152                gen_ai.output.messages = tracing::field::Empty,
153            )
154        } else {
155            tracing::Span::current()
156        };
157
158        tracing::Instrument::instrument(
159            send_streaming_request(self.client.http_client.clone(), req),
160            span,
161        )
162        .await
163    }
164}
165
166pub async fn send_streaming_request<T>(
167    client: T,
168    req: Request<Vec<u8>>,
169) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
170where
171    T: HttpClientExt + Clone + 'static,
172{
173    let response = client.send_streaming(req).await?;
174    let status = response.status();
175
176    if !status.is_success() {
177        return Err(CompletionError::ProviderError(format!(
178            "Got response error trying to send a completion request to OpenRouter: {status}"
179        )));
180    }
181
182    let mut stream = response.into_body();
183
184    // Handle OpenAI Compatible SSE chunks
185    let stream = stream! {
186        let mut tool_calls = HashMap::new();
187        let mut partial_line = String::new();
188        let mut final_usage = None;
189
190        while let Some(chunk_result) = stream.next().await {
191            let chunk = match chunk_result {
192                Ok(c) => c,
193                Err(e) => {
194                    yield Err(CompletionError::from(http_client::Error::Instance(e.into())));
195                    break;
196                }
197            };
198
199            let text = match String::from_utf8(chunk.to_vec()) {
200                Ok(t) => t,
201                Err(e) => {
202                    yield Err(CompletionError::ResponseError(e.to_string()));
203                    break;
204                }
205            };
206
207            for line in text.lines() {
208                let mut line = line.to_string();
209
210                // Skip empty lines and processing messages, as well as [DONE] (might be useful though)
211                if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
212                    continue;
213                }
214
215                // Handle data: prefix
216                line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
217
218                // If line starts with { but doesn't end with }, it's a partial JSON
219                if line.starts_with('{') && !line.ends_with('}') {
220                    partial_line = line;
221                    continue;
222                }
223
224                // If we have a partial line and this line ends with }, complete it
225                if !partial_line.is_empty() {
226                    if line.ends_with('}') {
227                        partial_line.push_str(&line);
228                        line = partial_line;
229                        partial_line = String::new();
230                    } else {
231                        partial_line.push_str(&line);
232                        continue;
233                    }
234                }
235
236                let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
237                    Ok(data) => data,
238                    Err(_) => {
239                        continue;
240                    }
241                };
242
243
244                let choice = data.choices.first().expect("Should have at least one choice");
245
246                // TODO this has to handle outputs like this:
247                // [{"index": 0, "id": "call_DdmO9pD3xa9XTPNJ32zg2hcA", "function": {"arguments": "", "name": "get_weather"}, "type": "function"}]
248                // [{"index": 0, "id": null, "function": {"arguments": "{\"", "name": null}, "type": null}]
249                // [{"index": 0, "id": null, "function": {"arguments": "location", "name": null}, "type": null}]
250                // [{"index": 0, "id": null, "function": {"arguments": "\":\"", "name": null}, "type": null}]
251                // [{"index": 0, "id": null, "function": {"arguments": "Paris", "name": null}, "type": null}]
252                // [{"index": 0, "id": null, "function": {"arguments": ",", "name": null}, "type": null}]
253                // [{"index": 0, "id": null, "function": {"arguments": " France", "name": null}, "type": null}]
254                // [{"index": 0, "id": null, "function": {"arguments": "\"}", "name": null}, "type": null}]
255                if let Some(delta) = &choice.delta {
256                    if !delta.tool_calls.is_empty() {
257                        for tool_call in &delta.tool_calls {
258                            let index = tool_call.index;
259
260                            // Get or create tool call entry
261                            let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
262                                id: String::new(),
263                                call_id: None,
264                                function: ToolFunction {
265                                    name: String::new(),
266                                    arguments: serde_json::Value::Null,
267                                },
268                            });
269
270                            // Update fields if present
271                            if let Some(id) = &tool_call.id && !id.is_empty() {
272                                    existing_tool_call.id = id.clone();
273                            }
274
275                            if let Some(name) = &tool_call.function.name && !name.is_empty() {
276                                    existing_tool_call.function.name = name.clone();
277                            }
278
279                            if let Some(chunk) = &tool_call.function.arguments {
280                                // Convert current arguments to string if needed
281                                let current_args = match &existing_tool_call.function.arguments {
282                                    serde_json::Value::Null => String::new(),
283                                    serde_json::Value::String(s) => s.clone(),
284                                    v => v.to_string(),
285                                };
286
287                                // Concatenate the new chunk
288                                let combined = format!("{current_args}{chunk}");
289
290                                // Try to parse as JSON if it looks complete
291                                if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
292                                    match serde_json::from_str(&combined) {
293                                        Ok(parsed) => existing_tool_call.function.arguments = parsed,
294                                        Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
295                                    }
296                                } else {
297                                    existing_tool_call.function.arguments = serde_json::Value::String(combined);
298                                }
299
300                                // Emit the delta so UI can show progress
301                                yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
302                                    id: existing_tool_call.id.clone(),
303                                    delta: chunk.clone(),
304                                });
305                            }
306                        }
307                    }
308
309                    if let Some(content) = &delta.content &&!content.is_empty() {
310                            yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
311                    }
312
313                    if let Some(usage) = data.usage {
314                        final_usage = Some(usage);
315                    }
316                }
317
318                // Handle message format
319                if let Some(message) = &choice.message {
320                    if !message.tool_calls.is_empty() {
321                        for tool_call in &message.tool_calls {
322                            let name = tool_call.function.name.clone();
323                            let id = tool_call.id.clone();
324                            let arguments = if let Some(args) = &tool_call.function.arguments {
325                                // Try to parse the string as JSON, fallback to string value
326                                match serde_json::from_str(args) {
327                                    Ok(v) => v,
328                                    Err(_) => serde_json::Value::String(args.to_string()),
329                                }
330                            } else {
331                                serde_json::Value::Null
332                            };
333                            let index = tool_call.index;
334
335                            tool_calls.insert(index, ToolCall {
336                                id: id.unwrap_or_default(),
337                                call_id: None,
338                                function: ToolFunction {
339                                    name: name.unwrap_or_default(),
340                                    arguments,
341                                },
342                            });
343                        }
344                    }
345
346                    if !message.content.is_empty() {
347                        yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
348                    }
349                }
350            }
351        }
352
353        for (_, tool_call) in tool_calls.into_iter() {
354
355            yield Ok(streaming::RawStreamingChoice::ToolCall{
356                name: tool_call.function.name,
357                id: tool_call.id,
358                arguments: tool_call.function.arguments,
359                call_id: None
360            });
361        }
362
363        yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
364            usage: final_usage.unwrap_or_default()
365        }))
366
367    };
368
369    Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
370        stream,
371    )))
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use serde_json::json;
378
379    #[test]
380    fn test_streaming_completion_response_deserialization() {
381        let json = json!({
382            "id": "gen-abc123",
383            "choices": [{
384                "index": 0,
385                "delta": {
386                    "role": "assistant",
387                    "content": "Hello"
388                }
389            }],
390            "created": 1234567890u64,
391            "model": "gpt-3.5-turbo",
392            "object": "chat.completion.chunk"
393        });
394
395        let response: StreamingCompletionResponse = serde_json::from_value(json).unwrap();
396        assert_eq!(response.id, "gen-abc123");
397        assert_eq!(response.model, "gpt-3.5-turbo");
398        assert_eq!(response.choices.len(), 1);
399    }
400
401    #[test]
402    fn test_delta_with_content() {
403        let json = json!({
404            "role": "assistant",
405            "content": "Hello, world!"
406        });
407
408        let delta: DeltaResponse = serde_json::from_value(json).unwrap();
409        assert_eq!(delta.role, Some("assistant".to_string()));
410        assert_eq!(delta.content, Some("Hello, world!".to_string()));
411    }
412
413    #[test]
414    fn test_delta_with_tool_call() {
415        let json = json!({
416            "role": "assistant",
417            "tool_calls": [{
418                "index": 0,
419                "id": "call_abc",
420                "type": "function",
421                "function": {
422                    "name": "get_weather",
423                    "arguments": "{\"location\":"
424                }
425            }]
426        });
427
428        let delta: DeltaResponse = serde_json::from_value(json).unwrap();
429        assert_eq!(delta.tool_calls.len(), 1);
430        assert_eq!(delta.tool_calls[0].index, 0);
431        assert_eq!(delta.tool_calls[0].id, Some("call_abc".to_string()));
432    }
433
434    #[test]
435    fn test_tool_call_with_partial_arguments() {
436        let json = json!({
437            "index": 0,
438            "id": null,
439            "type": null,
440            "function": {
441                "name": null,
442                "arguments": "Paris"
443            }
444        });
445
446        let tool_call: OpenRouterToolCall = serde_json::from_value(json).unwrap();
447        assert_eq!(tool_call.index, 0);
448        assert!(tool_call.id.is_none());
449        assert_eq!(tool_call.function.arguments, Some("Paris".to_string()));
450    }
451
452    #[test]
453    fn test_streaming_with_usage() {
454        let json = json!({
455            "id": "gen-xyz",
456            "choices": [{
457                "index": 0,
458                "delta": {
459                    "content": null
460                }
461            }],
462            "created": 1234567890u64,
463            "model": "gpt-4",
464            "object": "chat.completion.chunk",
465            "usage": {
466                "prompt_tokens": 100,
467                "completion_tokens": 50,
468                "total_tokens": 150
469            }
470        });
471
472        let response: StreamingCompletionResponse = serde_json::from_value(json).unwrap();
473        assert!(response.usage.is_some());
474        let usage = response.usage.unwrap();
475        assert_eq!(usage.prompt_tokens, 100);
476        assert_eq!(usage.completion_tokens, 50);
477        assert_eq!(usage.total_tokens, 150);
478    }
479
480    #[test]
481    fn test_multiple_tool_call_deltas() {
482        // Simulates the sequence of deltas for a tool call with arguments
483        let start_json = json!({
484            "id": "gen-1",
485            "choices": [{
486                "index": 0,
487                "delta": {
488                    "tool_calls": [{
489                        "index": 0,
490                        "id": "call_123",
491                        "type": "function",
492                        "function": {
493                            "name": "search",
494                            "arguments": ""
495                        }
496                    }]
497                }
498            }],
499            "created": 1234567890u64,
500            "model": "gpt-4",
501            "object": "chat.completion.chunk"
502        });
503
504        let delta1_json = json!({
505            "id": "gen-2",
506            "choices": [{
507                "index": 0,
508                "delta": {
509                    "tool_calls": [{
510                        "index": 0,
511                        "function": {
512                            "arguments": "{\"query\":"
513                        }
514                    }]
515                }
516            }],
517            "created": 1234567890u64,
518            "model": "gpt-4",
519            "object": "chat.completion.chunk"
520        });
521
522        let delta2_json = json!({
523            "id": "gen-3",
524            "choices": [{
525                "index": 0,
526                "delta": {
527                    "tool_calls": [{
528                        "index": 0,
529                        "function": {
530                            "arguments": "\"Rust programming\"}"
531                        }
532                    }]
533                }
534            }],
535            "created": 1234567890u64,
536            "model": "gpt-4",
537            "object": "chat.completion.chunk"
538        });
539
540        // Verify all chunks deserialize
541        let start: StreamingCompletionResponse = serde_json::from_value(start_json).unwrap();
542        assert_eq!(
543            start.choices[0].delta.as_ref().unwrap().tool_calls[0].id,
544            Some("call_123".to_string())
545        );
546
547        let delta1: StreamingCompletionResponse = serde_json::from_value(delta1_json).unwrap();
548        assert_eq!(
549            delta1.choices[0].delta.as_ref().unwrap().tool_calls[0]
550                .function
551                .arguments,
552            Some("{\"query\":".to_string())
553        );
554
555        let delta2: StreamingCompletionResponse = serde_json::from_value(delta2_json).unwrap();
556        assert_eq!(
557            delta2.choices[0].delta.as_ref().unwrap().tool_calls[0]
558                .function
559                .arguments,
560            Some("\"Rust programming\"}".to_string())
561        );
562    }
563
564    #[test]
565    fn test_response_with_error() {
566        let json = json!({
567            "id": "gen-err",
568            "choices": [{
569                "index": 0,
570                "error": {
571                    "code": 400,
572                    "message": "Invalid request"
573                }
574            }],
575            "created": 1234567890u64,
576            "model": "gpt-4",
577            "object": "chat.completion.chunk"
578        });
579
580        let response: StreamingCompletionResponse = serde_json::from_value(json).unwrap();
581        assert!(response.choices[0].error.is_some());
582        let error = response.choices[0].error.as_ref().unwrap();
583        assert_eq!(error.code, 400);
584        assert_eq!(error.message, "Invalid request");
585    }
586}