1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tracing::info_span;
9use tracing_futures::Instrument;
10
11use crate::completion::GetTokenUsage;
12use crate::completion::{CompletionError, CompletionRequest};
13use crate::http_client::HttpClientExt;
14use crate::http_client::sse::{Event, GenericEventSource};
15use crate::json_utils;
16use crate::message::{ToolCall, ToolFunction};
17use crate::providers::openrouter::OpenrouterCompletionRequest;
18use crate::streaming;
19
20#[derive(Clone, Serialize, Deserialize, Debug)]
21pub struct StreamingCompletionResponse {
22 pub usage: Usage,
23}
24
25impl GetTokenUsage for StreamingCompletionResponse {
26 fn token_usage(&self) -> Option<crate::completion::Usage> {
27 let mut usage = crate::completion::Usage::new();
28
29 usage.input_tokens = self.usage.prompt_tokens as u64;
30 usage.output_tokens = self.usage.completion_tokens as u64;
31 usage.total_tokens = self.usage.total_tokens as u64;
32
33 Some(usage)
34 }
35}
36
37#[derive(Deserialize, Debug, PartialEq)]
38#[serde(rename_all = "snake_case")]
39pub enum FinishReason {
40 ToolCalls,
41 Stop,
42 Error,
43 ContentFilter,
44 Length,
45 #[serde(untagged)]
46 Other(String),
47}
48
49#[derive(Deserialize, Debug)]
50#[allow(dead_code)]
51struct StreamingChoice {
52 pub finish_reason: Option<FinishReason>,
53 pub native_finish_reason: Option<String>,
54 pub logprobs: Option<Value>,
55 pub index: usize,
56 pub delta: StreamingDelta,
57}
58
59#[derive(Deserialize, Debug)]
60struct StreamingFunction {
61 pub name: Option<String>,
62 pub arguments: Option<String>,
63}
64
65#[derive(Deserialize, Debug)]
66#[allow(dead_code)]
67struct StreamingToolCall {
68 pub index: usize,
69 pub id: Option<String>,
70 pub r#type: Option<String>,
71 pub function: StreamingFunction,
72}
73
74#[derive(Serialize, Deserialize, Debug, Clone, Default)]
75pub struct Usage {
76 pub prompt_tokens: u32,
77 pub completion_tokens: u32,
78 pub total_tokens: u32,
79}
80
81#[derive(Deserialize, Debug)]
82#[allow(dead_code)]
83struct ErrorResponse {
84 pub code: i32,
85 pub message: String,
86}
87
88#[derive(Deserialize, Debug)]
89#[allow(dead_code)]
90struct StreamingDelta {
91 pub role: Option<String>,
92 pub content: Option<String>,
93 pub reasoning: Option<String>,
94 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
95 pub tool_calls: Vec<StreamingToolCall>,
96}
97
98#[derive(Deserialize, Debug)]
99#[allow(dead_code)]
100struct StreamingCompletionChunk {
101 id: String,
102 model: String,
103 choices: Vec<StreamingChoice>,
104 usage: Option<Usage>,
105 error: Option<ErrorResponse>,
106}
107
108impl<T> super::CompletionModel<T>
109where
110 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
111{
112 pub(crate) async fn stream(
113 &self,
114 completion_request: CompletionRequest,
115 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
116 {
117 let preamble = completion_request.preamble.clone();
118 let mut request =
119 OpenrouterCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
120
121 let params = json_utils::merge(
122 request.additional_params.unwrap_or(serde_json::json!({})),
123 serde_json::json!({"stream": true }),
124 );
125
126 request.additional_params = Some(params);
127
128 let body = serde_json::to_vec(&request)?;
129
130 let req = self
131 .client
132 .post("/chat/completions")?
133 .body(body)
134 .map_err(|x| CompletionError::HttpError(x.into()))?;
135
136 let span = if tracing::Span::current().is_disabled() {
137 info_span!(
138 target: "rig::completions",
139 "chat_streaming",
140 gen_ai.operation.name = "chat_streaming",
141 gen_ai.provider.name = "openrouter",
142 gen_ai.request.model = self.model,
143 gen_ai.system_instructions = preamble,
144 gen_ai.response.id = tracing::field::Empty,
145 gen_ai.response.model = tracing::field::Empty,
146 gen_ai.usage.output_tokens = tracing::field::Empty,
147 gen_ai.usage.input_tokens = tracing::field::Empty,
148 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
149 gen_ai.output.messages = tracing::field::Empty,
150 )
151 } else {
152 tracing::Span::current()
153 };
154
155 tracing::Instrument::instrument(
156 send_compatible_streaming_request(self.client.http_client().clone(), req),
157 span,
158 )
159 .await
160 }
161}
162
163pub async fn send_compatible_streaming_request<T>(
164 http_client: T,
165 req: Request<Vec<u8>>,
166) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
167where
168 T: HttpClientExt + Clone + 'static,
169{
170 let span = tracing::Span::current();
171 let mut event_source = GenericEventSource::new(http_client, req);
173
174 let stream = stream! {
175 let mut tool_calls: HashMap<usize, ToolCall> = HashMap::new();
177 let mut final_usage = None;
178
179 while let Some(event_result) = event_source.next().await {
180 match event_result {
181 Ok(Event::Open) => {
182 tracing::trace!("SSE connection opened");
183 continue;
184 }
185
186 Ok(Event::Message(message)) => {
187 if message.data.trim().is_empty() || message.data == "[DONE]" {
188 continue;
189 }
190
191 let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
192 Ok(data) => data,
193 Err(error) => {
194 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
195 continue;
196 }
197 };
198
199 let Some(choice) = data.choices.first() else {
201 tracing::debug!("There is no choice");
202 continue;
203 };
204 let delta = &choice.delta;
205
206 if !delta.tool_calls.is_empty() {
207 for tool_call in &delta.tool_calls {
208 let index = tool_call.index;
209
210 let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
212 id: String::new(),
213 call_id: None,
214 function: ToolFunction {
215 name: String::new(),
216 arguments: serde_json::Value::Null,
217 },
218 });
219
220 if let Some(id) = &tool_call.id && !id.is_empty() {
222 existing_tool_call.id = id.clone();
223 }
224
225 if let Some(name) = &tool_call.function.name && !name.is_empty() {
226 existing_tool_call.function.name = name.clone();
227 }
228
229 if let Some(chunk) = &tool_call.function.arguments {
230 let current_args = match &existing_tool_call.function.arguments {
232 serde_json::Value::Null => String::new(),
233 serde_json::Value::String(s) => s.clone(),
234 v => v.to_string(),
235 };
236
237 let combined = format!("{current_args}{chunk}");
239
240 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
242 match serde_json::from_str(&combined) {
243 Ok(parsed) => existing_tool_call.function.arguments = parsed,
244 Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
245 }
246 } else {
247 existing_tool_call.function.arguments = serde_json::Value::String(combined);
248 }
249
250 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
252 id: existing_tool_call.id.clone(),
253 delta: chunk.clone(),
254 });
255 }
256 }
257 }
258
259 if let Some(reasoning) = &delta.reasoning && !reasoning.is_empty() {
261 yield Ok(streaming::RawStreamingChoice::Reasoning {
262 reasoning: reasoning.clone(),
263 id: None,
264 signature: None,
265 });
266 }
267
268 if let Some(content) = &delta.content && !content.is_empty() {
270 yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
271 }
272
273 if let Some(usage) = data.usage {
275 final_usage = Some(usage);
276 }
277
278 if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
280 for (_idx, tool_call) in tool_calls.into_iter() {
281 yield Ok(streaming::RawStreamingChoice::ToolCall {
282 name: tool_call.function.name,
283 id: tool_call.id,
284 arguments: tool_call.function.arguments,
285 call_id: None,
286 });
287 }
288 tool_calls = HashMap::new();
289 }
290 }
291 Err(crate::http_client::Error::StreamEnded) => {
292 break;
293 }
294 Err(error) => {
295 tracing::error!(?error, "SSE error");
296 yield Err(CompletionError::ProviderError(error.to_string()));
297 break;
298 }
299 }
300 }
301
302 event_source.close();
304
305 for (_idx, tool_call) in tool_calls.into_iter() {
307 yield Ok(streaming::RawStreamingChoice::ToolCall {
308 name: tool_call.function.name,
309 id: tool_call.id,
310 arguments: tool_call.function.arguments,
311 call_id: None,
312 });
313 }
314
315 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
317 usage: final_usage.unwrap_or_default(),
318 }));
319 }.instrument(span);
320
321 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
322 stream,
323 )))
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use serde_json::json;
330
331 #[test]
332 fn test_streaming_completion_response_deserialization() {
333 let json = json!({
334 "id": "gen-abc123",
335 "choices": [{
336 "index": 0,
337 "delta": {
338 "role": "assistant",
339 "content": "Hello"
340 }
341 }],
342 "created": 1234567890u64,
343 "model": "gpt-3.5-turbo",
344 "object": "chat.completion.chunk"
345 });
346
347 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
348 assert_eq!(response.id, "gen-abc123");
349 assert_eq!(response.model, "gpt-3.5-turbo");
350 assert_eq!(response.choices.len(), 1);
351 }
352
353 #[test]
354 fn test_delta_with_content() {
355 let json = json!({
356 "role": "assistant",
357 "content": "Hello, world!"
358 });
359
360 let delta: StreamingDelta = serde_json::from_value(json).unwrap();
361 assert_eq!(delta.role, Some("assistant".to_string()));
362 assert_eq!(delta.content, Some("Hello, world!".to_string()));
363 }
364
365 #[test]
366 fn test_delta_with_tool_call() {
367 let json = json!({
368 "role": "assistant",
369 "tool_calls": [{
370 "index": 0,
371 "id": "call_abc",
372 "type": "function",
373 "function": {
374 "name": "get_weather",
375 "arguments": "{\"location\":"
376 }
377 }]
378 });
379
380 let delta: StreamingDelta = serde_json::from_value(json).unwrap();
381 assert_eq!(delta.tool_calls.len(), 1);
382 assert_eq!(delta.tool_calls[0].index, 0);
383 assert_eq!(delta.tool_calls[0].id, Some("call_abc".to_string()));
384 }
385
386 #[test]
387 fn test_tool_call_with_partial_arguments() {
388 let json = json!({
389 "index": 0,
390 "id": null,
391 "type": null,
392 "function": {
393 "name": null,
394 "arguments": "Paris"
395 }
396 });
397
398 let tool_call: StreamingToolCall = serde_json::from_value(json).unwrap();
399 assert_eq!(tool_call.index, 0);
400 assert!(tool_call.id.is_none());
401 assert_eq!(tool_call.function.arguments, Some("Paris".to_string()));
402 }
403
404 #[test]
405 fn test_streaming_with_usage() {
406 let json = json!({
407 "id": "gen-xyz",
408 "choices": [{
409 "index": 0,
410 "delta": {
411 "content": null
412 }
413 }],
414 "created": 1234567890u64,
415 "model": "gpt-4",
416 "object": "chat.completion.chunk",
417 "usage": {
418 "prompt_tokens": 100,
419 "completion_tokens": 50,
420 "total_tokens": 150
421 }
422 });
423
424 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
425 assert!(response.usage.is_some());
426 let usage = response.usage.unwrap();
427 assert_eq!(usage.prompt_tokens, 100);
428 assert_eq!(usage.completion_tokens, 50);
429 assert_eq!(usage.total_tokens, 150);
430 }
431
432 #[test]
433 fn test_multiple_tool_call_deltas() {
434 let start_json = json!({
436 "id": "gen-1",
437 "choices": [{
438 "index": 0,
439 "delta": {
440 "tool_calls": [{
441 "index": 0,
442 "id": "call_123",
443 "type": "function",
444 "function": {
445 "name": "search",
446 "arguments": ""
447 }
448 }]
449 }
450 }],
451 "created": 1234567890u64,
452 "model": "gpt-4",
453 "object": "chat.completion.chunk"
454 });
455
456 let delta1_json = json!({
457 "id": "gen-2",
458 "choices": [{
459 "index": 0,
460 "delta": {
461 "tool_calls": [{
462 "index": 0,
463 "function": {
464 "arguments": "{\"query\":"
465 }
466 }]
467 }
468 }],
469 "created": 1234567890u64,
470 "model": "gpt-4",
471 "object": "chat.completion.chunk"
472 });
473
474 let delta2_json = json!({
475 "id": "gen-3",
476 "choices": [{
477 "index": 0,
478 "delta": {
479 "tool_calls": [{
480 "index": 0,
481 "function": {
482 "arguments": "\"Rust programming\"}"
483 }
484 }]
485 }
486 }],
487 "created": 1234567890u64,
488 "model": "gpt-4",
489 "object": "chat.completion.chunk"
490 });
491
492 let start: StreamingCompletionChunk = serde_json::from_value(start_json).unwrap();
494 assert_eq!(
495 start.choices[0].delta.tool_calls[0].id,
496 Some("call_123".to_string())
497 );
498
499 let delta1: StreamingCompletionChunk = serde_json::from_value(delta1_json).unwrap();
500 assert_eq!(
501 delta1.choices[0].delta.tool_calls[0].function.arguments,
502 Some("{\"query\":".to_string())
503 );
504
505 let delta2: StreamingCompletionChunk = serde_json::from_value(delta2_json).unwrap();
506 assert_eq!(
507 delta2.choices[0].delta.tool_calls[0].function.arguments,
508 Some("\"Rust programming\"}".to_string())
509 );
510 }
511
512 #[test]
513 fn test_response_with_error() {
514 let json = json!({
515 "id": "cmpl-abc123",
516 "object": "chat.completion.chunk",
517 "created": 1234567890,
518 "model": "gpt-3.5-turbo",
519 "provider": "openai",
520 "error": { "code": 500, "message": "Provider disconnected" },
521 "choices": [
522 { "index": 0, "delta": { "content": "" }, "finish_reason": "error" }
523 ]
524 });
525
526 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
527 assert!(response.error.is_some());
528 let error = response.error.as_ref().unwrap();
529 assert_eq!(error.code, 500);
530 assert_eq!(error.message, "Provider disconnected");
531 }
532}