Skip to main content

rig/providers/openrouter/
streaming.rs

1use http::Request;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use tracing::info_span;
5
6use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
7use crate::http_client::HttpClientExt;
8use crate::json_utils;
9use crate::providers::internal::openai_chat_completions_compatible::{
10    self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
11    CompatibleToolCallChunk,
12};
13use crate::providers::openrouter::{
14    OpenRouterRequestParams, OpenrouterCompletionRequest, ReasoningDetails,
15};
16use crate::streaming;
17
18#[derive(Clone, Serialize, Deserialize, Debug)]
19pub struct StreamingCompletionResponse {
20    pub usage: Usage,
21}
22
23impl GetTokenUsage for StreamingCompletionResponse {
24    fn token_usage(&self) -> Option<crate::completion::Usage> {
25        self.usage.token_usage()
26    }
27}
28
29#[derive(Deserialize, Debug, PartialEq)]
30#[serde(rename_all = "snake_case")]
31pub enum FinishReason {
32    ToolCalls,
33    Stop,
34    Error,
35    ContentFilter,
36    Length,
37    #[serde(untagged)]
38    Other(String),
39}
40
41#[derive(Deserialize, Debug)]
42#[allow(dead_code)]
43struct StreamingChoice {
44    pub finish_reason: Option<FinishReason>,
45    pub native_finish_reason: Option<String>,
46    pub logprobs: Option<Value>,
47    pub index: usize,
48    pub delta: StreamingDelta,
49}
50
51#[derive(Deserialize, Debug)]
52struct StreamingFunction {
53    pub name: Option<String>,
54    pub arguments: Option<String>,
55}
56
57#[derive(Deserialize, Debug)]
58#[allow(dead_code)]
59struct StreamingToolCall {
60    pub index: usize,
61    pub id: Option<String>,
62    pub r#type: Option<String>,
63    pub function: StreamingFunction,
64}
65
66impl From<&StreamingToolCall> for CompatibleToolCallChunk {
67    fn from(value: &StreamingToolCall) -> Self {
68        Self {
69            index: value.index,
70            id: value.id.clone(),
71            name: value.function.name.clone(),
72            arguments: value.function.arguments.clone(),
73        }
74    }
75}
76
77#[derive(Serialize, Deserialize, Debug, Clone, Default)]
78pub struct Usage {
79    pub prompt_tokens: u32,
80    pub completion_tokens: u32,
81    pub total_tokens: u32,
82}
83
84impl GetTokenUsage for Usage {
85    fn token_usage(&self) -> Option<crate::completion::Usage> {
86        Some(crate::providers::internal::completion_usage(
87            self.prompt_tokens as u64,
88            self.completion_tokens as u64,
89            self.total_tokens as u64,
90            0,
91        ))
92    }
93}
94
95#[derive(Deserialize, Debug)]
96#[allow(dead_code)]
97struct ErrorResponse {
98    pub code: i32,
99    pub message: String,
100}
101
102#[derive(Deserialize, Debug)]
103#[allow(dead_code)]
104struct StreamingDelta {
105    pub role: Option<String>,
106    pub content: Option<String>,
107    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
108    pub tool_calls: Vec<StreamingToolCall>,
109    pub reasoning: Option<String>,
110    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
111    pub reasoning_details: Vec<ReasoningDetails>,
112}
113
114#[derive(Deserialize, Debug)]
115#[allow(dead_code)]
116struct StreamingCompletionChunk {
117    id: String,
118    model: String,
119    choices: Vec<StreamingChoice>,
120    usage: Option<Usage>,
121    error: Option<ErrorResponse>,
122}
123
124impl<T> super::CompletionModel<T>
125where
126    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
127{
128    pub(crate) async fn stream(
129        &self,
130        completion_request: CompletionRequest,
131    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
132    {
133        let request_model = completion_request
134            .model
135            .clone()
136            .unwrap_or_else(|| self.model.clone());
137        let preamble = completion_request.preamble.clone();
138        let mut request = OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
139            model: request_model.as_ref(),
140            request: completion_request,
141            strict_tools: self.strict_tools,
142        })?;
143
144        let params = json_utils::merge(
145            request.additional_params.unwrap_or(serde_json::json!({})),
146            serde_json::json!({"stream": true }),
147        );
148
149        request.additional_params = Some(params);
150
151        let body = serde_json::to_vec(&request)?;
152
153        let req = self
154            .client
155            .post("/chat/completions")?
156            .body(body)
157            .map_err(|x| CompletionError::HttpError(x.into()))?;
158
159        let span = if tracing::Span::current().is_disabled() {
160            info_span!(
161                target: "rig::completions",
162                "chat_streaming",
163                gen_ai.operation.name = "chat_streaming",
164                gen_ai.provider.name = "openrouter",
165                gen_ai.request.model = &request_model,
166                gen_ai.system_instructions = preamble,
167                gen_ai.response.id = tracing::field::Empty,
168                gen_ai.response.model = tracing::field::Empty,
169                gen_ai.usage.output_tokens = tracing::field::Empty,
170                gen_ai.usage.input_tokens = tracing::field::Empty,
171                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
172            )
173        } else {
174            tracing::Span::current()
175        };
176
177        tracing::Instrument::instrument(
178            send_compatible_streaming_request(self.client.clone(), req),
179            span,
180        )
181        .await
182    }
183}
184
185#[derive(Clone, Copy)]
186struct OpenRouterCompatibleProfile;
187
188impl CompatibleStreamProfile for OpenRouterCompatibleProfile {
189    type Usage = Usage;
190    type Detail = ReasoningDetails;
191    type FinalResponse = StreamingCompletionResponse;
192
193    fn normalize_chunk(
194        &self,
195        data: &str,
196    ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
197        let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
198            Ok(data) => data,
199            Err(error) => {
200                tracing::error!(?error, message = data, "Failed to parse SSE message");
201                return Ok(None);
202            }
203        };
204
205        Ok(Some(
206            openai_chat_completions_compatible::normalize_first_choice_chunk(
207                Some(data.id),
208                Some(data.model),
209                data.usage,
210                &data.choices,
211                |choice| CompatibleChoiceData {
212                    finish_reason: if choice.finish_reason == Some(FinishReason::ToolCalls) {
213                        CompatibleFinishReason::ToolCalls
214                    } else {
215                        CompatibleFinishReason::Other
216                    },
217                    text: choice.delta.content.clone(),
218                    reasoning: choice.delta.reasoning.clone(),
219                    tool_calls: openai_chat_completions_compatible::tool_call_chunks(
220                        &choice.delta.tool_calls,
221                    ),
222                    details: choice.delta.reasoning_details.clone(),
223                },
224            ),
225        ))
226    }
227
228    fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
229        StreamingCompletionResponse { usage }
230    }
231
232    fn decorate_tool_call(
233        &self,
234        detail: &Self::Detail,
235        tool_calls: &mut std::collections::HashMap<usize, crate::streaming::RawStreamingToolCall>,
236    ) {
237        if let ReasoningDetails::Encrypted { id, data, .. } = detail
238            && let Some(id) = id
239            && let Some(tool_call) = tool_calls
240                .values_mut()
241                .find(|tool_call| tool_call.id.eq(id))
242            && let Ok(additional_params) = serde_json::to_value(detail)
243        {
244            tool_call.signature = Some(data.clone());
245            tool_call.additional_params = Some(additional_params);
246        }
247    }
248}
249
250pub async fn send_compatible_streaming_request<T>(
251    http_client: T,
252    req: Request<Vec<u8>>,
253) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
254where
255    T: HttpClientExt + Clone + 'static,
256{
257    openai_chat_completions_compatible::send_compatible_streaming_request(
258        http_client,
259        req,
260        OpenRouterCompatibleProfile,
261    )
262    .await
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use crate::http_client::mock::MockStreamingClient;
269    use crate::providers::internal::openai_chat_completions_compatible::test_support::sse_bytes_from_data_lines;
270    use crate::streaming::StreamedAssistantContent;
271    use futures::StreamExt;
272    use serde_json::json;
273
274    #[test]
275    fn test_streaming_completion_response_deserialization() {
276        let json = json!({
277            "id": "gen-abc123",
278            "choices": [{
279                "index": 0,
280                "delta": {
281                    "role": "assistant",
282                    "content": "Hello"
283                }
284            }],
285            "created": 1234567890u64,
286            "model": "gpt-3.5-turbo",
287            "object": "chat.completion.chunk"
288        });
289
290        let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
291        assert_eq!(response.id, "gen-abc123");
292        assert_eq!(response.model, "gpt-3.5-turbo");
293        assert_eq!(response.choices.len(), 1);
294    }
295
296    #[test]
297    fn test_delta_with_content() {
298        let json = json!({
299            "role": "assistant",
300            "content": "Hello, world!"
301        });
302
303        let delta: StreamingDelta = serde_json::from_value(json).unwrap();
304        assert_eq!(delta.role, Some("assistant".to_string()));
305        assert_eq!(delta.content, Some("Hello, world!".to_string()));
306    }
307
308    #[test]
309    fn test_delta_with_tool_call() {
310        let json = json!({
311            "role": "assistant",
312            "tool_calls": [{
313                "index": 0,
314                "id": "call_abc",
315                "type": "function",
316                "function": {
317                    "name": "get_weather",
318                    "arguments": "{\"location\":"
319                }
320            }]
321        });
322
323        let delta: StreamingDelta = serde_json::from_value(json).unwrap();
324        assert_eq!(delta.tool_calls.len(), 1);
325        assert_eq!(delta.tool_calls[0].index, 0);
326        assert_eq!(delta.tool_calls[0].id, Some("call_abc".to_string()));
327    }
328
329    #[test]
330    fn test_tool_call_with_partial_arguments() {
331        let json = json!({
332            "index": 0,
333            "id": null,
334            "type": null,
335            "function": {
336                "name": null,
337                "arguments": "Paris"
338            }
339        });
340
341        let tool_call: StreamingToolCall = serde_json::from_value(json).unwrap();
342        assert_eq!(tool_call.index, 0);
343        assert!(tool_call.id.is_none());
344        assert_eq!(tool_call.function.arguments, Some("Paris".to_string()));
345    }
346
347    #[test]
348    fn test_streaming_with_usage() {
349        let json = json!({
350            "id": "gen-xyz",
351            "choices": [{
352                "index": 0,
353                "delta": {
354                    "content": null
355                }
356            }],
357            "created": 1234567890u64,
358            "model": "gpt-4",
359            "object": "chat.completion.chunk",
360            "usage": {
361                "prompt_tokens": 100,
362                "completion_tokens": 50,
363                "total_tokens": 150
364            }
365        });
366
367        let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
368        assert!(response.usage.is_some());
369        let usage = response.usage.unwrap();
370        assert_eq!(usage.prompt_tokens, 100);
371        assert_eq!(usage.completion_tokens, 50);
372        assert_eq!(usage.total_tokens, 150);
373    }
374
375    #[test]
376    fn test_multiple_tool_call_deltas() {
377        // Simulates the sequence of deltas for a tool call with arguments
378        let start_json = json!({
379            "id": "gen-1",
380            "choices": [{
381                "index": 0,
382                "delta": {
383                    "tool_calls": [{
384                        "index": 0,
385                        "id": "call_123",
386                        "type": "function",
387                        "function": {
388                            "name": "search",
389                            "arguments": ""
390                        }
391                    }]
392                }
393            }],
394            "created": 1234567890u64,
395            "model": "gpt-4",
396            "object": "chat.completion.chunk"
397        });
398
399        let delta1_json = json!({
400            "id": "gen-2",
401            "choices": [{
402                "index": 0,
403                "delta": {
404                    "tool_calls": [{
405                        "index": 0,
406                        "function": {
407                            "arguments": "{\"query\":"
408                        }
409                    }]
410                }
411            }],
412            "created": 1234567890u64,
413            "model": "gpt-4",
414            "object": "chat.completion.chunk"
415        });
416
417        let delta2_json = json!({
418            "id": "gen-3",
419            "choices": [{
420                "index": 0,
421                "delta": {
422                    "tool_calls": [{
423                        "index": 0,
424                        "function": {
425                            "arguments": "\"Rust programming\"}"
426                        }
427                    }]
428                }
429            }],
430            "created": 1234567890u64,
431            "model": "gpt-4",
432            "object": "chat.completion.chunk"
433        });
434
435        // Verify all chunks deserialize
436        let start: StreamingCompletionChunk = serde_json::from_value(start_json).unwrap();
437        assert_eq!(
438            start.choices[0].delta.tool_calls[0].id,
439            Some("call_123".to_string())
440        );
441
442        let delta1: StreamingCompletionChunk = serde_json::from_value(delta1_json).unwrap();
443        assert_eq!(
444            delta1.choices[0].delta.tool_calls[0].function.arguments,
445            Some("{\"query\":".to_string())
446        );
447
448        let delta2: StreamingCompletionChunk = serde_json::from_value(delta2_json).unwrap();
449        assert_eq!(
450            delta2.choices[0].delta.tool_calls[0].function.arguments,
451            Some("\"Rust programming\"}".to_string())
452        );
453    }
454
455    #[test]
456    fn test_response_with_error() {
457        let json = json!({
458            "id": "cmpl-abc123",
459            "object": "chat.completion.chunk",
460            "created": 1234567890,
461            "model": "gpt-3.5-turbo",
462            "provider": "openai",
463            "error": { "code": 500, "message": "Provider disconnected" },
464            "choices": [
465                { "index": 0, "delta": { "content": "" }, "finish_reason": "error" }
466            ]
467        });
468
469        let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
470        assert!(response.error.is_some());
471        let error = response.error.as_ref().unwrap();
472        assert_eq!(error.code, 500);
473        assert_eq!(error.message, "Provider disconnected");
474    }
475
476    #[tokio::test]
477    async fn encrypted_reasoning_details_attach_to_emitted_tool_calls() {
478        let client = MockStreamingClient {
479            sse_bytes: sse_bytes_from_data_lines([
480                "{\"id\":\"gen-1\",\"model\":\"openai/gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"type\":\"function\",\"function\":{\"name\":\"search\",\"arguments\":\"\"}}],\"reasoning_details\":[]},\"finish_reason\":null}],\"usage\":null}",
481                "{\"id\":\"gen-2\",\"model\":\"openai/gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[],\"reasoning_details\":[{\"type\":\"reasoning.encrypted\",\"id\":\"call_123\",\"format\":\"opaque\",\"index\":0,\"data\":\"enc_blob\"}]},\"finish_reason\":null}],\"usage\":null}",
482                "{\"id\":\"gen-3\",\"model\":\"openai/gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[],\"reasoning_details\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}",
483                "[DONE]",
484            ]),
485        };
486
487        let req = Request::builder()
488            .method("POST")
489            .uri("http://localhost/v1/chat/completions")
490            .body(Vec::new())
491            .expect("request should build");
492
493        let mut stream = send_compatible_streaming_request(client, req)
494            .await
495            .expect("stream should start");
496
497        let tool_call = loop {
498            match stream.next().await.expect("stream should yield an item") {
499                Ok(StreamedAssistantContent::ToolCall { tool_call, .. }) => break tool_call,
500                Ok(_) => continue,
501                Err(err) => panic!("stream should not error: {err}"),
502            }
503        };
504
505        assert_eq!(tool_call.id, "call_123");
506        assert_eq!(tool_call.function.name, "search");
507        assert_eq!(tool_call.function.arguments, serde_json::json!({}));
508        assert_eq!(tool_call.signature.as_deref(), Some("enc_blob"));
509        assert_eq!(
510            tool_call.additional_params,
511            Some(json!({
512                "type": "reasoning.encrypted",
513                "id": "call_123",
514                "format": "opaque",
515                "index": 0,
516                "data": "enc_blob"
517            }))
518        );
519    }
520}