Skip to main content

rig/providers/openrouter/
streaming.rs

1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tracing::info_span;
9use tracing_futures::Instrument;
10
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::HttpClientExt;
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::json_utils;
15use crate::providers::openrouter::{
16    OpenRouterRequestParams, OpenrouterCompletionRequest, ReasoningDetails,
17};
18use crate::streaming;
19
20#[derive(Clone, Serialize, Deserialize, Debug)]
21pub struct StreamingCompletionResponse {
22    pub usage: Usage,
23}
24
25impl GetTokenUsage for StreamingCompletionResponse {
26    fn token_usage(&self) -> Option<crate::completion::Usage> {
27        let mut usage = crate::completion::Usage::new();
28
29        usage.input_tokens = self.usage.prompt_tokens as u64;
30        usage.output_tokens = self.usage.completion_tokens as u64;
31        usage.total_tokens = self.usage.total_tokens as u64;
32
33        Some(usage)
34    }
35}
36
37#[derive(Deserialize, Debug, PartialEq)]
38#[serde(rename_all = "snake_case")]
39pub enum FinishReason {
40    ToolCalls,
41    Stop,
42    Error,
43    ContentFilter,
44    Length,
45    #[serde(untagged)]
46    Other(String),
47}
48
49#[derive(Deserialize, Debug)]
50#[allow(dead_code)]
51struct StreamingChoice {
52    pub finish_reason: Option<FinishReason>,
53    pub native_finish_reason: Option<String>,
54    pub logprobs: Option<Value>,
55    pub index: usize,
56    pub delta: StreamingDelta,
57}
58
59#[derive(Deserialize, Debug)]
60struct StreamingFunction {
61    pub name: Option<String>,
62    pub arguments: Option<String>,
63}
64
65#[derive(Deserialize, Debug)]
66#[allow(dead_code)]
67struct StreamingToolCall {
68    pub index: usize,
69    pub id: Option<String>,
70    pub r#type: Option<String>,
71    pub function: StreamingFunction,
72}
73
74#[derive(Serialize, Deserialize, Debug, Clone, Default)]
75pub struct Usage {
76    pub prompt_tokens: u32,
77    pub completion_tokens: u32,
78    pub total_tokens: u32,
79}
80
81#[derive(Deserialize, Debug)]
82#[allow(dead_code)]
83struct ErrorResponse {
84    pub code: i32,
85    pub message: String,
86}
87
88#[derive(Deserialize, Debug)]
89#[allow(dead_code)]
90struct StreamingDelta {
91    pub role: Option<String>,
92    pub content: Option<String>,
93    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
94    pub tool_calls: Vec<StreamingToolCall>,
95    pub reasoning: Option<String>,
96    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
97    pub reasoning_details: Vec<ReasoningDetails>,
98}
99
100#[derive(Deserialize, Debug)]
101#[allow(dead_code)]
102struct StreamingCompletionChunk {
103    id: String,
104    model: String,
105    choices: Vec<StreamingChoice>,
106    usage: Option<Usage>,
107    error: Option<ErrorResponse>,
108}
109
110impl<T> super::CompletionModel<T>
111where
112    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
113{
114    pub(crate) async fn stream(
115        &self,
116        completion_request: CompletionRequest,
117    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
118    {
119        let request_model = completion_request
120            .model
121            .clone()
122            .unwrap_or_else(|| self.model.clone());
123        let preamble = completion_request.preamble.clone();
124        let mut request = OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
125            model: request_model.as_ref(),
126            request: completion_request,
127            strict_tools: self.strict_tools,
128        })?;
129
130        let params = json_utils::merge(
131            request.additional_params.unwrap_or(serde_json::json!({})),
132            serde_json::json!({"stream": true }),
133        );
134
135        request.additional_params = Some(params);
136
137        let body = serde_json::to_vec(&request)?;
138
139        let req = self
140            .client
141            .post("/chat/completions")?
142            .body(body)
143            .map_err(|x| CompletionError::HttpError(x.into()))?;
144
145        let span = if tracing::Span::current().is_disabled() {
146            info_span!(
147                target: "rig::completions",
148                "chat_streaming",
149                gen_ai.operation.name = "chat_streaming",
150                gen_ai.provider.name = "openrouter",
151                gen_ai.request.model = &request_model,
152                gen_ai.system_instructions = preamble,
153                gen_ai.response.id = tracing::field::Empty,
154                gen_ai.response.model = tracing::field::Empty,
155                gen_ai.usage.output_tokens = tracing::field::Empty,
156                gen_ai.usage.input_tokens = tracing::field::Empty,
157                gen_ai.usage.cached_tokens = tracing::field::Empty,
158            )
159        } else {
160            tracing::Span::current()
161        };
162
163        tracing::Instrument::instrument(
164            send_compatible_streaming_request(self.client.clone(), req),
165            span,
166        )
167        .await
168    }
169}
170
171pub async fn send_compatible_streaming_request<T>(
172    http_client: T,
173    req: Request<Vec<u8>>,
174) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
175where
176    T: HttpClientExt + Clone + 'static,
177{
178    let span = tracing::Span::current();
179    // Build the request with proper headers for SSE
180    let mut event_source = GenericEventSource::new(http_client, req);
181
182    let stream = stream! {
183        // Accumulate tool calls by index while streaming
184        let mut tool_calls: HashMap<usize, streaming::RawStreamingToolCall> = HashMap::new();
185        let mut final_usage = None;
186
187        while let Some(event_result) = event_source.next().await {
188            match event_result {
189                Ok(Event::Open) => {
190                    tracing::trace!("SSE connection opened");
191                    continue;
192                }
193
194                Ok(Event::Message(message)) => {
195                    if message.data.trim().is_empty() || message.data == "[DONE]" {
196                        continue;
197                    }
198
199                    let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
200                        Ok(data) => data,
201                        Err(error) => {
202                            tracing::error!(?error, message = message.data, "Failed to parse SSE message");
203                            continue;
204                        }
205                    };
206
207                    // Expect at least one choice
208                     let Some(choice) = data.choices.first() else {
209                        tracing::debug!("There is no choice");
210                        continue;
211                    };
212                    let delta = &choice.delta;
213
214                    if !delta.tool_calls.is_empty() {
215                        for tool_call in &delta.tool_calls {
216                            let index = tool_call.index;
217
218                            // Get or create tool call entry
219                            let existing_tool_call = tool_calls.entry(index).or_insert_with(streaming::RawStreamingToolCall::empty);
220
221                            // Update fields if present
222                            if let Some(id) = &tool_call.id && !id.is_empty() {
223                                    existing_tool_call.id = id.clone();
224                            }
225
226                            if let Some(name) = &tool_call.function.name && !name.is_empty() {
227                                    existing_tool_call.name = name.clone();
228                                    yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
229                                        id: existing_tool_call.id.clone(),
230                                        internal_call_id: existing_tool_call.internal_call_id.clone(),
231                                        content: streaming::ToolCallDeltaContent::Name(name.clone()),
232                                    });
233                            }
234
235                                // Convert current arguments to string if needed
236                            if let Some(chunk) = &tool_call.function.arguments && !chunk.is_empty() {
237                                let current_args = match &existing_tool_call.arguments {
238                                    serde_json::Value::Null => String::new(),
239                                    serde_json::Value::String(s) => s.clone(),
240                                    v => v.to_string(),
241                                };
242
243                                // Concatenate the new chunk
244                                let combined = format!("{current_args}{chunk}");
245
246                                // Try to parse as JSON if it looks complete
247                                if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
248                                    match serde_json::from_str(&combined) {
249                                        Ok(parsed) => existing_tool_call.arguments = parsed,
250                                        Err(_) => existing_tool_call.arguments = serde_json::Value::String(combined),
251                                    }
252                                } else {
253                                    existing_tool_call.arguments = serde_json::Value::String(combined);
254                                }
255
256                                // Emit the delta so UI can show progress
257                                yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
258                                    id: existing_tool_call.id.clone(),
259                                    internal_call_id: existing_tool_call.internal_call_id.clone(),
260                                    content: streaming::ToolCallDeltaContent::Delta(chunk.clone()),
261                                });
262                            }
263                        }
264
265                        // Update the signature and the additional params of the tool call if present
266                        for reasoning_detail in &delta.reasoning_details {
267                            if let ReasoningDetails::Encrypted { id, data, .. } = reasoning_detail
268                                && let Some(id) = id
269                                && let Some(tool_call) = tool_calls.values_mut().find(|tool_call| tool_call.id.eq(id))
270                                && let Ok(additional_params) = serde_json::to_value(reasoning_detail) {
271                                tool_call.signature = Some(data.clone());
272                                tool_call.additional_params = Some(additional_params);
273                            }
274                        }
275                    }
276
277                    // Streamed reasoning content
278                    if let Some(reasoning) = &delta.reasoning && !reasoning.is_empty() {
279                        yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
280                            reasoning: reasoning.clone(),
281                            id: None,
282                        });
283                    }
284
285                    // Streamed text content
286                    if let Some(content) = &delta.content && !content.is_empty() {
287                        yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
288                    }
289
290                    // Usage updates
291                    if let Some(usage) = data.usage {
292                        final_usage = Some(usage);
293                    }
294
295                    // Finish reason
296                    if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
297                        for (_idx, tool_call) in tool_calls.into_iter() {
298                            yield Ok(streaming::RawStreamingChoice::ToolCall(
299                                finalize_completed_streaming_tool_call(tool_call),
300                            ));
301                        }
302                        tool_calls = HashMap::new();
303                    }
304                }
305                Err(crate::http_client::Error::StreamEnded) => {
306                    break;
307                }
308                Err(error) => {
309                    tracing::error!(?error, "SSE error");
310                    yield Err(CompletionError::ProviderError(error.to_string()));
311                    break;
312                }
313            }
314        }
315
316        // Ensure event source is closed when stream ends
317        event_source.close();
318
319        // Flush any accumulated tool calls (that weren't emitted as ToolCall earlier)
320        for (_idx, tool_call) in tool_calls.into_iter() {
321            yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call));
322        }
323
324        // Final response with usage
325        yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
326            usage: final_usage.unwrap_or_default(),
327        }));
328    }.instrument(span);
329
330    Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
331        stream,
332    )))
333}
334
335fn finalize_completed_streaming_tool_call(
336    mut tool_call: streaming::RawStreamingToolCall,
337) -> streaming::RawStreamingToolCall {
338    if tool_call.arguments.is_null() {
339        tool_call.arguments = Value::Object(serde_json::Map::new());
340    }
341
342    tool_call
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use serde_json::json;
349
350    #[test]
351    fn test_streaming_completion_response_deserialization() {
352        let json = json!({
353            "id": "gen-abc123",
354            "choices": [{
355                "index": 0,
356                "delta": {
357                    "role": "assistant",
358                    "content": "Hello"
359                }
360            }],
361            "created": 1234567890u64,
362            "model": "gpt-3.5-turbo",
363            "object": "chat.completion.chunk"
364        });
365
366        let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
367        assert_eq!(response.id, "gen-abc123");
368        assert_eq!(response.model, "gpt-3.5-turbo");
369        assert_eq!(response.choices.len(), 1);
370    }
371
372    #[test]
373    fn test_delta_with_content() {
374        let json = json!({
375            "role": "assistant",
376            "content": "Hello, world!"
377        });
378
379        let delta: StreamingDelta = serde_json::from_value(json).unwrap();
380        assert_eq!(delta.role, Some("assistant".to_string()));
381        assert_eq!(delta.content, Some("Hello, world!".to_string()));
382    }
383
384    #[test]
385    fn test_delta_with_tool_call() {
386        let json = json!({
387            "role": "assistant",
388            "tool_calls": [{
389                "index": 0,
390                "id": "call_abc",
391                "type": "function",
392                "function": {
393                    "name": "get_weather",
394                    "arguments": "{\"location\":"
395                }
396            }]
397        });
398
399        let delta: StreamingDelta = serde_json::from_value(json).unwrap();
400        assert_eq!(delta.tool_calls.len(), 1);
401        assert_eq!(delta.tool_calls[0].index, 0);
402        assert_eq!(delta.tool_calls[0].id, Some("call_abc".to_string()));
403    }
404
405    #[test]
406    fn test_tool_call_with_partial_arguments() {
407        let json = json!({
408            "index": 0,
409            "id": null,
410            "type": null,
411            "function": {
412                "name": null,
413                "arguments": "Paris"
414            }
415        });
416
417        let tool_call: StreamingToolCall = serde_json::from_value(json).unwrap();
418        assert_eq!(tool_call.index, 0);
419        assert!(tool_call.id.is_none());
420        assert_eq!(tool_call.function.arguments, Some("Paris".to_string()));
421    }
422
423    #[test]
424    fn test_streaming_with_usage() {
425        let json = json!({
426            "id": "gen-xyz",
427            "choices": [{
428                "index": 0,
429                "delta": {
430                    "content": null
431                }
432            }],
433            "created": 1234567890u64,
434            "model": "gpt-4",
435            "object": "chat.completion.chunk",
436            "usage": {
437                "prompt_tokens": 100,
438                "completion_tokens": 50,
439                "total_tokens": 150
440            }
441        });
442
443        let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
444        assert!(response.usage.is_some());
445        let usage = response.usage.unwrap();
446        assert_eq!(usage.prompt_tokens, 100);
447        assert_eq!(usage.completion_tokens, 50);
448        assert_eq!(usage.total_tokens, 150);
449    }
450
451    #[test]
452    fn test_multiple_tool_call_deltas() {
453        // Simulates the sequence of deltas for a tool call with arguments
454        let start_json = json!({
455            "id": "gen-1",
456            "choices": [{
457                "index": 0,
458                "delta": {
459                    "tool_calls": [{
460                        "index": 0,
461                        "id": "call_123",
462                        "type": "function",
463                        "function": {
464                            "name": "search",
465                            "arguments": ""
466                        }
467                    }]
468                }
469            }],
470            "created": 1234567890u64,
471            "model": "gpt-4",
472            "object": "chat.completion.chunk"
473        });
474
475        let delta1_json = json!({
476            "id": "gen-2",
477            "choices": [{
478                "index": 0,
479                "delta": {
480                    "tool_calls": [{
481                        "index": 0,
482                        "function": {
483                            "arguments": "{\"query\":"
484                        }
485                    }]
486                }
487            }],
488            "created": 1234567890u64,
489            "model": "gpt-4",
490            "object": "chat.completion.chunk"
491        });
492
493        let delta2_json = json!({
494            "id": "gen-3",
495            "choices": [{
496                "index": 0,
497                "delta": {
498                    "tool_calls": [{
499                        "index": 0,
500                        "function": {
501                            "arguments": "\"Rust programming\"}"
502                        }
503                    }]
504                }
505            }],
506            "created": 1234567890u64,
507            "model": "gpt-4",
508            "object": "chat.completion.chunk"
509        });
510
511        // Verify all chunks deserialize
512        let start: StreamingCompletionChunk = serde_json::from_value(start_json).unwrap();
513        assert_eq!(
514            start.choices[0].delta.tool_calls[0].id,
515            Some("call_123".to_string())
516        );
517
518        let delta1: StreamingCompletionChunk = serde_json::from_value(delta1_json).unwrap();
519        assert_eq!(
520            delta1.choices[0].delta.tool_calls[0].function.arguments,
521            Some("{\"query\":".to_string())
522        );
523
524        let delta2: StreamingCompletionChunk = serde_json::from_value(delta2_json).unwrap();
525        assert_eq!(
526            delta2.choices[0].delta.tool_calls[0].function.arguments,
527            Some("\"Rust programming\"}".to_string())
528        );
529    }
530
531    #[test]
532    fn test_response_with_error() {
533        let json = json!({
534            "id": "cmpl-abc123",
535            "object": "chat.completion.chunk",
536            "created": 1234567890,
537            "model": "gpt-3.5-turbo",
538            "provider": "openai",
539            "error": { "code": 500, "message": "Provider disconnected" },
540            "choices": [
541                { "index": 0, "delta": { "content": "" }, "finish_reason": "error" }
542            ]
543        });
544
545        let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
546        assert!(response.error.is_some());
547        let error = response.error.as_ref().unwrap();
548        assert_eq!(error.code, 500);
549        assert_eq!(error.message, "Provider disconnected");
550    }
551}