rig/providers/openrouter/
streaming.rs

1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tracing::info_span;
9use tracing_futures::Instrument;
10
11use crate::completion::GetTokenUsage;
12use crate::completion::{CompletionError, CompletionRequest};
13use crate::http_client::HttpClientExt;
14use crate::http_client::sse::{Event, GenericEventSource};
15use crate::json_utils;
16use crate::message::{ToolCall, ToolFunction};
17use crate::providers::openrouter::OpenrouterCompletionRequest;
18use crate::streaming;
19
20#[derive(Clone, Serialize, Deserialize, Debug)]
21pub struct StreamingCompletionResponse {
22    pub usage: Usage,
23}
24
25impl GetTokenUsage for StreamingCompletionResponse {
26    fn token_usage(&self) -> Option<crate::completion::Usage> {
27        let mut usage = crate::completion::Usage::new();
28
29        usage.input_tokens = self.usage.prompt_tokens as u64;
30        usage.output_tokens = self.usage.completion_tokens as u64;
31        usage.total_tokens = self.usage.total_tokens as u64;
32
33        Some(usage)
34    }
35}
36
37#[derive(Deserialize, Debug, PartialEq)]
38#[serde(rename_all = "snake_case")]
39pub enum FinishReason {
40    ToolCalls,
41    Stop,
42    Error,
43    ContentFilter,
44    Length,
45    #[serde(untagged)]
46    Other(String),
47}
48
49#[derive(Deserialize, Debug)]
50#[allow(dead_code)]
51struct StreamingChoice {
52    pub finish_reason: Option<FinishReason>,
53    pub native_finish_reason: Option<String>,
54    pub logprobs: Option<Value>,
55    pub index: usize,
56    pub delta: StreamingDelta,
57}
58
59#[derive(Deserialize, Debug)]
60struct StreamingFunction {
61    pub name: Option<String>,
62    pub arguments: Option<String>,
63}
64
65#[derive(Deserialize, Debug)]
66#[allow(dead_code)]
67struct StreamingToolCall {
68    pub index: usize,
69    pub id: Option<String>,
70    pub r#type: Option<String>,
71    pub function: StreamingFunction,
72}
73
74#[derive(Serialize, Deserialize, Debug, Clone, Default)]
75pub struct Usage {
76    pub prompt_tokens: u32,
77    pub completion_tokens: u32,
78    pub total_tokens: u32,
79}
80
81#[derive(Deserialize, Debug)]
82#[allow(dead_code)]
83struct ErrorResponse {
84    pub code: i32,
85    pub message: String,
86}
87
88#[derive(Deserialize, Debug)]
89#[allow(dead_code)]
90struct StreamingDelta {
91    pub role: Option<String>,
92    pub content: Option<String>,
93    pub reasoning: Option<String>,
94    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
95    pub tool_calls: Vec<StreamingToolCall>,
96}
97
98#[derive(Deserialize, Debug)]
99#[allow(dead_code)]
100struct StreamingCompletionChunk {
101    id: String,
102    model: String,
103    choices: Vec<StreamingChoice>,
104    usage: Option<Usage>,
105    error: Option<ErrorResponse>,
106}
107
108impl<T> super::CompletionModel<T>
109where
110    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
111{
112    pub(crate) async fn stream(
113        &self,
114        completion_request: CompletionRequest,
115    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
116    {
117        let preamble = completion_request.preamble.clone();
118        let mut request =
119            OpenrouterCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
120
121        let params = json_utils::merge(
122            request.additional_params.unwrap_or(serde_json::json!({})),
123            serde_json::json!({"stream": true }),
124        );
125
126        request.additional_params = Some(params);
127
128        let body = serde_json::to_vec(&request)?;
129
130        let req = self
131            .client
132            .post("/chat/completions")?
133            .body(body)
134            .map_err(|x| CompletionError::HttpError(x.into()))?;
135
136        let span = if tracing::Span::current().is_disabled() {
137            info_span!(
138                target: "rig::completions",
139                "chat_streaming",
140                gen_ai.operation.name = "chat_streaming",
141                gen_ai.provider.name = "openrouter",
142                gen_ai.request.model = self.model,
143                gen_ai.system_instructions = preamble,
144                gen_ai.response.id = tracing::field::Empty,
145                gen_ai.response.model = tracing::field::Empty,
146                gen_ai.usage.output_tokens = tracing::field::Empty,
147                gen_ai.usage.input_tokens = tracing::field::Empty,
148            )
149        } else {
150            tracing::Span::current()
151        };
152
153        tracing::Instrument::instrument(
154            send_compatible_streaming_request(self.client.clone(), req),
155            span,
156        )
157        .await
158    }
159}
160
161pub async fn send_compatible_streaming_request<T>(
162    http_client: T,
163    req: Request<Vec<u8>>,
164) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
165where
166    T: HttpClientExt + Clone + 'static,
167{
168    let span = tracing::Span::current();
169    // Build the request with proper headers for SSE
170    let mut event_source = GenericEventSource::new(http_client, req);
171
172    let stream = stream! {
173        // Accumulate tool calls by index while streaming
174        let mut tool_calls: HashMap<usize, ToolCall> = HashMap::new();
175        let mut final_usage = None;
176
177        while let Some(event_result) = event_source.next().await {
178            match event_result {
179                Ok(Event::Open) => {
180                    tracing::trace!("SSE connection opened");
181                    continue;
182                }
183
184                Ok(Event::Message(message)) => {
185                    if message.data.trim().is_empty() || message.data == "[DONE]" {
186                        continue;
187                    }
188
189                    let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
190                        Ok(data) => data,
191                        Err(error) => {
192                            tracing::error!(?error, message = message.data, "Failed to parse SSE message");
193                            continue;
194                        }
195                    };
196
197                    // Expect at least one choice
198                     let Some(choice) = data.choices.first() else {
199                        tracing::debug!("There is no choice");
200                        continue;
201                    };
202                    let delta = &choice.delta;
203
204                    if !delta.tool_calls.is_empty() {
205                        for tool_call in &delta.tool_calls {
206                            let index = tool_call.index;
207
208                            // Get or create tool call entry
209                            let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
210                                id: String::new(),
211                                call_id: None,
212                                function: ToolFunction {
213                                    name: String::new(),
214                                    arguments: serde_json::Value::Null,
215                                },
216                            });
217
218                            // Update fields if present
219                            if let Some(id) = &tool_call.id && !id.is_empty() {
220                                    existing_tool_call.id = id.clone();
221                            }
222
223                            if let Some(name) = &tool_call.function.name && !name.is_empty() {
224                                    existing_tool_call.function.name = name.clone();
225                            }
226
227                            if let Some(chunk) = &tool_call.function.arguments {
228                                // Convert current arguments to string if needed
229                                let current_args = match &existing_tool_call.function.arguments {
230                                    serde_json::Value::Null => String::new(),
231                                    serde_json::Value::String(s) => s.clone(),
232                                    v => v.to_string(),
233                                };
234
235                                // Concatenate the new chunk
236                                let combined = format!("{current_args}{chunk}");
237
238                                // Try to parse as JSON if it looks complete
239                                if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
240                                    match serde_json::from_str(&combined) {
241                                        Ok(parsed) => existing_tool_call.function.arguments = parsed,
242                                        Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
243                                    }
244                                } else {
245                                    existing_tool_call.function.arguments = serde_json::Value::String(combined);
246                                }
247
248                                // Emit the delta so UI can show progress
249                                yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
250                                    id: existing_tool_call.id.clone(),
251                                    delta: chunk.clone(),
252                                });
253                            }
254                        }
255                    }
256
257                    // Streamed reasoning content
258                    if let Some(reasoning) = &delta.reasoning && !reasoning.is_empty() {
259                        yield Ok(streaming::RawStreamingChoice::Reasoning {
260                            reasoning: reasoning.clone(),
261                            id: None,
262                            signature: None,
263                        });
264                    }
265
266                    // Streamed text content
267                    if let Some(content) = &delta.content && !content.is_empty() {
268                        yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
269                    }
270
271                    // Usage updates
272                    if let Some(usage) = data.usage {
273                        final_usage = Some(usage);
274                    }
275
276                    // Finish reason
277                    if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
278                        for (_idx, tool_call) in tool_calls.into_iter() {
279                            yield Ok(streaming::RawStreamingChoice::ToolCall {
280                                name: tool_call.function.name,
281                                id: tool_call.id,
282                                arguments: tool_call.function.arguments,
283                                call_id: None,
284                            });
285                        }
286                        tool_calls = HashMap::new();
287                    }
288                }
289                Err(crate::http_client::Error::StreamEnded) => {
290                    break;
291                }
292                Err(error) => {
293                    tracing::error!(?error, "SSE error");
294                    yield Err(CompletionError::ProviderError(error.to_string()));
295                    break;
296                }
297            }
298        }
299
300        // Ensure event source is closed when stream ends
301        event_source.close();
302
303        // Flush any accumulated tool calls (that weren't emitted as ToolCall earlier)
304        for (_idx, tool_call) in tool_calls.into_iter() {
305            yield Ok(streaming::RawStreamingChoice::ToolCall {
306                name: tool_call.function.name,
307                id: tool_call.id,
308                arguments: tool_call.function.arguments,
309                call_id: None,
310            });
311        }
312
313        // Final response with usage
314        yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
315            usage: final_usage.unwrap_or_default(),
316        }));
317    }.instrument(span);
318
319    Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
320        stream,
321    )))
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use serde_json::json;
328
329    #[test]
330    fn test_streaming_completion_response_deserialization() {
331        let json = json!({
332            "id": "gen-abc123",
333            "choices": [{
334                "index": 0,
335                "delta": {
336                    "role": "assistant",
337                    "content": "Hello"
338                }
339            }],
340            "created": 1234567890u64,
341            "model": "gpt-3.5-turbo",
342            "object": "chat.completion.chunk"
343        });
344
345        let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
346        assert_eq!(response.id, "gen-abc123");
347        assert_eq!(response.model, "gpt-3.5-turbo");
348        assert_eq!(response.choices.len(), 1);
349    }
350
351    #[test]
352    fn test_delta_with_content() {
353        let json = json!({
354            "role": "assistant",
355            "content": "Hello, world!"
356        });
357
358        let delta: StreamingDelta = serde_json::from_value(json).unwrap();
359        assert_eq!(delta.role, Some("assistant".to_string()));
360        assert_eq!(delta.content, Some("Hello, world!".to_string()));
361    }
362
363    #[test]
364    fn test_delta_with_tool_call() {
365        let json = json!({
366            "role": "assistant",
367            "tool_calls": [{
368                "index": 0,
369                "id": "call_abc",
370                "type": "function",
371                "function": {
372                    "name": "get_weather",
373                    "arguments": "{\"location\":"
374                }
375            }]
376        });
377
378        let delta: StreamingDelta = serde_json::from_value(json).unwrap();
379        assert_eq!(delta.tool_calls.len(), 1);
380        assert_eq!(delta.tool_calls[0].index, 0);
381        assert_eq!(delta.tool_calls[0].id, Some("call_abc".to_string()));
382    }
383
384    #[test]
385    fn test_tool_call_with_partial_arguments() {
386        let json = json!({
387            "index": 0,
388            "id": null,
389            "type": null,
390            "function": {
391                "name": null,
392                "arguments": "Paris"
393            }
394        });
395
396        let tool_call: StreamingToolCall = serde_json::from_value(json).unwrap();
397        assert_eq!(tool_call.index, 0);
398        assert!(tool_call.id.is_none());
399        assert_eq!(tool_call.function.arguments, Some("Paris".to_string()));
400    }
401
402    #[test]
403    fn test_streaming_with_usage() {
404        let json = json!({
405            "id": "gen-xyz",
406            "choices": [{
407                "index": 0,
408                "delta": {
409                    "content": null
410                }
411            }],
412            "created": 1234567890u64,
413            "model": "gpt-4",
414            "object": "chat.completion.chunk",
415            "usage": {
416                "prompt_tokens": 100,
417                "completion_tokens": 50,
418                "total_tokens": 150
419            }
420        });
421
422        let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
423        assert!(response.usage.is_some());
424        let usage = response.usage.unwrap();
425        assert_eq!(usage.prompt_tokens, 100);
426        assert_eq!(usage.completion_tokens, 50);
427        assert_eq!(usage.total_tokens, 150);
428    }
429
430    #[test]
431    fn test_multiple_tool_call_deltas() {
432        // Simulates the sequence of deltas for a tool call with arguments
433        let start_json = json!({
434            "id": "gen-1",
435            "choices": [{
436                "index": 0,
437                "delta": {
438                    "tool_calls": [{
439                        "index": 0,
440                        "id": "call_123",
441                        "type": "function",
442                        "function": {
443                            "name": "search",
444                            "arguments": ""
445                        }
446                    }]
447                }
448            }],
449            "created": 1234567890u64,
450            "model": "gpt-4",
451            "object": "chat.completion.chunk"
452        });
453
454        let delta1_json = json!({
455            "id": "gen-2",
456            "choices": [{
457                "index": 0,
458                "delta": {
459                    "tool_calls": [{
460                        "index": 0,
461                        "function": {
462                            "arguments": "{\"query\":"
463                        }
464                    }]
465                }
466            }],
467            "created": 1234567890u64,
468            "model": "gpt-4",
469            "object": "chat.completion.chunk"
470        });
471
472        let delta2_json = json!({
473            "id": "gen-3",
474            "choices": [{
475                "index": 0,
476                "delta": {
477                    "tool_calls": [{
478                        "index": 0,
479                        "function": {
480                            "arguments": "\"Rust programming\"}"
481                        }
482                    }]
483                }
484            }],
485            "created": 1234567890u64,
486            "model": "gpt-4",
487            "object": "chat.completion.chunk"
488        });
489
490        // Verify all chunks deserialize
491        let start: StreamingCompletionChunk = serde_json::from_value(start_json).unwrap();
492        assert_eq!(
493            start.choices[0].delta.tool_calls[0].id,
494            Some("call_123".to_string())
495        );
496
497        let delta1: StreamingCompletionChunk = serde_json::from_value(delta1_json).unwrap();
498        assert_eq!(
499            delta1.choices[0].delta.tool_calls[0].function.arguments,
500            Some("{\"query\":".to_string())
501        );
502
503        let delta2: StreamingCompletionChunk = serde_json::from_value(delta2_json).unwrap();
504        assert_eq!(
505            delta2.choices[0].delta.tool_calls[0].function.arguments,
506            Some("\"Rust programming\"}".to_string())
507        );
508    }
509
510    #[test]
511    fn test_response_with_error() {
512        let json = json!({
513            "id": "cmpl-abc123",
514            "object": "chat.completion.chunk",
515            "created": 1234567890,
516            "model": "gpt-3.5-turbo",
517            "provider": "openai",
518            "error": { "code": 500, "message": "Provider disconnected" },
519            "choices": [
520                { "index": 0, "delta": { "content": "" }, "finish_reason": "error" }
521            ]
522        });
523
524        let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
525        assert!(response.error.is_some());
526        let error = response.error.as_ref().unwrap();
527        assert_eq!(error.code, 500);
528        assert_eq!(error.message, "Provider disconnected");
529    }
530}