1use http::Request;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use tracing::info_span;
5
6use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
7use crate::http_client::HttpClientExt;
8use crate::json_utils;
9use crate::providers::internal::openai_chat_completions_compatible::{
10 self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
11 CompatibleToolCallChunk,
12};
13use crate::providers::openrouter::{
14 OpenRouterRequestParams, OpenrouterCompletionRequest, ReasoningDetails,
15};
16use crate::streaming;
17
18#[derive(Clone, Serialize, Deserialize, Debug)]
19pub struct StreamingCompletionResponse {
20 pub usage: Usage,
21}
22
23impl GetTokenUsage for StreamingCompletionResponse {
24 fn token_usage(&self) -> Option<crate::completion::Usage> {
25 self.usage.token_usage()
26 }
27}
28
29#[derive(Deserialize, Debug, PartialEq)]
30#[serde(rename_all = "snake_case")]
31pub enum FinishReason {
32 ToolCalls,
33 Stop,
34 Error,
35 ContentFilter,
36 Length,
37 #[serde(untagged)]
38 Other(String),
39}
40
41#[derive(Deserialize, Debug)]
42#[allow(dead_code)]
43struct StreamingChoice {
44 pub finish_reason: Option<FinishReason>,
45 pub native_finish_reason: Option<String>,
46 pub logprobs: Option<Value>,
47 pub index: usize,
48 pub delta: StreamingDelta,
49}
50
51#[derive(Deserialize, Debug)]
52struct StreamingFunction {
53 pub name: Option<String>,
54 pub arguments: Option<String>,
55}
56
57#[derive(Deserialize, Debug)]
58#[allow(dead_code)]
59struct StreamingToolCall {
60 pub index: usize,
61 pub id: Option<String>,
62 pub r#type: Option<String>,
63 pub function: StreamingFunction,
64}
65
66impl From<&StreamingToolCall> for CompatibleToolCallChunk {
67 fn from(value: &StreamingToolCall) -> Self {
68 Self {
69 index: value.index,
70 id: value.id.clone(),
71 name: value.function.name.clone(),
72 arguments: value.function.arguments.clone(),
73 }
74 }
75}
76
77#[derive(Serialize, Deserialize, Debug, Clone, Default)]
78pub struct Usage {
79 pub prompt_tokens: u32,
80 pub completion_tokens: u32,
81 pub total_tokens: u32,
82}
83
84impl GetTokenUsage for Usage {
85 fn token_usage(&self) -> Option<crate::completion::Usage> {
86 Some(crate::providers::internal::completion_usage(
87 self.prompt_tokens as u64,
88 self.completion_tokens as u64,
89 self.total_tokens as u64,
90 0,
91 ))
92 }
93}
94
95#[derive(Deserialize, Debug)]
96#[allow(dead_code)]
97struct ErrorResponse {
98 pub code: i32,
99 pub message: String,
100}
101
102#[derive(Deserialize, Debug)]
103#[allow(dead_code)]
104struct StreamingDelta {
105 pub role: Option<String>,
106 pub content: Option<String>,
107 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
108 pub tool_calls: Vec<StreamingToolCall>,
109 pub reasoning: Option<String>,
110 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
111 pub reasoning_details: Vec<ReasoningDetails>,
112}
113
114#[derive(Deserialize, Debug)]
115#[allow(dead_code)]
116struct StreamingCompletionChunk {
117 id: String,
118 model: String,
119 choices: Vec<StreamingChoice>,
120 usage: Option<Usage>,
121 error: Option<ErrorResponse>,
122}
123
124impl<T> super::CompletionModel<T>
125where
126 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
127{
128 pub(crate) async fn stream(
129 &self,
130 completion_request: CompletionRequest,
131 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
132 {
133 let request_model = completion_request
134 .model
135 .clone()
136 .unwrap_or_else(|| self.model.clone());
137 let preamble = completion_request.preamble.clone();
138 let mut request = OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
139 model: request_model.as_ref(),
140 request: completion_request,
141 strict_tools: self.strict_tools,
142 })?;
143
144 let params = json_utils::merge(
145 request.additional_params.unwrap_or(serde_json::json!({})),
146 serde_json::json!({"stream": true }),
147 );
148
149 request.additional_params = Some(params);
150
151 let body = serde_json::to_vec(&request)?;
152
153 let req = self
154 .client
155 .post("/chat/completions")?
156 .body(body)
157 .map_err(|x| CompletionError::HttpError(x.into()))?;
158
159 let span = if tracing::Span::current().is_disabled() {
160 info_span!(
161 target: "rig::completions",
162 "chat_streaming",
163 gen_ai.operation.name = "chat_streaming",
164 gen_ai.provider.name = "openrouter",
165 gen_ai.request.model = &request_model,
166 gen_ai.system_instructions = preamble,
167 gen_ai.response.id = tracing::field::Empty,
168 gen_ai.response.model = tracing::field::Empty,
169 gen_ai.usage.output_tokens = tracing::field::Empty,
170 gen_ai.usage.input_tokens = tracing::field::Empty,
171 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
172 )
173 } else {
174 tracing::Span::current()
175 };
176
177 tracing::Instrument::instrument(
178 send_compatible_streaming_request(self.client.clone(), req),
179 span,
180 )
181 .await
182 }
183}
184
185#[derive(Clone, Copy)]
186struct OpenRouterCompatibleProfile;
187
188impl CompatibleStreamProfile for OpenRouterCompatibleProfile {
189 type Usage = Usage;
190 type Detail = ReasoningDetails;
191 type FinalResponse = StreamingCompletionResponse;
192
193 fn normalize_chunk(
194 &self,
195 data: &str,
196 ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
197 let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
198 Ok(data) => data,
199 Err(error) => {
200 tracing::error!(?error, message = data, "Failed to parse SSE message");
201 return Ok(None);
202 }
203 };
204
205 Ok(Some(
206 openai_chat_completions_compatible::normalize_first_choice_chunk(
207 Some(data.id),
208 Some(data.model),
209 data.usage,
210 &data.choices,
211 |choice| CompatibleChoiceData {
212 finish_reason: if choice.finish_reason == Some(FinishReason::ToolCalls) {
213 CompatibleFinishReason::ToolCalls
214 } else {
215 CompatibleFinishReason::Other
216 },
217 text: choice.delta.content.clone(),
218 reasoning: choice.delta.reasoning.clone(),
219 tool_calls: openai_chat_completions_compatible::tool_call_chunks(
220 &choice.delta.tool_calls,
221 ),
222 details: choice.delta.reasoning_details.clone(),
223 },
224 ),
225 ))
226 }
227
228 fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
229 StreamingCompletionResponse { usage }
230 }
231
232 fn decorate_tool_call(
233 &self,
234 detail: &Self::Detail,
235 tool_calls: &mut std::collections::HashMap<usize, crate::streaming::RawStreamingToolCall>,
236 ) {
237 if let ReasoningDetails::Encrypted { id, data, .. } = detail
238 && let Some(id) = id
239 && let Some(tool_call) = tool_calls
240 .values_mut()
241 .find(|tool_call| tool_call.id.eq(id))
242 && let Ok(additional_params) = serde_json::to_value(detail)
243 {
244 tool_call.signature = Some(data.clone());
245 tool_call.additional_params = Some(additional_params);
246 }
247 }
248}
249
250pub async fn send_compatible_streaming_request<T>(
251 http_client: T,
252 req: Request<Vec<u8>>,
253) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
254where
255 T: HttpClientExt + Clone + 'static,
256{
257 openai_chat_completions_compatible::send_compatible_streaming_request(
258 http_client,
259 req,
260 OpenRouterCompatibleProfile,
261 )
262 .await
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use crate::http_client::mock::MockStreamingClient;
269 use crate::providers::internal::openai_chat_completions_compatible::test_support::sse_bytes_from_data_lines;
270 use crate::streaming::StreamedAssistantContent;
271 use futures::StreamExt;
272 use serde_json::json;
273
274 #[test]
275 fn test_streaming_completion_response_deserialization() {
276 let json = json!({
277 "id": "gen-abc123",
278 "choices": [{
279 "index": 0,
280 "delta": {
281 "role": "assistant",
282 "content": "Hello"
283 }
284 }],
285 "created": 1234567890u64,
286 "model": "gpt-3.5-turbo",
287 "object": "chat.completion.chunk"
288 });
289
290 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
291 assert_eq!(response.id, "gen-abc123");
292 assert_eq!(response.model, "gpt-3.5-turbo");
293 assert_eq!(response.choices.len(), 1);
294 }
295
296 #[test]
297 fn test_delta_with_content() {
298 let json = json!({
299 "role": "assistant",
300 "content": "Hello, world!"
301 });
302
303 let delta: StreamingDelta = serde_json::from_value(json).unwrap();
304 assert_eq!(delta.role, Some("assistant".to_string()));
305 assert_eq!(delta.content, Some("Hello, world!".to_string()));
306 }
307
308 #[test]
309 fn test_delta_with_tool_call() {
310 let json = json!({
311 "role": "assistant",
312 "tool_calls": [{
313 "index": 0,
314 "id": "call_abc",
315 "type": "function",
316 "function": {
317 "name": "get_weather",
318 "arguments": "{\"location\":"
319 }
320 }]
321 });
322
323 let delta: StreamingDelta = serde_json::from_value(json).unwrap();
324 assert_eq!(delta.tool_calls.len(), 1);
325 assert_eq!(delta.tool_calls[0].index, 0);
326 assert_eq!(delta.tool_calls[0].id, Some("call_abc".to_string()));
327 }
328
329 #[test]
330 fn test_tool_call_with_partial_arguments() {
331 let json = json!({
332 "index": 0,
333 "id": null,
334 "type": null,
335 "function": {
336 "name": null,
337 "arguments": "Paris"
338 }
339 });
340
341 let tool_call: StreamingToolCall = serde_json::from_value(json).unwrap();
342 assert_eq!(tool_call.index, 0);
343 assert!(tool_call.id.is_none());
344 assert_eq!(tool_call.function.arguments, Some("Paris".to_string()));
345 }
346
347 #[test]
348 fn test_streaming_with_usage() {
349 let json = json!({
350 "id": "gen-xyz",
351 "choices": [{
352 "index": 0,
353 "delta": {
354 "content": null
355 }
356 }],
357 "created": 1234567890u64,
358 "model": "gpt-4",
359 "object": "chat.completion.chunk",
360 "usage": {
361 "prompt_tokens": 100,
362 "completion_tokens": 50,
363 "total_tokens": 150
364 }
365 });
366
367 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
368 assert!(response.usage.is_some());
369 let usage = response.usage.unwrap();
370 assert_eq!(usage.prompt_tokens, 100);
371 assert_eq!(usage.completion_tokens, 50);
372 assert_eq!(usage.total_tokens, 150);
373 }
374
375 #[test]
376 fn test_multiple_tool_call_deltas() {
377 let start_json = json!({
379 "id": "gen-1",
380 "choices": [{
381 "index": 0,
382 "delta": {
383 "tool_calls": [{
384 "index": 0,
385 "id": "call_123",
386 "type": "function",
387 "function": {
388 "name": "search",
389 "arguments": ""
390 }
391 }]
392 }
393 }],
394 "created": 1234567890u64,
395 "model": "gpt-4",
396 "object": "chat.completion.chunk"
397 });
398
399 let delta1_json = json!({
400 "id": "gen-2",
401 "choices": [{
402 "index": 0,
403 "delta": {
404 "tool_calls": [{
405 "index": 0,
406 "function": {
407 "arguments": "{\"query\":"
408 }
409 }]
410 }
411 }],
412 "created": 1234567890u64,
413 "model": "gpt-4",
414 "object": "chat.completion.chunk"
415 });
416
417 let delta2_json = json!({
418 "id": "gen-3",
419 "choices": [{
420 "index": 0,
421 "delta": {
422 "tool_calls": [{
423 "index": 0,
424 "function": {
425 "arguments": "\"Rust programming\"}"
426 }
427 }]
428 }
429 }],
430 "created": 1234567890u64,
431 "model": "gpt-4",
432 "object": "chat.completion.chunk"
433 });
434
435 let start: StreamingCompletionChunk = serde_json::from_value(start_json).unwrap();
437 assert_eq!(
438 start.choices[0].delta.tool_calls[0].id,
439 Some("call_123".to_string())
440 );
441
442 let delta1: StreamingCompletionChunk = serde_json::from_value(delta1_json).unwrap();
443 assert_eq!(
444 delta1.choices[0].delta.tool_calls[0].function.arguments,
445 Some("{\"query\":".to_string())
446 );
447
448 let delta2: StreamingCompletionChunk = serde_json::from_value(delta2_json).unwrap();
449 assert_eq!(
450 delta2.choices[0].delta.tool_calls[0].function.arguments,
451 Some("\"Rust programming\"}".to_string())
452 );
453 }
454
455 #[test]
456 fn test_response_with_error() {
457 let json = json!({
458 "id": "cmpl-abc123",
459 "object": "chat.completion.chunk",
460 "created": 1234567890,
461 "model": "gpt-3.5-turbo",
462 "provider": "openai",
463 "error": { "code": 500, "message": "Provider disconnected" },
464 "choices": [
465 { "index": 0, "delta": { "content": "" }, "finish_reason": "error" }
466 ]
467 });
468
469 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
470 assert!(response.error.is_some());
471 let error = response.error.as_ref().unwrap();
472 assert_eq!(error.code, 500);
473 assert_eq!(error.message, "Provider disconnected");
474 }
475
476 #[tokio::test]
477 async fn encrypted_reasoning_details_attach_to_emitted_tool_calls() {
478 let client = MockStreamingClient {
479 sse_bytes: sse_bytes_from_data_lines([
480 "{\"id\":\"gen-1\",\"model\":\"openai/gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"type\":\"function\",\"function\":{\"name\":\"search\",\"arguments\":\"\"}}],\"reasoning_details\":[]},\"finish_reason\":null}],\"usage\":null}",
481 "{\"id\":\"gen-2\",\"model\":\"openai/gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[],\"reasoning_details\":[{\"type\":\"reasoning.encrypted\",\"id\":\"call_123\",\"format\":\"opaque\",\"index\":0,\"data\":\"enc_blob\"}]},\"finish_reason\":null}],\"usage\":null}",
482 "{\"id\":\"gen-3\",\"model\":\"openai/gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[],\"reasoning_details\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}",
483 "[DONE]",
484 ]),
485 };
486
487 let req = Request::builder()
488 .method("POST")
489 .uri("http://localhost/v1/chat/completions")
490 .body(Vec::new())
491 .expect("request should build");
492
493 let mut stream = send_compatible_streaming_request(client, req)
494 .await
495 .expect("stream should start");
496
497 let tool_call = loop {
498 match stream.next().await.expect("stream should yield an item") {
499 Ok(StreamedAssistantContent::ToolCall { tool_call, .. }) => break tool_call,
500 Ok(_) => continue,
501 Err(err) => panic!("stream should not error: {err}"),
502 }
503 };
504
505 assert_eq!(tool_call.id, "call_123");
506 assert_eq!(tool_call.function.name, "search");
507 assert_eq!(tool_call.function.arguments, serde_json::json!({}));
508 assert_eq!(tool_call.signature.as_deref(), Some("enc_blob"));
509 assert_eq!(
510 tool_call.additional_params,
511 Some(json!({
512 "type": "reasoning.encrypted",
513 "id": "call_123",
514 "format": "opaque",
515 "index": 0,
516 "data": "enc_blob"
517 }))
518 );
519 }
520}