rig/providers/openrouter/
streaming.rs

1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tracing::info_span;
9use tracing_futures::Instrument;
10
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::HttpClientExt;
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::json_utils;
15use crate::providers::openrouter::{
16    OpenRouterRequestParams, OpenrouterCompletionRequest, ReasoningDetails,
17};
18use crate::streaming;
19
20#[derive(Clone, Serialize, Deserialize, Debug)]
21pub struct StreamingCompletionResponse {
22    pub usage: Usage,
23}
24
25impl GetTokenUsage for StreamingCompletionResponse {
26    fn token_usage(&self) -> Option<crate::completion::Usage> {
27        let mut usage = crate::completion::Usage::new();
28
29        usage.input_tokens = self.usage.prompt_tokens as u64;
30        usage.output_tokens = self.usage.completion_tokens as u64;
31        usage.total_tokens = self.usage.total_tokens as u64;
32
33        Some(usage)
34    }
35}
36
37#[derive(Deserialize, Debug, PartialEq)]
38#[serde(rename_all = "snake_case")]
39pub enum FinishReason {
40    ToolCalls,
41    Stop,
42    Error,
43    ContentFilter,
44    Length,
45    #[serde(untagged)]
46    Other(String),
47}
48
49#[derive(Deserialize, Debug)]
50#[allow(dead_code)]
51struct StreamingChoice {
52    pub finish_reason: Option<FinishReason>,
53    pub native_finish_reason: Option<String>,
54    pub logprobs: Option<Value>,
55    pub index: usize,
56    pub delta: StreamingDelta,
57}
58
59#[derive(Deserialize, Debug)]
60struct StreamingFunction {
61    pub name: Option<String>,
62    pub arguments: Option<String>,
63}
64
65#[derive(Deserialize, Debug)]
66#[allow(dead_code)]
67struct StreamingToolCall {
68    pub index: usize,
69    pub id: Option<String>,
70    pub r#type: Option<String>,
71    pub function: StreamingFunction,
72}
73
74#[derive(Serialize, Deserialize, Debug, Clone, Default)]
75pub struct Usage {
76    pub prompt_tokens: u32,
77    pub completion_tokens: u32,
78    pub total_tokens: u32,
79}
80
81#[derive(Deserialize, Debug)]
82#[allow(dead_code)]
83struct ErrorResponse {
84    pub code: i32,
85    pub message: String,
86}
87
88#[derive(Deserialize, Debug)]
89#[allow(dead_code)]
90struct StreamingDelta {
91    pub role: Option<String>,
92    pub content: Option<String>,
93    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
94    pub tool_calls: Vec<StreamingToolCall>,
95    pub reasoning: Option<String>,
96    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
97    pub reasoning_details: Vec<ReasoningDetails>,
98}
99
100#[derive(Deserialize, Debug)]
101#[allow(dead_code)]
102struct StreamingCompletionChunk {
103    id: String,
104    model: String,
105    choices: Vec<StreamingChoice>,
106    usage: Option<Usage>,
107    error: Option<ErrorResponse>,
108}
109
110impl<T> super::CompletionModel<T>
111where
112    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
113{
114    pub(crate) async fn stream(
115        &self,
116        completion_request: CompletionRequest,
117    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
118    {
119        let preamble = completion_request.preamble.clone();
120        let mut request = OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
121            model: self.model.as_ref(),
122            request: completion_request,
123            strict_tools: self.strict_tools,
124        })?;
125
126        let params = json_utils::merge(
127            request.additional_params.unwrap_or(serde_json::json!({})),
128            serde_json::json!({"stream": true }),
129        );
130
131        request.additional_params = Some(params);
132
133        let body = serde_json::to_vec(&request)?;
134
135        let req = self
136            .client
137            .post("/chat/completions")?
138            .body(body)
139            .map_err(|x| CompletionError::HttpError(x.into()))?;
140
141        let span = if tracing::Span::current().is_disabled() {
142            info_span!(
143                target: "rig::completions",
144                "chat_streaming",
145                gen_ai.operation.name = "chat_streaming",
146                gen_ai.provider.name = "openrouter",
147                gen_ai.request.model = self.model,
148                gen_ai.system_instructions = preamble,
149                gen_ai.response.id = tracing::field::Empty,
150                gen_ai.response.model = tracing::field::Empty,
151                gen_ai.usage.output_tokens = tracing::field::Empty,
152                gen_ai.usage.input_tokens = tracing::field::Empty,
153            )
154        } else {
155            tracing::Span::current()
156        };
157
158        tracing::Instrument::instrument(
159            send_compatible_streaming_request(self.client.clone(), req),
160            span,
161        )
162        .await
163    }
164}
165
166pub async fn send_compatible_streaming_request<T>(
167    http_client: T,
168    req: Request<Vec<u8>>,
169) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
170where
171    T: HttpClientExt + Clone + 'static,
172{
173    let span = tracing::Span::current();
174    // Build the request with proper headers for SSE
175    let mut event_source = GenericEventSource::new(http_client, req);
176
177    let stream = stream! {
178        // Accumulate tool calls by index while streaming
179        let mut tool_calls: HashMap<usize, streaming::RawStreamingToolCall> = HashMap::new();
180        let mut final_usage = None;
181
182        while let Some(event_result) = event_source.next().await {
183            match event_result {
184                Ok(Event::Open) => {
185                    tracing::trace!("SSE connection opened");
186                    continue;
187                }
188
189                Ok(Event::Message(message)) => {
190                    if message.data.trim().is_empty() || message.data == "[DONE]" {
191                        continue;
192                    }
193
194                    let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
195                        Ok(data) => data,
196                        Err(error) => {
197                            tracing::error!(?error, message = message.data, "Failed to parse SSE message");
198                            continue;
199                        }
200                    };
201
202                    // Expect at least one choice
203                     let Some(choice) = data.choices.first() else {
204                        tracing::debug!("There is no choice");
205                        continue;
206                    };
207                    let delta = &choice.delta;
208
209                    if !delta.tool_calls.is_empty() {
210                        for tool_call in &delta.tool_calls {
211                            let index = tool_call.index;
212
213                            // Get or create tool call entry
214                            let existing_tool_call = tool_calls.entry(index).or_insert_with(streaming::RawStreamingToolCall::empty);
215
216                            // Update fields if present
217                            if let Some(id) = &tool_call.id && !id.is_empty() {
218                                    existing_tool_call.id = id.clone();
219                            }
220
221                            if let Some(name) = &tool_call.function.name && !name.is_empty() {
222                                    existing_tool_call.name = name.clone();
223                                    yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
224                                        id: existing_tool_call.id.clone(),
225                                        content: streaming::ToolCallDeltaContent::Name(name.clone()),
226                                    });
227                            }
228
229                                // Convert current arguments to string if needed
230                            if let Some(chunk) = &tool_call.function.arguments && !chunk.is_empty() {
231                                let current_args = match &existing_tool_call.arguments {
232                                    serde_json::Value::Null => String::new(),
233                                    serde_json::Value::String(s) => s.clone(),
234                                    v => v.to_string(),
235                                };
236
237                                // Concatenate the new chunk
238                                let combined = format!("{current_args}{chunk}");
239
240                                // Try to parse as JSON if it looks complete
241                                if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
242                                    match serde_json::from_str(&combined) {
243                                        Ok(parsed) => existing_tool_call.arguments = parsed,
244                                        Err(_) => existing_tool_call.arguments = serde_json::Value::String(combined),
245                                    }
246                                } else {
247                                    existing_tool_call.arguments = serde_json::Value::String(combined);
248                                }
249
250                                // Emit the delta so UI can show progress
251                                yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
252                                    id: existing_tool_call.id.clone(),
253                                    content: streaming::ToolCallDeltaContent::Delta(chunk.clone()),
254                                });
255                            }
256                        }
257
258                        // Update the signature and the additional params of the tool call if present
259                        for reasoning_detail in &delta.reasoning_details {
260                            if let ReasoningDetails::Encrypted { id, data, .. } = reasoning_detail
261                                && let Some(id) = id
262                                && let Some(tool_call) = tool_calls.values_mut().find(|tool_call| tool_call.id.eq(id))
263                                && let Ok(additional_params) = serde_json::to_value(reasoning_detail) {
264                                tool_call.signature = Some(data.clone());
265                                tool_call.additional_params = Some(additional_params);
266                            }
267                        }
268                    }
269
270                    // Streamed reasoning content
271                    if let Some(reasoning) = &delta.reasoning && !reasoning.is_empty() {
272                        yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
273                            reasoning: reasoning.clone(),
274                            id: None,
275                        });
276                    }
277
278                    // Streamed text content
279                    if let Some(content) = &delta.content && !content.is_empty() {
280                        yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
281                    }
282
283                    // Usage updates
284                    if let Some(usage) = data.usage {
285                        final_usage = Some(usage);
286                    }
287
288                    // Finish reason
289                    if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
290                        for (_idx, tool_call) in tool_calls.into_iter() {
291                            yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call));
292                        }
293                        tool_calls = HashMap::new();
294                    }
295                }
296                Err(crate::http_client::Error::StreamEnded) => {
297                    break;
298                }
299                Err(error) => {
300                    tracing::error!(?error, "SSE error");
301                    yield Err(CompletionError::ProviderError(error.to_string()));
302                    break;
303                }
304            }
305        }
306
307        // Ensure event source is closed when stream ends
308        event_source.close();
309
310        // Flush any accumulated tool calls (that weren't emitted as ToolCall earlier)
311        for (_idx, tool_call) in tool_calls.into_iter() {
312            yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call));
313        }
314
315        // Final response with usage
316        yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
317            usage: final_usage.unwrap_or_default(),
318        }));
319    }.instrument(span);
320
321    Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
322        stream,
323    )))
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use serde_json::json;
330
331    #[test]
332    fn test_streaming_completion_response_deserialization() {
333        let json = json!({
334            "id": "gen-abc123",
335            "choices": [{
336                "index": 0,
337                "delta": {
338                    "role": "assistant",
339                    "content": "Hello"
340                }
341            }],
342            "created": 1234567890u64,
343            "model": "gpt-3.5-turbo",
344            "object": "chat.completion.chunk"
345        });
346
347        let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
348        assert_eq!(response.id, "gen-abc123");
349        assert_eq!(response.model, "gpt-3.5-turbo");
350        assert_eq!(response.choices.len(), 1);
351    }
352
353    #[test]
354    fn test_delta_with_content() {
355        let json = json!({
356            "role": "assistant",
357            "content": "Hello, world!"
358        });
359
360        let delta: StreamingDelta = serde_json::from_value(json).unwrap();
361        assert_eq!(delta.role, Some("assistant".to_string()));
362        assert_eq!(delta.content, Some("Hello, world!".to_string()));
363    }
364
365    #[test]
366    fn test_delta_with_tool_call() {
367        let json = json!({
368            "role": "assistant",
369            "tool_calls": [{
370                "index": 0,
371                "id": "call_abc",
372                "type": "function",
373                "function": {
374                    "name": "get_weather",
375                    "arguments": "{\"location\":"
376                }
377            }]
378        });
379
380        let delta: StreamingDelta = serde_json::from_value(json).unwrap();
381        assert_eq!(delta.tool_calls.len(), 1);
382        assert_eq!(delta.tool_calls[0].index, 0);
383        assert_eq!(delta.tool_calls[0].id, Some("call_abc".to_string()));
384    }
385
386    #[test]
387    fn test_tool_call_with_partial_arguments() {
388        let json = json!({
389            "index": 0,
390            "id": null,
391            "type": null,
392            "function": {
393                "name": null,
394                "arguments": "Paris"
395            }
396        });
397
398        let tool_call: StreamingToolCall = serde_json::from_value(json).unwrap();
399        assert_eq!(tool_call.index, 0);
400        assert!(tool_call.id.is_none());
401        assert_eq!(tool_call.function.arguments, Some("Paris".to_string()));
402    }
403
404    #[test]
405    fn test_streaming_with_usage() {
406        let json = json!({
407            "id": "gen-xyz",
408            "choices": [{
409                "index": 0,
410                "delta": {
411                    "content": null
412                }
413            }],
414            "created": 1234567890u64,
415            "model": "gpt-4",
416            "object": "chat.completion.chunk",
417            "usage": {
418                "prompt_tokens": 100,
419                "completion_tokens": 50,
420                "total_tokens": 150
421            }
422        });
423
424        let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
425        assert!(response.usage.is_some());
426        let usage = response.usage.unwrap();
427        assert_eq!(usage.prompt_tokens, 100);
428        assert_eq!(usage.completion_tokens, 50);
429        assert_eq!(usage.total_tokens, 150);
430    }
431
432    #[test]
433    fn test_multiple_tool_call_deltas() {
434        // Simulates the sequence of deltas for a tool call with arguments
435        let start_json = json!({
436            "id": "gen-1",
437            "choices": [{
438                "index": 0,
439                "delta": {
440                    "tool_calls": [{
441                        "index": 0,
442                        "id": "call_123",
443                        "type": "function",
444                        "function": {
445                            "name": "search",
446                            "arguments": ""
447                        }
448                    }]
449                }
450            }],
451            "created": 1234567890u64,
452            "model": "gpt-4",
453            "object": "chat.completion.chunk"
454        });
455
456        let delta1_json = json!({
457            "id": "gen-2",
458            "choices": [{
459                "index": 0,
460                "delta": {
461                    "tool_calls": [{
462                        "index": 0,
463                        "function": {
464                            "arguments": "{\"query\":"
465                        }
466                    }]
467                }
468            }],
469            "created": 1234567890u64,
470            "model": "gpt-4",
471            "object": "chat.completion.chunk"
472        });
473
474        let delta2_json = json!({
475            "id": "gen-3",
476            "choices": [{
477                "index": 0,
478                "delta": {
479                    "tool_calls": [{
480                        "index": 0,
481                        "function": {
482                            "arguments": "\"Rust programming\"}"
483                        }
484                    }]
485                }
486            }],
487            "created": 1234567890u64,
488            "model": "gpt-4",
489            "object": "chat.completion.chunk"
490        });
491
492        // Verify all chunks deserialize
493        let start: StreamingCompletionChunk = serde_json::from_value(start_json).unwrap();
494        assert_eq!(
495            start.choices[0].delta.tool_calls[0].id,
496            Some("call_123".to_string())
497        );
498
499        let delta1: StreamingCompletionChunk = serde_json::from_value(delta1_json).unwrap();
500        assert_eq!(
501            delta1.choices[0].delta.tool_calls[0].function.arguments,
502            Some("{\"query\":".to_string())
503        );
504
505        let delta2: StreamingCompletionChunk = serde_json::from_value(delta2_json).unwrap();
506        assert_eq!(
507            delta2.choices[0].delta.tool_calls[0].function.arguments,
508            Some("\"Rust programming\"}".to_string())
509        );
510    }
511
512    #[test]
513    fn test_response_with_error() {
514        let json = json!({
515            "id": "cmpl-abc123",
516            "object": "chat.completion.chunk",
517            "created": 1234567890,
518            "model": "gpt-3.5-turbo",
519            "provider": "openai",
520            "error": { "code": 500, "message": "Provider disconnected" },
521            "choices": [
522                { "index": 0, "delta": { "content": "" }, "finish_reason": "error" }
523            ]
524        });
525
526        let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
527        assert!(response.error.is_some());
528        let error = response.error.as_ref().unwrap();
529        assert_eq!(error.code, 500);
530        assert_eq!(error.message, "Provider disconnected");
531    }
532}