1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tracing::info_span;
9use tracing_futures::Instrument;
10
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::HttpClientExt;
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::json_utils;
15use crate::providers::openrouter::{
16 OpenRouterRequestParams, OpenrouterCompletionRequest, ReasoningDetails,
17};
18use crate::streaming;
19
20#[derive(Clone, Serialize, Deserialize, Debug)]
21pub struct StreamingCompletionResponse {
22 pub usage: Usage,
23}
24
25impl GetTokenUsage for StreamingCompletionResponse {
26 fn token_usage(&self) -> Option<crate::completion::Usage> {
27 let mut usage = crate::completion::Usage::new();
28
29 usage.input_tokens = self.usage.prompt_tokens as u64;
30 usage.output_tokens = self.usage.completion_tokens as u64;
31 usage.total_tokens = self.usage.total_tokens as u64;
32
33 Some(usage)
34 }
35}
36
37#[derive(Deserialize, Debug, PartialEq)]
38#[serde(rename_all = "snake_case")]
39pub enum FinishReason {
40 ToolCalls,
41 Stop,
42 Error,
43 ContentFilter,
44 Length,
45 #[serde(untagged)]
46 Other(String),
47}
48
49#[derive(Deserialize, Debug)]
50#[allow(dead_code)]
51struct StreamingChoice {
52 pub finish_reason: Option<FinishReason>,
53 pub native_finish_reason: Option<String>,
54 pub logprobs: Option<Value>,
55 pub index: usize,
56 pub delta: StreamingDelta,
57}
58
59#[derive(Deserialize, Debug)]
60struct StreamingFunction {
61 pub name: Option<String>,
62 pub arguments: Option<String>,
63}
64
65#[derive(Deserialize, Debug)]
66#[allow(dead_code)]
67struct StreamingToolCall {
68 pub index: usize,
69 pub id: Option<String>,
70 pub r#type: Option<String>,
71 pub function: StreamingFunction,
72}
73
74#[derive(Serialize, Deserialize, Debug, Clone, Default)]
75pub struct Usage {
76 pub prompt_tokens: u32,
77 pub completion_tokens: u32,
78 pub total_tokens: u32,
79}
80
81#[derive(Deserialize, Debug)]
82#[allow(dead_code)]
83struct ErrorResponse {
84 pub code: i32,
85 pub message: String,
86}
87
88#[derive(Deserialize, Debug)]
89#[allow(dead_code)]
90struct StreamingDelta {
91 pub role: Option<String>,
92 pub content: Option<String>,
93 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
94 pub tool_calls: Vec<StreamingToolCall>,
95 pub reasoning: Option<String>,
96 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
97 pub reasoning_details: Vec<ReasoningDetails>,
98}
99
100#[derive(Deserialize, Debug)]
101#[allow(dead_code)]
102struct StreamingCompletionChunk {
103 id: String,
104 model: String,
105 choices: Vec<StreamingChoice>,
106 usage: Option<Usage>,
107 error: Option<ErrorResponse>,
108}
109
110impl<T> super::CompletionModel<T>
111where
112 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
113{
114 pub(crate) async fn stream(
115 &self,
116 completion_request: CompletionRequest,
117 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
118 {
119 let preamble = completion_request.preamble.clone();
120 let mut request = OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
121 model: self.model.as_ref(),
122 request: completion_request,
123 strict_tools: self.strict_tools,
124 })?;
125
126 let params = json_utils::merge(
127 request.additional_params.unwrap_or(serde_json::json!({})),
128 serde_json::json!({"stream": true }),
129 );
130
131 request.additional_params = Some(params);
132
133 let body = serde_json::to_vec(&request)?;
134
135 let req = self
136 .client
137 .post("/chat/completions")?
138 .body(body)
139 .map_err(|x| CompletionError::HttpError(x.into()))?;
140
141 let span = if tracing::Span::current().is_disabled() {
142 info_span!(
143 target: "rig::completions",
144 "chat_streaming",
145 gen_ai.operation.name = "chat_streaming",
146 gen_ai.provider.name = "openrouter",
147 gen_ai.request.model = self.model,
148 gen_ai.system_instructions = preamble,
149 gen_ai.response.id = tracing::field::Empty,
150 gen_ai.response.model = tracing::field::Empty,
151 gen_ai.usage.output_tokens = tracing::field::Empty,
152 gen_ai.usage.input_tokens = tracing::field::Empty,
153 )
154 } else {
155 tracing::Span::current()
156 };
157
158 tracing::Instrument::instrument(
159 send_compatible_streaming_request(self.client.clone(), req),
160 span,
161 )
162 .await
163 }
164}
165
166pub async fn send_compatible_streaming_request<T>(
167 http_client: T,
168 req: Request<Vec<u8>>,
169) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
170where
171 T: HttpClientExt + Clone + 'static,
172{
173 let span = tracing::Span::current();
174 let mut event_source = GenericEventSource::new(http_client, req);
176
177 let stream = stream! {
178 let mut tool_calls: HashMap<usize, streaming::RawStreamingToolCall> = HashMap::new();
180 let mut final_usage = None;
181
182 while let Some(event_result) = event_source.next().await {
183 match event_result {
184 Ok(Event::Open) => {
185 tracing::trace!("SSE connection opened");
186 continue;
187 }
188
189 Ok(Event::Message(message)) => {
190 if message.data.trim().is_empty() || message.data == "[DONE]" {
191 continue;
192 }
193
194 let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
195 Ok(data) => data,
196 Err(error) => {
197 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
198 continue;
199 }
200 };
201
202 let Some(choice) = data.choices.first() else {
204 tracing::debug!("There is no choice");
205 continue;
206 };
207 let delta = &choice.delta;
208
209 if !delta.tool_calls.is_empty() {
210 for tool_call in &delta.tool_calls {
211 let index = tool_call.index;
212
213 let existing_tool_call = tool_calls.entry(index).or_insert_with(streaming::RawStreamingToolCall::empty);
215
216 if let Some(id) = &tool_call.id && !id.is_empty() {
218 existing_tool_call.id = id.clone();
219 }
220
221 if let Some(name) = &tool_call.function.name && !name.is_empty() {
222 existing_tool_call.name = name.clone();
223 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
224 id: existing_tool_call.id.clone(),
225 internal_call_id: existing_tool_call.internal_call_id.clone(),
226 content: streaming::ToolCallDeltaContent::Name(name.clone()),
227 });
228 }
229
230 if let Some(chunk) = &tool_call.function.arguments && !chunk.is_empty() {
232 let current_args = match &existing_tool_call.arguments {
233 serde_json::Value::Null => String::new(),
234 serde_json::Value::String(s) => s.clone(),
235 v => v.to_string(),
236 };
237
238 let combined = format!("{current_args}{chunk}");
240
241 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
243 match serde_json::from_str(&combined) {
244 Ok(parsed) => existing_tool_call.arguments = parsed,
245 Err(_) => existing_tool_call.arguments = serde_json::Value::String(combined),
246 }
247 } else {
248 existing_tool_call.arguments = serde_json::Value::String(combined);
249 }
250
251 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
253 id: existing_tool_call.id.clone(),
254 internal_call_id: existing_tool_call.internal_call_id.clone(),
255 content: streaming::ToolCallDeltaContent::Delta(chunk.clone()),
256 });
257 }
258 }
259
260 for reasoning_detail in &delta.reasoning_details {
262 if let ReasoningDetails::Encrypted { id, data, .. } = reasoning_detail
263 && let Some(id) = id
264 && let Some(tool_call) = tool_calls.values_mut().find(|tool_call| tool_call.id.eq(id))
265 && let Ok(additional_params) = serde_json::to_value(reasoning_detail) {
266 tool_call.signature = Some(data.clone());
267 tool_call.additional_params = Some(additional_params);
268 }
269 }
270 }
271
272 if let Some(reasoning) = &delta.reasoning && !reasoning.is_empty() {
274 yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
275 reasoning: reasoning.clone(),
276 id: None,
277 });
278 }
279
280 if let Some(content) = &delta.content && !content.is_empty() {
282 yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
283 }
284
285 if let Some(usage) = data.usage {
287 final_usage = Some(usage);
288 }
289
290 if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
292 for (_idx, tool_call) in tool_calls.into_iter() {
293 yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call));
294 }
295 tool_calls = HashMap::new();
296 }
297 }
298 Err(crate::http_client::Error::StreamEnded) => {
299 break;
300 }
301 Err(error) => {
302 tracing::error!(?error, "SSE error");
303 yield Err(CompletionError::ProviderError(error.to_string()));
304 break;
305 }
306 }
307 }
308
309 event_source.close();
311
312 for (_idx, tool_call) in tool_calls.into_iter() {
314 yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call));
315 }
316
317 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
319 usage: final_usage.unwrap_or_default(),
320 }));
321 }.instrument(span);
322
323 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
324 stream,
325 )))
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use serde_json::json;
332
333 #[test]
334 fn test_streaming_completion_response_deserialization() {
335 let json = json!({
336 "id": "gen-abc123",
337 "choices": [{
338 "index": 0,
339 "delta": {
340 "role": "assistant",
341 "content": "Hello"
342 }
343 }],
344 "created": 1234567890u64,
345 "model": "gpt-3.5-turbo",
346 "object": "chat.completion.chunk"
347 });
348
349 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
350 assert_eq!(response.id, "gen-abc123");
351 assert_eq!(response.model, "gpt-3.5-turbo");
352 assert_eq!(response.choices.len(), 1);
353 }
354
355 #[test]
356 fn test_delta_with_content() {
357 let json = json!({
358 "role": "assistant",
359 "content": "Hello, world!"
360 });
361
362 let delta: StreamingDelta = serde_json::from_value(json).unwrap();
363 assert_eq!(delta.role, Some("assistant".to_string()));
364 assert_eq!(delta.content, Some("Hello, world!".to_string()));
365 }
366
367 #[test]
368 fn test_delta_with_tool_call() {
369 let json = json!({
370 "role": "assistant",
371 "tool_calls": [{
372 "index": 0,
373 "id": "call_abc",
374 "type": "function",
375 "function": {
376 "name": "get_weather",
377 "arguments": "{\"location\":"
378 }
379 }]
380 });
381
382 let delta: StreamingDelta = serde_json::from_value(json).unwrap();
383 assert_eq!(delta.tool_calls.len(), 1);
384 assert_eq!(delta.tool_calls[0].index, 0);
385 assert_eq!(delta.tool_calls[0].id, Some("call_abc".to_string()));
386 }
387
388 #[test]
389 fn test_tool_call_with_partial_arguments() {
390 let json = json!({
391 "index": 0,
392 "id": null,
393 "type": null,
394 "function": {
395 "name": null,
396 "arguments": "Paris"
397 }
398 });
399
400 let tool_call: StreamingToolCall = serde_json::from_value(json).unwrap();
401 assert_eq!(tool_call.index, 0);
402 assert!(tool_call.id.is_none());
403 assert_eq!(tool_call.function.arguments, Some("Paris".to_string()));
404 }
405
406 #[test]
407 fn test_streaming_with_usage() {
408 let json = json!({
409 "id": "gen-xyz",
410 "choices": [{
411 "index": 0,
412 "delta": {
413 "content": null
414 }
415 }],
416 "created": 1234567890u64,
417 "model": "gpt-4",
418 "object": "chat.completion.chunk",
419 "usage": {
420 "prompt_tokens": 100,
421 "completion_tokens": 50,
422 "total_tokens": 150
423 }
424 });
425
426 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
427 assert!(response.usage.is_some());
428 let usage = response.usage.unwrap();
429 assert_eq!(usage.prompt_tokens, 100);
430 assert_eq!(usage.completion_tokens, 50);
431 assert_eq!(usage.total_tokens, 150);
432 }
433
434 #[test]
435 fn test_multiple_tool_call_deltas() {
436 let start_json = json!({
438 "id": "gen-1",
439 "choices": [{
440 "index": 0,
441 "delta": {
442 "tool_calls": [{
443 "index": 0,
444 "id": "call_123",
445 "type": "function",
446 "function": {
447 "name": "search",
448 "arguments": ""
449 }
450 }]
451 }
452 }],
453 "created": 1234567890u64,
454 "model": "gpt-4",
455 "object": "chat.completion.chunk"
456 });
457
458 let delta1_json = json!({
459 "id": "gen-2",
460 "choices": [{
461 "index": 0,
462 "delta": {
463 "tool_calls": [{
464 "index": 0,
465 "function": {
466 "arguments": "{\"query\":"
467 }
468 }]
469 }
470 }],
471 "created": 1234567890u64,
472 "model": "gpt-4",
473 "object": "chat.completion.chunk"
474 });
475
476 let delta2_json = json!({
477 "id": "gen-3",
478 "choices": [{
479 "index": 0,
480 "delta": {
481 "tool_calls": [{
482 "index": 0,
483 "function": {
484 "arguments": "\"Rust programming\"}"
485 }
486 }]
487 }
488 }],
489 "created": 1234567890u64,
490 "model": "gpt-4",
491 "object": "chat.completion.chunk"
492 });
493
494 let start: StreamingCompletionChunk = serde_json::from_value(start_json).unwrap();
496 assert_eq!(
497 start.choices[0].delta.tool_calls[0].id,
498 Some("call_123".to_string())
499 );
500
501 let delta1: StreamingCompletionChunk = serde_json::from_value(delta1_json).unwrap();
502 assert_eq!(
503 delta1.choices[0].delta.tool_calls[0].function.arguments,
504 Some("{\"query\":".to_string())
505 );
506
507 let delta2: StreamingCompletionChunk = serde_json::from_value(delta2_json).unwrap();
508 assert_eq!(
509 delta2.choices[0].delta.tool_calls[0].function.arguments,
510 Some("\"Rust programming\"}".to_string())
511 );
512 }
513
514 #[test]
515 fn test_response_with_error() {
516 let json = json!({
517 "id": "cmpl-abc123",
518 "object": "chat.completion.chunk",
519 "created": 1234567890,
520 "model": "gpt-3.5-turbo",
521 "provider": "openai",
522 "error": { "code": 500, "message": "Provider disconnected" },
523 "choices": [
524 { "index": 0, "delta": { "content": "" }, "finish_reason": "error" }
525 ]
526 });
527
528 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
529 assert!(response.error.is_some());
530 let error = response.error.as_ref().unwrap();
531 assert_eq!(error.code, 500);
532 assert_eq!(error.message, "Provider disconnected");
533 }
534}