1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tracing::info_span;
9use tracing_futures::Instrument;
10
11use crate::completion::GetTokenUsage;
12use crate::completion::{CompletionError, CompletionRequest};
13use crate::http_client::HttpClientExt;
14use crate::http_client::sse::{Event, GenericEventSource};
15use crate::json_utils;
16use crate::message::{ToolCall, ToolFunction};
17use crate::providers::openrouter::OpenrouterCompletionRequest;
18use crate::streaming;
19
20#[derive(Clone, Serialize, Deserialize, Debug)]
21pub struct StreamingCompletionResponse {
22 pub usage: Usage,
23}
24
25impl GetTokenUsage for StreamingCompletionResponse {
26 fn token_usage(&self) -> Option<crate::completion::Usage> {
27 let mut usage = crate::completion::Usage::new();
28
29 usage.input_tokens = self.usage.prompt_tokens as u64;
30 usage.output_tokens = self.usage.completion_tokens as u64;
31 usage.total_tokens = self.usage.total_tokens as u64;
32
33 Some(usage)
34 }
35}
36
37#[derive(Deserialize, Debug, PartialEq)]
38#[serde(rename_all = "snake_case")]
39pub enum FinishReason {
40 ToolCalls,
41 Stop,
42 Error,
43 ContentFilter,
44 Length,
45 #[serde(untagged)]
46 Other(String),
47}
48
49#[derive(Deserialize, Debug)]
50#[allow(dead_code)]
51struct StreamingChoice {
52 pub finish_reason: Option<FinishReason>,
53 pub native_finish_reason: Option<String>,
54 pub logprobs: Option<Value>,
55 pub index: usize,
56 pub delta: StreamingDelta,
57}
58
59#[derive(Deserialize, Debug)]
60struct StreamingFunction {
61 pub name: Option<String>,
62 pub arguments: Option<String>,
63}
64
65#[derive(Deserialize, Debug)]
66#[allow(dead_code)]
67struct StreamingToolCall {
68 pub index: usize,
69 pub id: Option<String>,
70 pub r#type: Option<String>,
71 pub function: StreamingFunction,
72}
73
74#[derive(Serialize, Deserialize, Debug, Clone, Default)]
75pub struct Usage {
76 pub prompt_tokens: u32,
77 pub completion_tokens: u32,
78 pub total_tokens: u32,
79}
80
81#[derive(Deserialize, Debug)]
82#[allow(dead_code)]
83struct ErrorResponse {
84 pub code: i32,
85 pub message: String,
86}
87
88#[derive(Deserialize, Debug)]
89#[allow(dead_code)]
90struct StreamingDelta {
91 pub role: Option<String>,
92 pub content: Option<String>,
93 pub reasoning: Option<String>,
94 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
95 pub tool_calls: Vec<StreamingToolCall>,
96}
97
98#[derive(Deserialize, Debug)]
99#[allow(dead_code)]
100struct StreamingCompletionChunk {
101 id: String,
102 model: String,
103 choices: Vec<StreamingChoice>,
104 usage: Option<Usage>,
105 error: Option<ErrorResponse>,
106}
107
108impl<T> super::CompletionModel<T>
109where
110 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
111{
112 pub(crate) async fn stream(
113 &self,
114 completion_request: CompletionRequest,
115 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
116 {
117 let preamble = completion_request.preamble.clone();
118 let mut request =
119 OpenrouterCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
120
121 let params = json_utils::merge(
122 request.additional_params.unwrap_or(serde_json::json!({})),
123 serde_json::json!({"stream": true }),
124 );
125
126 request.additional_params = Some(params);
127
128 let body = serde_json::to_vec(&request)?;
129
130 let req = self
131 .client
132 .post("/chat/completions")?
133 .body(body)
134 .map_err(|x| CompletionError::HttpError(x.into()))?;
135
136 let span = if tracing::Span::current().is_disabled() {
137 info_span!(
138 target: "rig::completions",
139 "chat_streaming",
140 gen_ai.operation.name = "chat_streaming",
141 gen_ai.provider.name = "openrouter",
142 gen_ai.request.model = self.model,
143 gen_ai.system_instructions = preamble,
144 gen_ai.response.id = tracing::field::Empty,
145 gen_ai.response.model = tracing::field::Empty,
146 gen_ai.usage.output_tokens = tracing::field::Empty,
147 gen_ai.usage.input_tokens = tracing::field::Empty,
148 )
149 } else {
150 tracing::Span::current()
151 };
152
153 tracing::Instrument::instrument(
154 send_compatible_streaming_request(self.client.clone(), req),
155 span,
156 )
157 .await
158 }
159}
160
161pub async fn send_compatible_streaming_request<T>(
162 http_client: T,
163 req: Request<Vec<u8>>,
164) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
165where
166 T: HttpClientExt + Clone + 'static,
167{
168 let span = tracing::Span::current();
169 let mut event_source = GenericEventSource::new(http_client, req);
171
172 let stream = stream! {
173 let mut tool_calls: HashMap<usize, ToolCall> = HashMap::new();
175 let mut final_usage = None;
176
177 while let Some(event_result) = event_source.next().await {
178 match event_result {
179 Ok(Event::Open) => {
180 tracing::trace!("SSE connection opened");
181 continue;
182 }
183
184 Ok(Event::Message(message)) => {
185 if message.data.trim().is_empty() || message.data == "[DONE]" {
186 continue;
187 }
188
189 let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
190 Ok(data) => data,
191 Err(error) => {
192 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
193 continue;
194 }
195 };
196
197 let Some(choice) = data.choices.first() else {
199 tracing::debug!("There is no choice");
200 continue;
201 };
202 let delta = &choice.delta;
203
204 if !delta.tool_calls.is_empty() {
205 for tool_call in &delta.tool_calls {
206 let index = tool_call.index;
207
208 let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
210 id: String::new(),
211 call_id: None,
212 function: ToolFunction {
213 name: String::new(),
214 arguments: serde_json::Value::Null,
215 },
216 });
217
218 if let Some(id) = &tool_call.id && !id.is_empty() {
220 existing_tool_call.id = id.clone();
221 }
222
223 if let Some(name) = &tool_call.function.name && !name.is_empty() {
224 existing_tool_call.function.name = name.clone();
225 }
226
227 if let Some(chunk) = &tool_call.function.arguments {
228 let current_args = match &existing_tool_call.function.arguments {
230 serde_json::Value::Null => String::new(),
231 serde_json::Value::String(s) => s.clone(),
232 v => v.to_string(),
233 };
234
235 let combined = format!("{current_args}{chunk}");
237
238 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
240 match serde_json::from_str(&combined) {
241 Ok(parsed) => existing_tool_call.function.arguments = parsed,
242 Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
243 }
244 } else {
245 existing_tool_call.function.arguments = serde_json::Value::String(combined);
246 }
247
248 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
250 id: existing_tool_call.id.clone(),
251 delta: chunk.clone(),
252 });
253 }
254 }
255 }
256
257 if let Some(reasoning) = &delta.reasoning && !reasoning.is_empty() {
259 yield Ok(streaming::RawStreamingChoice::Reasoning {
260 reasoning: reasoning.clone(),
261 id: None,
262 signature: None,
263 });
264 }
265
266 if let Some(content) = &delta.content && !content.is_empty() {
268 yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
269 }
270
271 if let Some(usage) = data.usage {
273 final_usage = Some(usage);
274 }
275
276 if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
278 for (_idx, tool_call) in tool_calls.into_iter() {
279 yield Ok(streaming::RawStreamingChoice::ToolCall {
280 name: tool_call.function.name,
281 id: tool_call.id,
282 arguments: tool_call.function.arguments,
283 call_id: None,
284 });
285 }
286 tool_calls = HashMap::new();
287 }
288 }
289 Err(crate::http_client::Error::StreamEnded) => {
290 break;
291 }
292 Err(error) => {
293 tracing::error!(?error, "SSE error");
294 yield Err(CompletionError::ProviderError(error.to_string()));
295 break;
296 }
297 }
298 }
299
300 event_source.close();
302
303 for (_idx, tool_call) in tool_calls.into_iter() {
305 yield Ok(streaming::RawStreamingChoice::ToolCall {
306 name: tool_call.function.name,
307 id: tool_call.id,
308 arguments: tool_call.function.arguments,
309 call_id: None,
310 });
311 }
312
313 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
315 usage: final_usage.unwrap_or_default(),
316 }));
317 }.instrument(span);
318
319 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
320 stream,
321 )))
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use serde_json::json;
328
329 #[test]
330 fn test_streaming_completion_response_deserialization() {
331 let json = json!({
332 "id": "gen-abc123",
333 "choices": [{
334 "index": 0,
335 "delta": {
336 "role": "assistant",
337 "content": "Hello"
338 }
339 }],
340 "created": 1234567890u64,
341 "model": "gpt-3.5-turbo",
342 "object": "chat.completion.chunk"
343 });
344
345 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
346 assert_eq!(response.id, "gen-abc123");
347 assert_eq!(response.model, "gpt-3.5-turbo");
348 assert_eq!(response.choices.len(), 1);
349 }
350
351 #[test]
352 fn test_delta_with_content() {
353 let json = json!({
354 "role": "assistant",
355 "content": "Hello, world!"
356 });
357
358 let delta: StreamingDelta = serde_json::from_value(json).unwrap();
359 assert_eq!(delta.role, Some("assistant".to_string()));
360 assert_eq!(delta.content, Some("Hello, world!".to_string()));
361 }
362
363 #[test]
364 fn test_delta_with_tool_call() {
365 let json = json!({
366 "role": "assistant",
367 "tool_calls": [{
368 "index": 0,
369 "id": "call_abc",
370 "type": "function",
371 "function": {
372 "name": "get_weather",
373 "arguments": "{\"location\":"
374 }
375 }]
376 });
377
378 let delta: StreamingDelta = serde_json::from_value(json).unwrap();
379 assert_eq!(delta.tool_calls.len(), 1);
380 assert_eq!(delta.tool_calls[0].index, 0);
381 assert_eq!(delta.tool_calls[0].id, Some("call_abc".to_string()));
382 }
383
384 #[test]
385 fn test_tool_call_with_partial_arguments() {
386 let json = json!({
387 "index": 0,
388 "id": null,
389 "type": null,
390 "function": {
391 "name": null,
392 "arguments": "Paris"
393 }
394 });
395
396 let tool_call: StreamingToolCall = serde_json::from_value(json).unwrap();
397 assert_eq!(tool_call.index, 0);
398 assert!(tool_call.id.is_none());
399 assert_eq!(tool_call.function.arguments, Some("Paris".to_string()));
400 }
401
402 #[test]
403 fn test_streaming_with_usage() {
404 let json = json!({
405 "id": "gen-xyz",
406 "choices": [{
407 "index": 0,
408 "delta": {
409 "content": null
410 }
411 }],
412 "created": 1234567890u64,
413 "model": "gpt-4",
414 "object": "chat.completion.chunk",
415 "usage": {
416 "prompt_tokens": 100,
417 "completion_tokens": 50,
418 "total_tokens": 150
419 }
420 });
421
422 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
423 assert!(response.usage.is_some());
424 let usage = response.usage.unwrap();
425 assert_eq!(usage.prompt_tokens, 100);
426 assert_eq!(usage.completion_tokens, 50);
427 assert_eq!(usage.total_tokens, 150);
428 }
429
430 #[test]
431 fn test_multiple_tool_call_deltas() {
432 let start_json = json!({
434 "id": "gen-1",
435 "choices": [{
436 "index": 0,
437 "delta": {
438 "tool_calls": [{
439 "index": 0,
440 "id": "call_123",
441 "type": "function",
442 "function": {
443 "name": "search",
444 "arguments": ""
445 }
446 }]
447 }
448 }],
449 "created": 1234567890u64,
450 "model": "gpt-4",
451 "object": "chat.completion.chunk"
452 });
453
454 let delta1_json = json!({
455 "id": "gen-2",
456 "choices": [{
457 "index": 0,
458 "delta": {
459 "tool_calls": [{
460 "index": 0,
461 "function": {
462 "arguments": "{\"query\":"
463 }
464 }]
465 }
466 }],
467 "created": 1234567890u64,
468 "model": "gpt-4",
469 "object": "chat.completion.chunk"
470 });
471
472 let delta2_json = json!({
473 "id": "gen-3",
474 "choices": [{
475 "index": 0,
476 "delta": {
477 "tool_calls": [{
478 "index": 0,
479 "function": {
480 "arguments": "\"Rust programming\"}"
481 }
482 }]
483 }
484 }],
485 "created": 1234567890u64,
486 "model": "gpt-4",
487 "object": "chat.completion.chunk"
488 });
489
490 let start: StreamingCompletionChunk = serde_json::from_value(start_json).unwrap();
492 assert_eq!(
493 start.choices[0].delta.tool_calls[0].id,
494 Some("call_123".to_string())
495 );
496
497 let delta1: StreamingCompletionChunk = serde_json::from_value(delta1_json).unwrap();
498 assert_eq!(
499 delta1.choices[0].delta.tool_calls[0].function.arguments,
500 Some("{\"query\":".to_string())
501 );
502
503 let delta2: StreamingCompletionChunk = serde_json::from_value(delta2_json).unwrap();
504 assert_eq!(
505 delta2.choices[0].delta.tool_calls[0].function.arguments,
506 Some("\"Rust programming\"}".to_string())
507 );
508 }
509
510 #[test]
511 fn test_response_with_error() {
512 let json = json!({
513 "id": "cmpl-abc123",
514 "object": "chat.completion.chunk",
515 "created": 1234567890,
516 "model": "gpt-3.5-turbo",
517 "provider": "openai",
518 "error": { "code": 500, "message": "Provider disconnected" },
519 "choices": [
520 { "index": 0, "delta": { "content": "" }, "finish_reason": "error" }
521 ]
522 });
523
524 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
525 assert!(response.error.is_some());
526 let error = response.error.as_ref().unwrap();
527 assert_eq!(error.code, 500);
528 assert_eq!(error.message, "Provider disconnected");
529 }
530}