Skip to main content

rig_core/providers/openrouter/
streaming.rs

1use http::Request;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use tracing::info_span;
5
6use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
7use crate::http_client::HttpClientExt;
8use crate::json_utils;
9use crate::providers::internal::openai_chat_completions_compatible::{
10    self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
11    CompatibleToolCallChunk,
12};
13use crate::providers::openrouter::{
14    OpenRouterRequestParams, OpenrouterCompletionRequest, ReasoningDetails,
15};
16use crate::streaming;
17
18#[derive(Clone, Serialize, Deserialize, Debug)]
19pub struct StreamingCompletionResponse {
20    pub usage: Usage,
21}
22
23impl GetTokenUsage for StreamingCompletionResponse {
24    fn token_usage(&self) -> Option<crate::completion::Usage> {
25        self.usage.token_usage()
26    }
27}
28
29#[derive(Deserialize, Debug, PartialEq)]
30#[serde(rename_all = "snake_case")]
31pub enum FinishReason {
32    ToolCalls,
33    Stop,
34    Error,
35    ContentFilter,
36    Length,
37    #[serde(untagged)]
38    Other(String),
39}
40
41#[derive(Deserialize, Debug)]
42#[allow(dead_code)]
43struct StreamingChoice {
44    pub finish_reason: Option<FinishReason>,
45    pub native_finish_reason: Option<String>,
46    pub logprobs: Option<Value>,
47    pub index: usize,
48    pub delta: StreamingDelta,
49}
50
51#[derive(Deserialize, Debug)]
52struct StreamingFunction {
53    pub name: Option<String>,
54    pub arguments: Option<String>,
55}
56
57#[derive(Deserialize, Debug)]
58#[allow(dead_code)]
59struct StreamingToolCall {
60    pub index: usize,
61    pub id: Option<String>,
62    pub r#type: Option<String>,
63    pub function: StreamingFunction,
64}
65
66impl From<&StreamingToolCall> for CompatibleToolCallChunk {
67    fn from(value: &StreamingToolCall) -> Self {
68        Self {
69            index: value.index,
70            id: value.id.clone(),
71            name: value.function.name.clone(),
72            arguments: value.function.arguments.clone(),
73        }
74    }
75}
76
77#[derive(Serialize, Deserialize, Debug, Clone, Default)]
78pub struct Usage {
79    pub prompt_tokens: u32,
80    pub completion_tokens: u32,
81    pub total_tokens: u32,
82    /// OpenAI-compatible prompt-token details, returned by OpenRouter when a
83    /// provider reports cache activity (Anthropic with cache_control, OpenAI
84    /// with server-side automatic caching).
85    #[serde(default, skip_serializing_if = "Option::is_none")]
86    pub prompt_tokens_details: Option<PromptTokensDetails>,
87}
88
89/// Prompt-token breakdown reported by OpenRouter for cached streaming requests.
90// `u32` matches the parent `Usage` struct in this module; the non-streaming counterpart
91// in `client.rs` uses `usize` to match its own parent.
92#[derive(Serialize, Deserialize, Debug, Clone, Default)]
93pub struct PromptTokensDetails {
94    /// Tokens served from cache (cache hit).
95    #[serde(default)]
96    pub cached_tokens: u32,
97    /// Tokens written to cache on this call (cache miss that populated the cache).
98    #[serde(default)]
99    pub cache_write_tokens: u32,
100}
101
102impl GetTokenUsage for Usage {
103    fn token_usage(&self) -> Option<crate::completion::Usage> {
104        let (cached_input, cache_creation) = self
105            .prompt_tokens_details
106            .as_ref()
107            .map(|d| (d.cached_tokens as u64, d.cache_write_tokens as u64))
108            .unwrap_or((0, 0));
109        Some(crate::completion::Usage {
110            input_tokens: self.prompt_tokens as u64,
111            output_tokens: self.completion_tokens as u64,
112            total_tokens: self.total_tokens as u64,
113            cached_input_tokens: cached_input,
114            cache_creation_input_tokens: cache_creation,
115            tool_use_prompt_tokens: 0,
116            reasoning_tokens: 0,
117        })
118    }
119}
120
121#[derive(Deserialize, Debug)]
122#[allow(dead_code)]
123struct ErrorResponse {
124    pub code: i32,
125    pub message: String,
126}
127
128#[derive(Deserialize, Debug)]
129#[allow(dead_code)]
130struct StreamingDelta {
131    pub role: Option<String>,
132    pub content: Option<String>,
133    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
134    pub tool_calls: Vec<StreamingToolCall>,
135    pub reasoning: Option<String>,
136    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
137    pub reasoning_details: Vec<ReasoningDetails>,
138}
139
140#[derive(Deserialize, Debug)]
141#[allow(dead_code)]
142struct StreamingCompletionChunk {
143    id: String,
144    model: String,
145    choices: Vec<StreamingChoice>,
146    usage: Option<Usage>,
147    error: Option<ErrorResponse>,
148}
149
150impl<T> super::CompletionModel<T>
151where
152    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
153{
154    pub(crate) async fn stream(
155        &self,
156        completion_request: CompletionRequest,
157    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
158    {
159        let request_model = completion_request
160            .model
161            .clone()
162            .unwrap_or_else(|| self.model.clone());
163        let preamble = completion_request.preamble.clone();
164        let mut request = OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
165            model: request_model.as_ref(),
166            request: completion_request,
167            strict_tools: self.strict_tools,
168        })?;
169
170        let params = json_utils::merge(
171            request.additional_params.unwrap_or(serde_json::json!({})),
172            serde_json::json!({"stream": true }),
173        );
174
175        request.additional_params = Some(params);
176
177        let body = serde_json::to_vec(&super::completion::final_request_body(
178            &request,
179            self.prompt_caching,
180        )?)?;
181
182        let req = self
183            .client
184            .post("/chat/completions")?
185            .body(body)
186            .map_err(|x| CompletionError::HttpError(x.into()))?;
187
188        let span = if tracing::Span::current().is_disabled() {
189            info_span!(
190                target: "rig::completions",
191                "chat_streaming",
192                gen_ai.operation.name = "chat_streaming",
193                gen_ai.provider.name = "openrouter",
194                gen_ai.request.model = &request_model,
195                gen_ai.system_instructions = preamble,
196                gen_ai.response.id = tracing::field::Empty,
197                gen_ai.response.model = tracing::field::Empty,
198                gen_ai.usage.output_tokens = tracing::field::Empty,
199                gen_ai.usage.input_tokens = tracing::field::Empty,
200                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
201            )
202        } else {
203            tracing::Span::current()
204        };
205
206        tracing::Instrument::instrument(
207            send_compatible_streaming_request(self.client.clone(), req),
208            span,
209        )
210        .await
211    }
212}
213
214#[derive(Clone, Copy)]
215struct OpenRouterCompatibleProfile;
216
217impl CompatibleStreamProfile for OpenRouterCompatibleProfile {
218    type Usage = Usage;
219    type Detail = ReasoningDetails;
220    type FinalResponse = StreamingCompletionResponse;
221
222    fn normalize_chunk(
223        &self,
224        data: &str,
225    ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
226        let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
227            Ok(data) => data,
228            Err(error) => {
229                tracing::error!(?error, message = data, "Failed to parse SSE message");
230                return Ok(None);
231            }
232        };
233
234        Ok(Some(
235            openai_chat_completions_compatible::normalize_first_choice_chunk(
236                Some(data.id),
237                Some(data.model),
238                data.usage,
239                &data.choices,
240                |choice| CompatibleChoiceData {
241                    finish_reason: if choice.finish_reason == Some(FinishReason::ToolCalls) {
242                        CompatibleFinishReason::ToolCalls
243                    } else {
244                        CompatibleFinishReason::Other
245                    },
246                    text: choice.delta.content.clone(),
247                    reasoning: choice.delta.reasoning.clone(),
248                    tool_calls: openai_chat_completions_compatible::tool_call_chunks(
249                        &choice.delta.tool_calls,
250                    ),
251                    details: choice.delta.reasoning_details.clone(),
252                },
253            ),
254        ))
255    }
256
257    fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
258        StreamingCompletionResponse { usage }
259    }
260
261    fn decorate_tool_call(
262        &self,
263        detail: &Self::Detail,
264        tool_calls: &mut std::collections::HashMap<usize, crate::streaming::RawStreamingToolCall>,
265    ) {
266        if let ReasoningDetails::Encrypted { id, data, .. } = detail
267            && let Some(id) = id
268            && let Some(tool_call) = tool_calls
269                .values_mut()
270                .find(|tool_call| tool_call.id.eq(id))
271            && let Ok(additional_params) = serde_json::to_value(detail)
272        {
273            tool_call.signature = Some(data.clone());
274            tool_call.additional_params = Some(additional_params);
275        }
276    }
277}
278
279pub async fn send_compatible_streaming_request<T>(
280    http_client: T,
281    req: Request<Vec<u8>>,
282) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
283where
284    T: HttpClientExt + Clone + 'static,
285{
286    openai_chat_completions_compatible::send_compatible_streaming_request(
287        http_client,
288        req,
289        OpenRouterCompatibleProfile,
290    )
291    .await
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use crate::providers::internal::openai_chat_completions_compatible::test_support::sse_bytes_from_data_lines;
298    use crate::streaming::StreamedAssistantContent;
299    use crate::test_utils::MockStreamingClient;
300    use futures::StreamExt;
301    use serde_json::json;
302
303    #[test]
304    fn test_streaming_completion_response_deserialization() {
305        let json = json!({
306            "id": "gen-abc123",
307            "choices": [{
308                "index": 0,
309                "delta": {
310                    "role": "assistant",
311                    "content": "Hello"
312                }
313            }],
314            "created": 1234567890u64,
315            "model": "gpt-3.5-turbo",
316            "object": "chat.completion.chunk"
317        });
318
319        let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
320        assert_eq!(response.id, "gen-abc123");
321        assert_eq!(response.model, "gpt-3.5-turbo");
322        assert_eq!(response.choices.len(), 1);
323    }
324
325    #[test]
326    fn test_delta_with_content() {
327        let json = json!({
328            "role": "assistant",
329            "content": "Hello, world!"
330        });
331
332        let delta: StreamingDelta = serde_json::from_value(json).unwrap();
333        assert_eq!(delta.role, Some("assistant".to_string()));
334        assert_eq!(delta.content, Some("Hello, world!".to_string()));
335    }
336
337    #[test]
338    fn test_delta_with_tool_call() {
339        let json = json!({
340            "role": "assistant",
341            "tool_calls": [{
342                "index": 0,
343                "id": "call_abc",
344                "type": "function",
345                "function": {
346                    "name": "get_weather",
347                    "arguments": "{\"location\":"
348                }
349            }]
350        });
351
352        let delta: StreamingDelta = serde_json::from_value(json).unwrap();
353        assert_eq!(delta.tool_calls.len(), 1);
354        assert_eq!(delta.tool_calls[0].index, 0);
355        assert_eq!(delta.tool_calls[0].id, Some("call_abc".to_string()));
356    }
357
358    #[test]
359    fn test_tool_call_with_partial_arguments() {
360        let json = json!({
361            "index": 0,
362            "id": null,
363            "type": null,
364            "function": {
365                "name": null,
366                "arguments": "Paris"
367            }
368        });
369
370        let tool_call: StreamingToolCall = serde_json::from_value(json).unwrap();
371        assert_eq!(tool_call.index, 0);
372        assert!(tool_call.id.is_none());
373        assert_eq!(tool_call.function.arguments, Some("Paris".to_string()));
374    }
375
376    #[test]
377    fn test_streaming_with_usage() {
378        let json = json!({
379            "id": "gen-xyz",
380            "choices": [{
381                "index": 0,
382                "delta": {
383                    "content": null
384                }
385            }],
386            "created": 1234567890u64,
387            "model": "gpt-4",
388            "object": "chat.completion.chunk",
389            "usage": {
390                "prompt_tokens": 100,
391                "completion_tokens": 50,
392                "total_tokens": 150
393            }
394        });
395
396        let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
397        assert!(response.usage.is_some());
398        let usage = response.usage.unwrap();
399        assert_eq!(usage.prompt_tokens, 100);
400        assert_eq!(usage.completion_tokens, 50);
401        assert_eq!(usage.total_tokens, 150);
402    }
403
404    #[test]
405    fn test_streaming_usage_maps_cache_token_accounting() {
406        use crate::completion::GetTokenUsage;
407
408        let json = json!({
409            "id": "gen-stream-cache",
410            "choices": [],
411            "created": 1u64,
412            "model": "anthropic/claude-3.5-sonnet",
413            "object": "chat.completion.chunk",
414            "usage": {
415                "prompt_tokens": 500,
416                "completion_tokens": 20,
417                "total_tokens": 520,
418                "prompt_tokens_details": {
419                    "cached_tokens": 400,
420                    "cache_write_tokens": 60
421                }
422            }
423        });
424
425        let chunk: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
426        let usage = chunk.usage.unwrap();
427        let token_usage = usage.token_usage().unwrap();
428
429        assert_eq!(token_usage.input_tokens, 500);
430        assert_eq!(token_usage.output_tokens, 20);
431        assert_eq!(token_usage.cached_input_tokens, 400);
432        assert_eq!(token_usage.cache_creation_input_tokens, 60);
433    }
434
435    #[test]
436    fn test_streaming_usage_cache_tokens_absent_defaults_to_zero() {
437        use crate::completion::GetTokenUsage;
438
439        let json = json!({
440            "id": "gen-stream-no-cache",
441            "choices": [],
442            "created": 1u64,
443            "model": "openai/gpt-4o",
444            "object": "chat.completion.chunk",
445            "usage": {
446                "prompt_tokens": 100,
447                "completion_tokens": 10,
448                "total_tokens": 110
449            }
450        });
451
452        let chunk: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
453        let usage = chunk.usage.unwrap();
454        let token_usage = usage.token_usage().unwrap();
455
456        assert_eq!(token_usage.cached_input_tokens, 0);
457        assert_eq!(token_usage.cache_creation_input_tokens, 0);
458    }
459
460    #[test]
461    fn test_multiple_tool_call_deltas() {
462        // Simulates the sequence of deltas for a tool call with arguments
463        let start_json = json!({
464            "id": "gen-1",
465            "choices": [{
466                "index": 0,
467                "delta": {
468                    "tool_calls": [{
469                        "index": 0,
470                        "id": "call_123",
471                        "type": "function",
472                        "function": {
473                            "name": "search",
474                            "arguments": ""
475                        }
476                    }]
477                }
478            }],
479            "created": 1234567890u64,
480            "model": "gpt-4",
481            "object": "chat.completion.chunk"
482        });
483
484        let delta1_json = json!({
485            "id": "gen-2",
486            "choices": [{
487                "index": 0,
488                "delta": {
489                    "tool_calls": [{
490                        "index": 0,
491                        "function": {
492                            "arguments": "{\"query\":"
493                        }
494                    }]
495                }
496            }],
497            "created": 1234567890u64,
498            "model": "gpt-4",
499            "object": "chat.completion.chunk"
500        });
501
502        let delta2_json = json!({
503            "id": "gen-3",
504            "choices": [{
505                "index": 0,
506                "delta": {
507                    "tool_calls": [{
508                        "index": 0,
509                        "function": {
510                            "arguments": "\"Rust programming\"}"
511                        }
512                    }]
513                }
514            }],
515            "created": 1234567890u64,
516            "model": "gpt-4",
517            "object": "chat.completion.chunk"
518        });
519
520        // Verify all chunks deserialize
521        let start: StreamingCompletionChunk = serde_json::from_value(start_json).unwrap();
522        assert_eq!(
523            start.choices[0].delta.tool_calls[0].id,
524            Some("call_123".to_string())
525        );
526
527        let delta1: StreamingCompletionChunk = serde_json::from_value(delta1_json).unwrap();
528        assert_eq!(
529            delta1.choices[0].delta.tool_calls[0].function.arguments,
530            Some("{\"query\":".to_string())
531        );
532
533        let delta2: StreamingCompletionChunk = serde_json::from_value(delta2_json).unwrap();
534        assert_eq!(
535            delta2.choices[0].delta.tool_calls[0].function.arguments,
536            Some("\"Rust programming\"}".to_string())
537        );
538    }
539
540    #[test]
541    fn test_response_with_error() {
542        let json = json!({
543            "id": "cmpl-abc123",
544            "object": "chat.completion.chunk",
545            "created": 1234567890,
546            "model": "gpt-3.5-turbo",
547            "provider": "openai",
548            "error": { "code": 500, "message": "Provider disconnected" },
549            "choices": [
550                { "index": 0, "delta": { "content": "" }, "finish_reason": "error" }
551            ]
552        });
553
554        let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
555        assert!(response.error.is_some());
556        let error = response.error.as_ref().unwrap();
557        assert_eq!(error.code, 500);
558        assert_eq!(error.message, "Provider disconnected");
559    }
560
561    #[tokio::test]
562    async fn encrypted_reasoning_details_attach_to_emitted_tool_calls() {
563        let client = MockStreamingClient {
564            sse_bytes: sse_bytes_from_data_lines([
565                "{\"id\":\"gen-1\",\"model\":\"openai/gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"type\":\"function\",\"function\":{\"name\":\"search\",\"arguments\":\"\"}}],\"reasoning_details\":[]},\"finish_reason\":null}],\"usage\":null}",
566                "{\"id\":\"gen-2\",\"model\":\"openai/gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[],\"reasoning_details\":[{\"type\":\"reasoning.encrypted\",\"id\":\"call_123\",\"format\":\"opaque\",\"index\":0,\"data\":\"enc_blob\"}]},\"finish_reason\":null}],\"usage\":null}",
567                "{\"id\":\"gen-3\",\"model\":\"openai/gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[],\"reasoning_details\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}",
568                "[DONE]",
569            ]),
570        };
571
572        let req = Request::builder()
573            .method("POST")
574            .uri("http://localhost/v1/chat/completions")
575            .body(Vec::new())
576            .expect("request should build");
577
578        let mut stream = send_compatible_streaming_request(client, req)
579            .await
580            .expect("stream should start");
581
582        let tool_call = loop {
583            match stream.next().await.expect("stream should yield an item") {
584                Ok(StreamedAssistantContent::ToolCall { tool_call, .. }) => break tool_call,
585                Ok(_) => continue,
586                Err(err) => panic!("stream should not error: {err}"),
587            }
588        };
589
590        assert_eq!(tool_call.id, "call_123");
591        assert_eq!(tool_call.function.name, "search");
592        assert_eq!(tool_call.function.arguments, serde_json::json!({}));
593        assert_eq!(tool_call.signature.as_deref(), Some("enc_blob"));
594        assert_eq!(
595            tool_call.additional_params,
596            Some(json!({
597                "type": "reasoning.encrypted",
598                "id": "call_123",
599                "format": "opaque",
600                "index": 0,
601                "data": "enc_blob"
602            }))
603        );
604    }
605}