1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tracing::info_span;
9use tracing_futures::Instrument;
10
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::HttpClientExt;
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::json_utils;
15use crate::providers::openrouter::{
16 OpenRouterRequestParams, OpenrouterCompletionRequest, ReasoningDetails,
17};
18use crate::streaming;
19
20#[derive(Clone, Serialize, Deserialize, Debug)]
21pub struct StreamingCompletionResponse {
22 pub usage: Usage,
23}
24
25impl GetTokenUsage for StreamingCompletionResponse {
26 fn token_usage(&self) -> Option<crate::completion::Usage> {
27 let mut usage = crate::completion::Usage::new();
28
29 usage.input_tokens = self.usage.prompt_tokens as u64;
30 usage.output_tokens = self.usage.completion_tokens as u64;
31 usage.total_tokens = self.usage.total_tokens as u64;
32
33 Some(usage)
34 }
35}
36
37#[derive(Deserialize, Debug, PartialEq)]
38#[serde(rename_all = "snake_case")]
39pub enum FinishReason {
40 ToolCalls,
41 Stop,
42 Error,
43 ContentFilter,
44 Length,
45 #[serde(untagged)]
46 Other(String),
47}
48
49#[derive(Deserialize, Debug)]
50#[allow(dead_code)]
51struct StreamingChoice {
52 pub finish_reason: Option<FinishReason>,
53 pub native_finish_reason: Option<String>,
54 pub logprobs: Option<Value>,
55 pub index: usize,
56 pub delta: StreamingDelta,
57}
58
59#[derive(Deserialize, Debug)]
60struct StreamingFunction {
61 pub name: Option<String>,
62 pub arguments: Option<String>,
63}
64
65#[derive(Deserialize, Debug)]
66#[allow(dead_code)]
67struct StreamingToolCall {
68 pub index: usize,
69 pub id: Option<String>,
70 pub r#type: Option<String>,
71 pub function: StreamingFunction,
72}
73
74#[derive(Serialize, Deserialize, Debug, Clone, Default)]
75pub struct Usage {
76 pub prompt_tokens: u32,
77 pub completion_tokens: u32,
78 pub total_tokens: u32,
79}
80
81#[derive(Deserialize, Debug)]
82#[allow(dead_code)]
83struct ErrorResponse {
84 pub code: i32,
85 pub message: String,
86}
87
88#[derive(Deserialize, Debug)]
89#[allow(dead_code)]
90struct StreamingDelta {
91 pub role: Option<String>,
92 pub content: Option<String>,
93 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
94 pub tool_calls: Vec<StreamingToolCall>,
95 pub reasoning: Option<String>,
96 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
97 pub reasoning_details: Vec<ReasoningDetails>,
98}
99
100#[derive(Deserialize, Debug)]
101#[allow(dead_code)]
102struct StreamingCompletionChunk {
103 id: String,
104 model: String,
105 choices: Vec<StreamingChoice>,
106 usage: Option<Usage>,
107 error: Option<ErrorResponse>,
108}
109
110impl<T> super::CompletionModel<T>
111where
112 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
113{
114 pub(crate) async fn stream(
115 &self,
116 completion_request: CompletionRequest,
117 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
118 {
119 let request_model = completion_request
120 .model
121 .clone()
122 .unwrap_or_else(|| self.model.clone());
123 let preamble = completion_request.preamble.clone();
124 let mut request = OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
125 model: request_model.as_ref(),
126 request: completion_request,
127 strict_tools: self.strict_tools,
128 })?;
129
130 let params = json_utils::merge(
131 request.additional_params.unwrap_or(serde_json::json!({})),
132 serde_json::json!({"stream": true }),
133 );
134
135 request.additional_params = Some(params);
136
137 let body = serde_json::to_vec(&request)?;
138
139 let req = self
140 .client
141 .post("/chat/completions")?
142 .body(body)
143 .map_err(|x| CompletionError::HttpError(x.into()))?;
144
145 let span = if tracing::Span::current().is_disabled() {
146 info_span!(
147 target: "rig::completions",
148 "chat_streaming",
149 gen_ai.operation.name = "chat_streaming",
150 gen_ai.provider.name = "openrouter",
151 gen_ai.request.model = &request_model,
152 gen_ai.system_instructions = preamble,
153 gen_ai.response.id = tracing::field::Empty,
154 gen_ai.response.model = tracing::field::Empty,
155 gen_ai.usage.output_tokens = tracing::field::Empty,
156 gen_ai.usage.input_tokens = tracing::field::Empty,
157 gen_ai.usage.cached_tokens = tracing::field::Empty,
158 )
159 } else {
160 tracing::Span::current()
161 };
162
163 tracing::Instrument::instrument(
164 send_compatible_streaming_request(self.client.clone(), req),
165 span,
166 )
167 .await
168 }
169}
170
171pub async fn send_compatible_streaming_request<T>(
172 http_client: T,
173 req: Request<Vec<u8>>,
174) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
175where
176 T: HttpClientExt + Clone + 'static,
177{
178 let span = tracing::Span::current();
179 let mut event_source = GenericEventSource::new(http_client, req);
181
182 let stream = stream! {
183 let mut tool_calls: HashMap<usize, streaming::RawStreamingToolCall> = HashMap::new();
185 let mut final_usage = None;
186
187 while let Some(event_result) = event_source.next().await {
188 match event_result {
189 Ok(Event::Open) => {
190 tracing::trace!("SSE connection opened");
191 continue;
192 }
193
194 Ok(Event::Message(message)) => {
195 if message.data.trim().is_empty() || message.data == "[DONE]" {
196 continue;
197 }
198
199 let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
200 Ok(data) => data,
201 Err(error) => {
202 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
203 continue;
204 }
205 };
206
207 let Some(choice) = data.choices.first() else {
209 tracing::debug!("There is no choice");
210 continue;
211 };
212 let delta = &choice.delta;
213
214 if !delta.tool_calls.is_empty() {
215 for tool_call in &delta.tool_calls {
216 let index = tool_call.index;
217
218 let existing_tool_call = tool_calls.entry(index).or_insert_with(streaming::RawStreamingToolCall::empty);
220
221 if let Some(id) = &tool_call.id && !id.is_empty() {
223 existing_tool_call.id = id.clone();
224 }
225
226 if let Some(name) = &tool_call.function.name && !name.is_empty() {
227 existing_tool_call.name = name.clone();
228 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
229 id: existing_tool_call.id.clone(),
230 internal_call_id: existing_tool_call.internal_call_id.clone(),
231 content: streaming::ToolCallDeltaContent::Name(name.clone()),
232 });
233 }
234
235 if let Some(chunk) = &tool_call.function.arguments && !chunk.is_empty() {
237 let current_args = match &existing_tool_call.arguments {
238 serde_json::Value::Null => String::new(),
239 serde_json::Value::String(s) => s.clone(),
240 v => v.to_string(),
241 };
242
243 let combined = format!("{current_args}{chunk}");
245
246 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
248 match serde_json::from_str(&combined) {
249 Ok(parsed) => existing_tool_call.arguments = parsed,
250 Err(_) => existing_tool_call.arguments = serde_json::Value::String(combined),
251 }
252 } else {
253 existing_tool_call.arguments = serde_json::Value::String(combined);
254 }
255
256 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
258 id: existing_tool_call.id.clone(),
259 internal_call_id: existing_tool_call.internal_call_id.clone(),
260 content: streaming::ToolCallDeltaContent::Delta(chunk.clone()),
261 });
262 }
263 }
264
265 for reasoning_detail in &delta.reasoning_details {
267 if let ReasoningDetails::Encrypted { id, data, .. } = reasoning_detail
268 && let Some(id) = id
269 && let Some(tool_call) = tool_calls.values_mut().find(|tool_call| tool_call.id.eq(id))
270 && let Ok(additional_params) = serde_json::to_value(reasoning_detail) {
271 tool_call.signature = Some(data.clone());
272 tool_call.additional_params = Some(additional_params);
273 }
274 }
275 }
276
277 if let Some(reasoning) = &delta.reasoning && !reasoning.is_empty() {
279 yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
280 reasoning: reasoning.clone(),
281 id: None,
282 });
283 }
284
285 if let Some(content) = &delta.content && !content.is_empty() {
287 yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
288 }
289
290 if let Some(usage) = data.usage {
292 final_usage = Some(usage);
293 }
294
295 if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
297 for (_idx, tool_call) in tool_calls.into_iter() {
298 yield Ok(streaming::RawStreamingChoice::ToolCall(
299 finalize_completed_streaming_tool_call(tool_call),
300 ));
301 }
302 tool_calls = HashMap::new();
303 }
304 }
305 Err(crate::http_client::Error::StreamEnded) => {
306 break;
307 }
308 Err(error) => {
309 tracing::error!(?error, "SSE error");
310 yield Err(CompletionError::ProviderError(error.to_string()));
311 break;
312 }
313 }
314 }
315
316 event_source.close();
318
319 for (_idx, tool_call) in tool_calls.into_iter() {
321 yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call));
322 }
323
324 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
326 usage: final_usage.unwrap_or_default(),
327 }));
328 }.instrument(span);
329
330 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
331 stream,
332 )))
333}
334
335fn finalize_completed_streaming_tool_call(
336 mut tool_call: streaming::RawStreamingToolCall,
337) -> streaming::RawStreamingToolCall {
338 if tool_call.arguments.is_null() {
339 tool_call.arguments = Value::Object(serde_json::Map::new());
340 }
341
342 tool_call
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use serde_json::json;
349
350 #[test]
351 fn test_streaming_completion_response_deserialization() {
352 let json = json!({
353 "id": "gen-abc123",
354 "choices": [{
355 "index": 0,
356 "delta": {
357 "role": "assistant",
358 "content": "Hello"
359 }
360 }],
361 "created": 1234567890u64,
362 "model": "gpt-3.5-turbo",
363 "object": "chat.completion.chunk"
364 });
365
366 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
367 assert_eq!(response.id, "gen-abc123");
368 assert_eq!(response.model, "gpt-3.5-turbo");
369 assert_eq!(response.choices.len(), 1);
370 }
371
372 #[test]
373 fn test_delta_with_content() {
374 let json = json!({
375 "role": "assistant",
376 "content": "Hello, world!"
377 });
378
379 let delta: StreamingDelta = serde_json::from_value(json).unwrap();
380 assert_eq!(delta.role, Some("assistant".to_string()));
381 assert_eq!(delta.content, Some("Hello, world!".to_string()));
382 }
383
384 #[test]
385 fn test_delta_with_tool_call() {
386 let json = json!({
387 "role": "assistant",
388 "tool_calls": [{
389 "index": 0,
390 "id": "call_abc",
391 "type": "function",
392 "function": {
393 "name": "get_weather",
394 "arguments": "{\"location\":"
395 }
396 }]
397 });
398
399 let delta: StreamingDelta = serde_json::from_value(json).unwrap();
400 assert_eq!(delta.tool_calls.len(), 1);
401 assert_eq!(delta.tool_calls[0].index, 0);
402 assert_eq!(delta.tool_calls[0].id, Some("call_abc".to_string()));
403 }
404
405 #[test]
406 fn test_tool_call_with_partial_arguments() {
407 let json = json!({
408 "index": 0,
409 "id": null,
410 "type": null,
411 "function": {
412 "name": null,
413 "arguments": "Paris"
414 }
415 });
416
417 let tool_call: StreamingToolCall = serde_json::from_value(json).unwrap();
418 assert_eq!(tool_call.index, 0);
419 assert!(tool_call.id.is_none());
420 assert_eq!(tool_call.function.arguments, Some("Paris".to_string()));
421 }
422
423 #[test]
424 fn test_streaming_with_usage() {
425 let json = json!({
426 "id": "gen-xyz",
427 "choices": [{
428 "index": 0,
429 "delta": {
430 "content": null
431 }
432 }],
433 "created": 1234567890u64,
434 "model": "gpt-4",
435 "object": "chat.completion.chunk",
436 "usage": {
437 "prompt_tokens": 100,
438 "completion_tokens": 50,
439 "total_tokens": 150
440 }
441 });
442
443 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
444 assert!(response.usage.is_some());
445 let usage = response.usage.unwrap();
446 assert_eq!(usage.prompt_tokens, 100);
447 assert_eq!(usage.completion_tokens, 50);
448 assert_eq!(usage.total_tokens, 150);
449 }
450
451 #[test]
452 fn test_multiple_tool_call_deltas() {
453 let start_json = json!({
455 "id": "gen-1",
456 "choices": [{
457 "index": 0,
458 "delta": {
459 "tool_calls": [{
460 "index": 0,
461 "id": "call_123",
462 "type": "function",
463 "function": {
464 "name": "search",
465 "arguments": ""
466 }
467 }]
468 }
469 }],
470 "created": 1234567890u64,
471 "model": "gpt-4",
472 "object": "chat.completion.chunk"
473 });
474
475 let delta1_json = json!({
476 "id": "gen-2",
477 "choices": [{
478 "index": 0,
479 "delta": {
480 "tool_calls": [{
481 "index": 0,
482 "function": {
483 "arguments": "{\"query\":"
484 }
485 }]
486 }
487 }],
488 "created": 1234567890u64,
489 "model": "gpt-4",
490 "object": "chat.completion.chunk"
491 });
492
493 let delta2_json = json!({
494 "id": "gen-3",
495 "choices": [{
496 "index": 0,
497 "delta": {
498 "tool_calls": [{
499 "index": 0,
500 "function": {
501 "arguments": "\"Rust programming\"}"
502 }
503 }]
504 }
505 }],
506 "created": 1234567890u64,
507 "model": "gpt-4",
508 "object": "chat.completion.chunk"
509 });
510
511 let start: StreamingCompletionChunk = serde_json::from_value(start_json).unwrap();
513 assert_eq!(
514 start.choices[0].delta.tool_calls[0].id,
515 Some("call_123".to_string())
516 );
517
518 let delta1: StreamingCompletionChunk = serde_json::from_value(delta1_json).unwrap();
519 assert_eq!(
520 delta1.choices[0].delta.tool_calls[0].function.arguments,
521 Some("{\"query\":".to_string())
522 );
523
524 let delta2: StreamingCompletionChunk = serde_json::from_value(delta2_json).unwrap();
525 assert_eq!(
526 delta2.choices[0].delta.tool_calls[0].function.arguments,
527 Some("\"Rust programming\"}".to_string())
528 );
529 }
530
531 #[test]
532 fn test_response_with_error() {
533 let json = json!({
534 "id": "cmpl-abc123",
535 "object": "chat.completion.chunk",
536 "created": 1234567890,
537 "model": "gpt-3.5-turbo",
538 "provider": "openai",
539 "error": { "code": 500, "message": "Provider disconnected" },
540 "choices": [
541 { "index": 0, "delta": { "content": "" }, "finish_reason": "error" }
542 ]
543 });
544
545 let response: StreamingCompletionChunk = serde_json::from_value(json).unwrap();
546 assert!(response.error.is_some());
547 let error = response.error.as_ref().unwrap();
548 assert_eq!(error.code, 500);
549 assert_eq!(error.message, "Provider disconnected");
550 }
551}