1use http::Request;
2use std::collections::HashMap;
3use tracing::info_span;
4
5use crate::{
6 completion::GetTokenUsage,
7 http_client::{self, HttpClientExt},
8 json_utils,
9 message::{ToolCall, ToolFunction},
10 streaming::{self},
11};
12use async_stream::stream;
13use futures::StreamExt;
14use serde_json::{Value, json};
15
16use crate::completion::{CompletionError, CompletionRequest};
17use serde::{Deserialize, Serialize};
18
19#[derive(Serialize, Deserialize, Debug)]
20pub struct StreamingCompletionResponse {
21 pub id: String,
22 pub choices: Vec<StreamingChoice>,
23 pub created: u64,
24 pub model: String,
25 pub object: String,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub system_fingerprint: Option<String>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub usage: Option<ResponseUsage>,
30}
31
32impl GetTokenUsage for FinalCompletionResponse {
33 fn token_usage(&self) -> Option<crate::completion::Usage> {
34 let mut usage = crate::completion::Usage::new();
35
36 usage.input_tokens = self.usage.prompt_tokens as u64;
37 usage.output_tokens = self.usage.completion_tokens as u64;
38 usage.total_tokens = self.usage.total_tokens as u64;
39
40 Some(usage)
41 }
42}
43
44#[derive(Serialize, Deserialize, Debug)]
45pub struct StreamingChoice {
46 #[serde(skip_serializing_if = "Option::is_none")]
47 pub finish_reason: Option<String>,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 pub native_finish_reason: Option<String>,
50 #[serde(skip_serializing_if = "Option::is_none")]
51 pub logprobs: Option<Value>,
52 pub index: usize,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub message: Option<MessageResponse>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub delta: Option<DeltaResponse>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub error: Option<ErrorResponse>,
59}
60
61#[derive(Serialize, Deserialize, Debug)]
62pub struct MessageResponse {
63 pub role: String,
64 pub content: String,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub refusal: Option<Value>,
67 #[serde(default)]
68 pub tool_calls: Vec<OpenRouterToolCall>,
69}
70
71#[derive(Serialize, Deserialize, Debug)]
72pub struct OpenRouterToolFunction {
73 pub name: Option<String>,
74 pub arguments: Option<String>,
75}
76
77#[derive(Serialize, Deserialize, Debug)]
78pub struct OpenRouterToolCall {
79 pub index: usize,
80 pub id: Option<String>,
81 pub r#type: Option<String>,
82 pub function: OpenRouterToolFunction,
83}
84
85#[derive(Serialize, Deserialize, Debug, Clone, Default)]
86pub struct ResponseUsage {
87 pub prompt_tokens: u32,
88 pub completion_tokens: u32,
89 pub total_tokens: u32,
90}
91
92#[derive(Serialize, Deserialize, Debug)]
93pub struct ErrorResponse {
94 pub code: i32,
95 pub message: String,
96 #[serde(skip_serializing_if = "Option::is_none")]
97 pub metadata: Option<HashMap<String, Value>>,
98}
99
100#[derive(Serialize, Deserialize, Debug)]
101pub struct DeltaResponse {
102 pub role: Option<String>,
103 #[serde(skip_serializing_if = "Option::is_none")]
104 pub content: Option<String>,
105 #[serde(default)]
106 pub tool_calls: Vec<OpenRouterToolCall>,
107 #[serde(skip_serializing_if = "Option::is_none")]
108 pub native_finish_reason: Option<String>,
109}
110
111#[derive(Clone, Deserialize, Serialize)]
112pub struct FinalCompletionResponse {
113 pub usage: ResponseUsage,
114}
115
116impl<T> super::CompletionModel<T>
117where
118 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
119{
120 pub(crate) async fn stream(
121 &self,
122 completion_request: CompletionRequest,
123 ) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
124 {
125 let preamble = completion_request.preamble.clone();
126 let request = self.create_completion_request(completion_request)?;
127
128 let request = json_utils::merge(request, json!({"stream": true}));
129
130 let body = serde_json::to_vec(&request)?;
131
132 let req = self
133 .client
134 .post("/chat/completions")?
135 .header("Content-Type", "application/json")
136 .body(body)
137 .map_err(|x| CompletionError::HttpError(x.into()))?;
138
139 let span = if tracing::Span::current().is_disabled() {
140 info_span!(
141 target: "rig::completions",
142 "chat_streaming",
143 gen_ai.operation.name = "chat_streaming",
144 gen_ai.provider.name = "openrouter",
145 gen_ai.request.model = self.model,
146 gen_ai.system_instructions = preamble,
147 gen_ai.response.id = tracing::field::Empty,
148 gen_ai.response.model = tracing::field::Empty,
149 gen_ai.usage.output_tokens = tracing::field::Empty,
150 gen_ai.usage.input_tokens = tracing::field::Empty,
151 gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(),
152 gen_ai.output.messages = tracing::field::Empty,
153 )
154 } else {
155 tracing::Span::current()
156 };
157
158 tracing::Instrument::instrument(
159 send_streaming_request(self.client.http_client.clone(), req),
160 span,
161 )
162 .await
163 }
164}
165
166pub async fn send_streaming_request<T>(
167 client: T,
168 req: Request<Vec<u8>>,
169) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
170where
171 T: HttpClientExt + Clone + 'static,
172{
173 let response = client.send_streaming(req).await?;
174 let status = response.status();
175
176 if !status.is_success() {
177 return Err(CompletionError::ProviderError(format!(
178 "Got response error trying to send a completion request to OpenRouter: {status}"
179 )));
180 }
181
182 let mut stream = response.into_body();
183
184 let stream = stream! {
186 let mut tool_calls = HashMap::new();
187 let mut partial_line = String::new();
188 let mut final_usage = None;
189
190 while let Some(chunk_result) = stream.next().await {
191 let chunk = match chunk_result {
192 Ok(c) => c,
193 Err(e) => {
194 yield Err(CompletionError::from(http_client::Error::Instance(e.into())));
195 break;
196 }
197 };
198
199 let text = match String::from_utf8(chunk.to_vec()) {
200 Ok(t) => t,
201 Err(e) => {
202 yield Err(CompletionError::ResponseError(e.to_string()));
203 break;
204 }
205 };
206
207 for line in text.lines() {
208 let mut line = line.to_string();
209
210 if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
212 continue;
213 }
214
215 line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
217
218 if line.starts_with('{') && !line.ends_with('}') {
220 partial_line = line;
221 continue;
222 }
223
224 if !partial_line.is_empty() {
226 if line.ends_with('}') {
227 partial_line.push_str(&line);
228 line = partial_line;
229 partial_line = String::new();
230 } else {
231 partial_line.push_str(&line);
232 continue;
233 }
234 }
235
236 let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
237 Ok(data) => data,
238 Err(_) => {
239 continue;
240 }
241 };
242
243
244 let choice = data.choices.first().expect("Should have at least one choice");
245
246 if let Some(delta) = &choice.delta {
256 if !delta.tool_calls.is_empty() {
257 for tool_call in &delta.tool_calls {
258 let index = tool_call.index;
259
260 let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
262 id: String::new(),
263 call_id: None,
264 function: ToolFunction {
265 name: String::new(),
266 arguments: serde_json::Value::Null,
267 },
268 });
269
270 if let Some(id) = &tool_call.id && !id.is_empty() {
272 existing_tool_call.id = id.clone();
273 }
274
275 if let Some(name) = &tool_call.function.name && !name.is_empty() {
276 existing_tool_call.function.name = name.clone();
277 }
278
279 if let Some(chunk) = &tool_call.function.arguments {
280 let current_args = match &existing_tool_call.function.arguments {
282 serde_json::Value::Null => String::new(),
283 serde_json::Value::String(s) => s.clone(),
284 v => v.to_string(),
285 };
286
287 let combined = format!("{current_args}{chunk}");
289
290 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
292 match serde_json::from_str(&combined) {
293 Ok(parsed) => existing_tool_call.function.arguments = parsed,
294 Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
295 }
296 } else {
297 existing_tool_call.function.arguments = serde_json::Value::String(combined);
298 }
299
300 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
302 id: existing_tool_call.id.clone(),
303 delta: chunk.clone(),
304 });
305 }
306 }
307 }
308
309 if let Some(content) = &delta.content &&!content.is_empty() {
310 yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
311 }
312
313 if let Some(usage) = data.usage {
314 final_usage = Some(usage);
315 }
316 }
317
318 if let Some(message) = &choice.message {
320 if !message.tool_calls.is_empty() {
321 for tool_call in &message.tool_calls {
322 let name = tool_call.function.name.clone();
323 let id = tool_call.id.clone();
324 let arguments = if let Some(args) = &tool_call.function.arguments {
325 match serde_json::from_str(args) {
327 Ok(v) => v,
328 Err(_) => serde_json::Value::String(args.to_string()),
329 }
330 } else {
331 serde_json::Value::Null
332 };
333 let index = tool_call.index;
334
335 tool_calls.insert(index, ToolCall {
336 id: id.unwrap_or_default(),
337 call_id: None,
338 function: ToolFunction {
339 name: name.unwrap_or_default(),
340 arguments,
341 },
342 });
343 }
344 }
345
346 if !message.content.is_empty() {
347 yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
348 }
349 }
350 }
351 }
352
353 for (_, tool_call) in tool_calls.into_iter() {
354
355 yield Ok(streaming::RawStreamingChoice::ToolCall{
356 name: tool_call.function.name,
357 id: tool_call.id,
358 arguments: tool_call.function.arguments,
359 call_id: None
360 });
361 }
362
363 yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
364 usage: final_usage.unwrap_or_default()
365 }))
366
367 };
368
369 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
370 stream,
371 )))
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use serde_json::json;
378
379 #[test]
380 fn test_streaming_completion_response_deserialization() {
381 let json = json!({
382 "id": "gen-abc123",
383 "choices": [{
384 "index": 0,
385 "delta": {
386 "role": "assistant",
387 "content": "Hello"
388 }
389 }],
390 "created": 1234567890u64,
391 "model": "gpt-3.5-turbo",
392 "object": "chat.completion.chunk"
393 });
394
395 let response: StreamingCompletionResponse = serde_json::from_value(json).unwrap();
396 assert_eq!(response.id, "gen-abc123");
397 assert_eq!(response.model, "gpt-3.5-turbo");
398 assert_eq!(response.choices.len(), 1);
399 }
400
401 #[test]
402 fn test_delta_with_content() {
403 let json = json!({
404 "role": "assistant",
405 "content": "Hello, world!"
406 });
407
408 let delta: DeltaResponse = serde_json::from_value(json).unwrap();
409 assert_eq!(delta.role, Some("assistant".to_string()));
410 assert_eq!(delta.content, Some("Hello, world!".to_string()));
411 }
412
413 #[test]
414 fn test_delta_with_tool_call() {
415 let json = json!({
416 "role": "assistant",
417 "tool_calls": [{
418 "index": 0,
419 "id": "call_abc",
420 "type": "function",
421 "function": {
422 "name": "get_weather",
423 "arguments": "{\"location\":"
424 }
425 }]
426 });
427
428 let delta: DeltaResponse = serde_json::from_value(json).unwrap();
429 assert_eq!(delta.tool_calls.len(), 1);
430 assert_eq!(delta.tool_calls[0].index, 0);
431 assert_eq!(delta.tool_calls[0].id, Some("call_abc".to_string()));
432 }
433
434 #[test]
435 fn test_tool_call_with_partial_arguments() {
436 let json = json!({
437 "index": 0,
438 "id": null,
439 "type": null,
440 "function": {
441 "name": null,
442 "arguments": "Paris"
443 }
444 });
445
446 let tool_call: OpenRouterToolCall = serde_json::from_value(json).unwrap();
447 assert_eq!(tool_call.index, 0);
448 assert!(tool_call.id.is_none());
449 assert_eq!(tool_call.function.arguments, Some("Paris".to_string()));
450 }
451
452 #[test]
453 fn test_streaming_with_usage() {
454 let json = json!({
455 "id": "gen-xyz",
456 "choices": [{
457 "index": 0,
458 "delta": {
459 "content": null
460 }
461 }],
462 "created": 1234567890u64,
463 "model": "gpt-4",
464 "object": "chat.completion.chunk",
465 "usage": {
466 "prompt_tokens": 100,
467 "completion_tokens": 50,
468 "total_tokens": 150
469 }
470 });
471
472 let response: StreamingCompletionResponse = serde_json::from_value(json).unwrap();
473 assert!(response.usage.is_some());
474 let usage = response.usage.unwrap();
475 assert_eq!(usage.prompt_tokens, 100);
476 assert_eq!(usage.completion_tokens, 50);
477 assert_eq!(usage.total_tokens, 150);
478 }
479
480 #[test]
481 fn test_multiple_tool_call_deltas() {
482 let start_json = json!({
484 "id": "gen-1",
485 "choices": [{
486 "index": 0,
487 "delta": {
488 "tool_calls": [{
489 "index": 0,
490 "id": "call_123",
491 "type": "function",
492 "function": {
493 "name": "search",
494 "arguments": ""
495 }
496 }]
497 }
498 }],
499 "created": 1234567890u64,
500 "model": "gpt-4",
501 "object": "chat.completion.chunk"
502 });
503
504 let delta1_json = json!({
505 "id": "gen-2",
506 "choices": [{
507 "index": 0,
508 "delta": {
509 "tool_calls": [{
510 "index": 0,
511 "function": {
512 "arguments": "{\"query\":"
513 }
514 }]
515 }
516 }],
517 "created": 1234567890u64,
518 "model": "gpt-4",
519 "object": "chat.completion.chunk"
520 });
521
522 let delta2_json = json!({
523 "id": "gen-3",
524 "choices": [{
525 "index": 0,
526 "delta": {
527 "tool_calls": [{
528 "index": 0,
529 "function": {
530 "arguments": "\"Rust programming\"}"
531 }
532 }]
533 }
534 }],
535 "created": 1234567890u64,
536 "model": "gpt-4",
537 "object": "chat.completion.chunk"
538 });
539
540 let start: StreamingCompletionResponse = serde_json::from_value(start_json).unwrap();
542 assert_eq!(
543 start.choices[0].delta.as_ref().unwrap().tool_calls[0].id,
544 Some("call_123".to_string())
545 );
546
547 let delta1: StreamingCompletionResponse = serde_json::from_value(delta1_json).unwrap();
548 assert_eq!(
549 delta1.choices[0].delta.as_ref().unwrap().tool_calls[0]
550 .function
551 .arguments,
552 Some("{\"query\":".to_string())
553 );
554
555 let delta2: StreamingCompletionResponse = serde_json::from_value(delta2_json).unwrap();
556 assert_eq!(
557 delta2.choices[0].delta.as_ref().unwrap().tool_calls[0]
558 .function
559 .arguments,
560 Some("\"Rust programming\"}".to_string())
561 );
562 }
563
564 #[test]
565 fn test_response_with_error() {
566 let json = json!({
567 "id": "gen-err",
568 "choices": [{
569 "index": 0,
570 "error": {
571 "code": 400,
572 "message": "Invalid request"
573 }
574 }],
575 "created": 1234567890u64,
576 "model": "gpt-4",
577 "object": "chat.completion.chunk"
578 });
579
580 let response: StreamingCompletionResponse = serde_json::from_value(json).unwrap();
581 assert!(response.choices[0].error.is_some());
582 let error = response.choices[0].error.as_ref().unwrap();
583 assert_eq!(error.code, 400);
584 assert_eq!(error.message, "Invalid request");
585 }
586}