1use http::Request;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use tracing::info_span;
5
6use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
7use crate::http_client::HttpClientExt;
8use crate::json_utils;
9use crate::providers::internal::openai_chat_completions_compatible::{
10 self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
11 CompatibleToolCallChunk,
12};
13use crate::providers::openrouter::{
14 OpenRouterRequestParams, OpenrouterCompletionRequest, ReasoningDetails,
15};
16use crate::streaming;
17
18#[derive(Clone, Serialize, Deserialize, Debug)]
19pub struct StreamingCompletionResponse {
20 pub usage: Usage,
21}
22
23impl GetTokenUsage for StreamingCompletionResponse {
24 fn token_usage(&self) -> Option<crate::completion::Usage> {
25 self.usage.token_usage()
26 }
27}
28
29#[derive(Deserialize, Debug, PartialEq)]
30#[serde(rename_all = "snake_case")]
31pub enum FinishReason {
32 ToolCalls,
33 Stop,
34 Error,
35 ContentFilter,
36 Length,
37 #[serde(untagged)]
38 Other(String),
39}
40
41#[derive(Deserialize, Debug)]
42#[allow(dead_code)]
43struct StreamingChoice {
44 pub finish_reason: Option<FinishReason>,
45 pub native_finish_reason: Option<String>,
46 pub logprobs: Option<Value>,
47 pub index: usize,
48 pub delta: StreamingDelta,
49}
50
51#[derive(Deserialize, Debug)]
52struct StreamingFunction {
53 pub name: Option<String>,
54 pub arguments: Option<String>,
55}
56
57#[derive(Deserialize, Debug)]
58#[allow(dead_code)]
59struct StreamingToolCall {
60 pub index: usize,
61 pub id: Option<String>,
62 pub r#type: Option<String>,
63 pub function: StreamingFunction,
64}
65
66impl From<&StreamingToolCall> for CompatibleToolCallChunk {
67 fn from(value: &StreamingToolCall) -> Self {
68 Self {
69 index: value.index,
70 id: value.id.clone(),
71 name: value.function.name.clone(),
72 arguments: value.function.arguments.clone(),
73 }
74 }
75}
76
77#[derive(Serialize, Deserialize, Debug, Clone, Default)]
78pub struct Usage {
79 pub prompt_tokens: u32,
80 pub completion_tokens: u32,
81 pub total_tokens: u32,
82 #[serde(default, skip_serializing_if = "Option::is_none")]
86 pub prompt_tokens_details: Option<PromptTokensDetails>,
87}
88
89#[derive(Serialize, Deserialize, Debug, Clone, Default)]
93pub struct PromptTokensDetails {
94 #[serde(default)]
96 pub cached_tokens: u32,
97 #[serde(default)]
99 pub cache_write_tokens: u32,
100}
101
102impl GetTokenUsage for Usage {
103 fn token_usage(&self) -> Option<crate::completion::Usage> {
104 let (cached_input, cache_creation) = self
105 .prompt_tokens_details
106 .as_ref()
107 .map(|d| (d.cached_tokens as u64, d.cache_write_tokens as u64))
108 .unwrap_or((0, 0));
109 Some(crate::completion::Usage {
110 input_tokens: self.prompt_tokens as u64,
111 output_tokens: self.completion_tokens as u64,
112 total_tokens: self.total_tokens as u64,
113 cached_input_tokens: cached_input,
114 cache_creation_input_tokens: cache_creation,
115 tool_use_prompt_tokens: 0,
116 reasoning_tokens: 0,
117 })
118 }
119}
120
121#[derive(Deserialize, Debug)]
122#[allow(dead_code)]
123struct ErrorResponse {
124 pub code: i32,
125 pub message: String,
126}
127
128#[derive(Deserialize, Debug)]
129#[allow(dead_code)]
130struct StreamingDelta {
131 pub role: Option<String>,
132 pub content: Option<String>,
133 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
134 pub tool_calls: Vec<StreamingToolCall>,
135 pub reasoning: Option<String>,
136 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
137 pub reasoning_details: Vec<ReasoningDetails>,
138}
139
140#[derive(Deserialize, Debug)]
141#[allow(dead_code)]
142struct StreamingCompletionChunk {
143 id: String,
144 model: String,
145 choices: Vec<StreamingChoice>,
146 usage: Option<Usage>,
147 error: Option<ErrorResponse>,
148}
149
150impl<T> super::CompletionModel<T>
151where
152 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
153{
154 pub(crate) async fn stream(
155 &self,
156 completion_request: CompletionRequest,
157 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
158 {
159 let request_model = completion_request
160 .model
161 .clone()
162 .unwrap_or_else(|| self.model.clone());
163 let preamble = completion_request.preamble.clone();
164 let mut request = OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
165 model: request_model.as_ref(),
166 request: completion_request,
167 strict_tools: self.strict_tools,
168 })?;
169
170 let params = json_utils::merge(
171 request.additional_params.unwrap_or(serde_json::json!({})),
172 serde_json::json!({"stream": true }),
173 );
174
175 request.additional_params = Some(params);
176
177 let body = serde_json::to_vec(&super::completion::final_request_body(
178 &request,
179 self.prompt_caching,
180 )?)?;
181
182 let req = self
183 .client
184 .post("/chat/completions")?
185 .body(body)
186 .map_err(|x| CompletionError::HttpError(x.into()))?;
187
188 let span = if tracing::Span::current().is_disabled() {
189 info_span!(
190 target: "rig::completions",
191 "chat_streaming",
192 gen_ai.operation.name = "chat_streaming",
193 gen_ai.provider.name = "openrouter",
194 gen_ai.request.model = &request_model,
195 gen_ai.system_instructions = preamble,
196 gen_ai.response.id = tracing::field::Empty,
197 gen_ai.response.model = tracing::field::Empty,
198 gen_ai.usage.output_tokens = tracing::field::Empty,
199 gen_ai.usage.input_tokens = tracing::field::Empty,
200 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
201 )
202 } else {
203 tracing::Span::current()
204 };
205
206 tracing::Instrument::instrument(
207 send_compatible_streaming_request(self.client.clone(), req),
208 span,
209 )
210 .await
211 }
212}
213
214#[derive(Clone, Copy)]
215struct OpenRouterCompatibleProfile;
216
217impl CompatibleStreamProfile for OpenRouterCompatibleProfile {
218 type Usage = Usage;
219 type Detail = ReasoningDetails;
220 type FinalResponse = StreamingCompletionResponse;
221
222 fn normalize_chunk(
223 &self,
224 data: &str,
225 ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
226 let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
227 Ok(data) => data,
228 Err(error) => {
229 tracing::error!(?error, message = data, "Failed to parse SSE message");
230 return Ok(None);
231 }
232 };
233
234 Ok(Some(
235 openai_chat_completions_compatible::normalize_first_choice_chunk(
236 Some(data.id),
237 Some(data.model),
238 data.usage,
239 &data.choices,
240 |choice| CompatibleChoiceData {
241 finish_reason: if choice.finish_reason == Some(FinishReason::ToolCalls) {
242 CompatibleFinishReason::ToolCalls
243 } else {
244 CompatibleFinishReason::Other
245 },
246 text: choice.delta.content.clone(),
247 reasoning: choice.delta.reasoning.clone(),
248 tool_calls: openai_chat_completions_compatible::tool_call_chunks(
249 &choice.delta.tool_calls,
250 ),
251 details: choice.delta.reasoning_details.clone(),
252 },
253 ),
254 ))
255 }
256
257 fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
258 StreamingCompletionResponse { usage }
259 }
260
261 fn decorate_tool_call(
262 &self,
263 detail: &Self::Detail,
264 tool_calls: &mut std::collections::HashMap<usize, crate::streaming::RawStreamingToolCall>,
265 ) {
266 if let ReasoningDetails::Encrypted { id, data, .. } = detail
267 && let Some(id) = id
268 && let Some(tool_call) = tool_calls
269 .values_mut()
270 .find(|tool_call| tool_call.id.eq(id))
271 && let Ok(additional_params) = serde_json::to_value(detail)
272 {
273 tool_call.signature = Some(data.clone());
274 tool_call.additional_params = Some(additional_params);
275 }
276 }
277}
278
279pub async fn send_compatible_streaming_request<T>(
280 http_client: T,
281 req: Request<Vec<u8>>,
282) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
283where
284 T: HttpClientExt + Clone + 'static,
285{
286 openai_chat_completions_compatible::send_compatible_streaming_request(
287 http_client,
288 req,
289 OpenRouterCompatibleProfile,
290 )
291 .await
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use crate::providers::internal::openai_chat_completions_compatible::test_support::sse_bytes_from_data_lines;
298 use crate::streaming::StreamedAssistantContent;
299 use crate::test_utils::MockStreamingClient;
300 use futures::StreamExt;
301 use serde_json::json;
302
303 #[test]
304 fn test_streaming_completion_response_deserialization() {
305 let json = json!({
306 "id": "gen-abc123",
307 "choices": [{
308 "index": 0,
309 "delta": {
310 "role": "assistant",
311 "content": "Hello"
312 }
313 }],
314 "created": 1234567890u64,
315 "model": "gpt-3.5-turbo",
316 "object": "chat.completion.chunk"
317 });
318
319 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
320 assert_eq!(response.id, "gen-abc123");
321 assert_eq!(response.model, "gpt-3.5-turbo");
322 assert_eq!(response.choices.len(), 1);
323 }
324
325 #[test]
326 fn test_delta_with_content() {
327 let json = json!({
328 "role": "assistant",
329 "content": "Hello, world!"
330 });
331
332 let delta: StreamingDelta = serde_json::from_value(json).unwrap();
333 assert_eq!(delta.role, Some("assistant".to_string()));
334 assert_eq!(delta.content, Some("Hello, world!".to_string()));
335 }
336
337 #[test]
338 fn test_delta_with_tool_call() {
339 let json = json!({
340 "role": "assistant",
341 "tool_calls": [{
342 "index": 0,
343 "id": "call_abc",
344 "type": "function",
345 "function": {
346 "name": "get_weather",
347 "arguments": "{\"location\":"
348 }
349 }]
350 });
351
352 let delta: StreamingDelta = serde_json::from_value(json).unwrap();
353 assert_eq!(delta.tool_calls.len(), 1);
354 assert_eq!(delta.tool_calls[0].index, 0);
355 assert_eq!(delta.tool_calls[0].id, Some("call_abc".to_string()));
356 }
357
358 #[test]
359 fn test_tool_call_with_partial_arguments() {
360 let json = json!({
361 "index": 0,
362 "id": null,
363 "type": null,
364 "function": {
365 "name": null,
366 "arguments": "Paris"
367 }
368 });
369
370 let tool_call: StreamingToolCall = serde_json::from_value(json).unwrap();
371 assert_eq!(tool_call.index, 0);
372 assert!(tool_call.id.is_none());
373 assert_eq!(tool_call.function.arguments, Some("Paris".to_string()));
374 }
375
376 #[test]
377 fn test_streaming_with_usage() {
378 let json = json!({
379 "id": "gen-xyz",
380 "choices": [{
381 "index": 0,
382 "delta": {
383 "content": null
384 }
385 }],
386 "created": 1234567890u64,
387 "model": "gpt-4",
388 "object": "chat.completion.chunk",
389 "usage": {
390 "prompt_tokens": 100,
391 "completion_tokens": 50,
392 "total_tokens": 150
393 }
394 });
395
396 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
397 assert!(response.usage.is_some());
398 let usage = response.usage.unwrap();
399 assert_eq!(usage.prompt_tokens, 100);
400 assert_eq!(usage.completion_tokens, 50);
401 assert_eq!(usage.total_tokens, 150);
402 }
403
404 #[test]
405 fn test_streaming_usage_maps_cache_token_accounting() {
406 use crate::completion::GetTokenUsage;
407
408 let json = json!({
409 "id": "gen-stream-cache",
410 "choices": [],
411 "created": 1u64,
412 "model": "anthropic/claude-3.5-sonnet",
413 "object": "chat.completion.chunk",
414 "usage": {
415 "prompt_tokens": 500,
416 "completion_tokens": 20,
417 "total_tokens": 520,
418 "prompt_tokens_details": {
419 "cached_tokens": 400,
420 "cache_write_tokens": 60
421 }
422 }
423 });
424
425 let chunk: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
426 let usage = chunk.usage.unwrap();
427 let token_usage = usage.token_usage().unwrap();
428
429 assert_eq!(token_usage.input_tokens, 500);
430 assert_eq!(token_usage.output_tokens, 20);
431 assert_eq!(token_usage.cached_input_tokens, 400);
432 assert_eq!(token_usage.cache_creation_input_tokens, 60);
433 }
434
435 #[test]
436 fn test_streaming_usage_cache_tokens_absent_defaults_to_zero() {
437 use crate::completion::GetTokenUsage;
438
439 let json = json!({
440 "id": "gen-stream-no-cache",
441 "choices": [],
442 "created": 1u64,
443 "model": "openai/gpt-4o",
444 "object": "chat.completion.chunk",
445 "usage": {
446 "prompt_tokens": 100,
447 "completion_tokens": 10,
448 "total_tokens": 110
449 }
450 });
451
452 let chunk: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
453 let usage = chunk.usage.unwrap();
454 let token_usage = usage.token_usage().unwrap();
455
456 assert_eq!(token_usage.cached_input_tokens, 0);
457 assert_eq!(token_usage.cache_creation_input_tokens, 0);
458 }
459
460 #[test]
461 fn test_multiple_tool_call_deltas() {
462 let start_json = json!({
464 "id": "gen-1",
465 "choices": [{
466 "index": 0,
467 "delta": {
468 "tool_calls": [{
469 "index": 0,
470 "id": "call_123",
471 "type": "function",
472 "function": {
473 "name": "search",
474 "arguments": ""
475 }
476 }]
477 }
478 }],
479 "created": 1234567890u64,
480 "model": "gpt-4",
481 "object": "chat.completion.chunk"
482 });
483
484 let delta1_json = json!({
485 "id": "gen-2",
486 "choices": [{
487 "index": 0,
488 "delta": {
489 "tool_calls": [{
490 "index": 0,
491 "function": {
492 "arguments": "{\"query\":"
493 }
494 }]
495 }
496 }],
497 "created": 1234567890u64,
498 "model": "gpt-4",
499 "object": "chat.completion.chunk"
500 });
501
502 let delta2_json = json!({
503 "id": "gen-3",
504 "choices": [{
505 "index": 0,
506 "delta": {
507 "tool_calls": [{
508 "index": 0,
509 "function": {
510 "arguments": "\"Rust programming\"}"
511 }
512 }]
513 }
514 }],
515 "created": 1234567890u64,
516 "model": "gpt-4",
517 "object": "chat.completion.chunk"
518 });
519
520 let start: StreamingCompletionChunk = serde_json::from_value(start_json).unwrap();
522 assert_eq!(
523 start.choices[0].delta.tool_calls[0].id,
524 Some("call_123".to_string())
525 );
526
527 let delta1: StreamingCompletionChunk = serde_json::from_value(delta1_json).unwrap();
528 assert_eq!(
529 delta1.choices[0].delta.tool_calls[0].function.arguments,
530 Some("{\"query\":".to_string())
531 );
532
533 let delta2: StreamingCompletionChunk = serde_json::from_value(delta2_json).unwrap();
534 assert_eq!(
535 delta2.choices[0].delta.tool_calls[0].function.arguments,
536 Some("\"Rust programming\"}".to_string())
537 );
538 }
539
540 #[test]
541 fn test_response_with_error() {
542 let json = json!({
543 "id": "cmpl-abc123",
544 "object": "chat.completion.chunk",
545 "created": 1234567890,
546 "model": "gpt-3.5-turbo",
547 "provider": "openai",
548 "error": { "code": 500, "message": "Provider disconnected" },
549 "choices": [
550 { "index": 0, "delta": { "content": "" }, "finish_reason": "error" }
551 ]
552 });
553
554 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
555 assert!(response.error.is_some());
556 let error = response.error.as_ref().unwrap();
557 assert_eq!(error.code, 500);
558 assert_eq!(error.message, "Provider disconnected");
559 }
560
561 #[tokio::test]
562 async fn encrypted_reasoning_details_attach_to_emitted_tool_calls() {
563 let client = MockStreamingClient {
564 sse_bytes: sse_bytes_from_data_lines([
565 "{\"id\":\"gen-1\",\"model\":\"openai/gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"type\":\"function\",\"function\":{\"name\":\"search\",\"arguments\":\"\"}}],\"reasoning_details\":[]},\"finish_reason\":null}],\"usage\":null}",
566 "{\"id\":\"gen-2\",\"model\":\"openai/gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[],\"reasoning_details\":[{\"type\":\"reasoning.encrypted\",\"id\":\"call_123\",\"format\":\"opaque\",\"index\":0,\"data\":\"enc_blob\"}]},\"finish_reason\":null}],\"usage\":null}",
567 "{\"id\":\"gen-3\",\"model\":\"openai/gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[],\"reasoning_details\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}",
568 "[DONE]",
569 ]),
570 };
571
572 let req = Request::builder()
573 .method("POST")
574 .uri("http://localhost/v1/chat/completions")
575 .body(Vec::new())
576 .expect("request should build");
577
578 let mut stream = send_compatible_streaming_request(client, req)
579 .await
580 .expect("stream should start");
581
582 let tool_call = loop {
583 match stream.next().await.expect("stream should yield an item") {
584 Ok(StreamedAssistantContent::ToolCall { tool_call, .. }) => break tool_call,
585 Ok(_) => continue,
586 Err(err) => panic!("stream should not error: {err}"),
587 }
588 };
589
590 assert_eq!(tool_call.id, "call_123");
591 assert_eq!(tool_call.function.name, "search");
592 assert_eq!(tool_call.function.arguments, serde_json::json!({}));
593 assert_eq!(tool_call.signature.as_deref(), Some("enc_blob"));
594 assert_eq!(
595 tool_call.additional_params,
596 Some(json!({
597 "type": "reasoning.encrypted",
598 "id": "call_123",
599 "format": "opaque",
600 "index": 0,
601 "data": "enc_blob"
602 }))
603 );
604 }
605}