rig/providers/openrouter/
streaming.rs1use reqwest_eventsource::{Event, RequestBuilderExt};
2use std::collections::HashMap;
3use tracing::info_span;
4
5use crate::{
6 completion::GetTokenUsage,
7 http_client, json_utils,
8 message::{ToolCall, ToolFunction},
9 streaming::{self},
10};
11use async_stream::stream;
12use futures::StreamExt;
13use reqwest::RequestBuilder;
14use serde_json::{Value, json};
15
16use crate::completion::{CompletionError, CompletionRequest};
17use serde::{Deserialize, Serialize};
18
19#[derive(Serialize, Deserialize, Debug)]
20pub struct StreamingCompletionResponse {
21 pub id: String,
22 pub choices: Vec<StreamingChoice>,
23 pub created: u64,
24 pub model: String,
25 pub object: String,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub system_fingerprint: Option<String>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub usage: Option<ResponseUsage>,
30}
31
32impl GetTokenUsage for FinalCompletionResponse {
33 fn token_usage(&self) -> Option<crate::completion::Usage> {
34 let mut usage = crate::completion::Usage::new();
35
36 usage.input_tokens = self.usage.prompt_tokens as u64;
37 usage.output_tokens = self.usage.completion_tokens as u64;
38 usage.total_tokens = self.usage.total_tokens as u64;
39
40 Some(usage)
41 }
42}
43
44#[derive(Serialize, Deserialize, Debug)]
45pub struct StreamingChoice {
46 #[serde(skip_serializing_if = "Option::is_none")]
47 pub finish_reason: Option<String>,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 pub native_finish_reason: Option<String>,
50 #[serde(skip_serializing_if = "Option::is_none")]
51 pub logprobs: Option<Value>,
52 pub index: usize,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub message: Option<MessageResponse>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub delta: Option<DeltaResponse>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub error: Option<ErrorResponse>,
59}
60
61#[derive(Serialize, Deserialize, Debug)]
62pub struct MessageResponse {
63 pub role: String,
64 pub content: String,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub refusal: Option<Value>,
67 #[serde(default)]
68 pub tool_calls: Vec<OpenRouterToolCall>,
69}
70
71#[derive(Serialize, Deserialize, Debug)]
72pub struct OpenRouterToolFunction {
73 pub name: Option<String>,
74 pub arguments: Option<String>,
75}
76
77#[derive(Serialize, Deserialize, Debug)]
78pub struct OpenRouterToolCall {
79 pub index: usize,
80 pub id: Option<String>,
81 pub r#type: Option<String>,
82 pub function: OpenRouterToolFunction,
83}
84
85#[derive(Serialize, Deserialize, Debug, Clone, Default)]
86pub struct ResponseUsage {
87 pub prompt_tokens: u32,
88 pub completion_tokens: u32,
89 pub total_tokens: u32,
90}
91
92#[derive(Serialize, Deserialize, Debug)]
93pub struct ErrorResponse {
94 pub code: i32,
95 pub message: String,
96 #[serde(skip_serializing_if = "Option::is_none")]
97 pub metadata: Option<HashMap<String, Value>>,
98}
99
100#[derive(Serialize, Deserialize, Debug)]
101pub struct DeltaResponse {
102 pub role: Option<String>,
103 #[serde(skip_serializing_if = "Option::is_none")]
104 pub content: Option<String>,
105 #[serde(default)]
106 pub tool_calls: Vec<OpenRouterToolCall>,
107 #[serde(skip_serializing_if = "Option::is_none")]
108 pub native_finish_reason: Option<String>,
109}
110
111#[derive(Clone, Deserialize, Serialize)]
112pub struct FinalCompletionResponse {
113 pub usage: ResponseUsage,
114}
115
116impl super::CompletionModel<reqwest::Client> {
117 pub(crate) async fn stream(
118 &self,
119 completion_request: CompletionRequest,
120 ) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
121 {
122 let preamble = completion_request.preamble.clone();
123 let request = self.create_completion_request(completion_request)?;
124
125 let request = json_utils::merge(request, json!({"stream": true}));
126
127 let builder = self
128 .client
129 .reqwest_post("/chat/completions")
130 .header("Content-Type", "application/json")
131 .json(&request);
132
133 let span = if tracing::Span::current().is_disabled() {
134 info_span!(
135 target: "rig::completions",
136 "chat_streaming",
137 gen_ai.operation.name = "chat_streaming",
138 gen_ai.provider.name = "openrouter",
139 gen_ai.request.model = self.model,
140 gen_ai.system_instructions = preamble,
141 gen_ai.response.id = tracing::field::Empty,
142 gen_ai.response.model = tracing::field::Empty,
143 gen_ai.usage.output_tokens = tracing::field::Empty,
144 gen_ai.usage.input_tokens = tracing::field::Empty,
145 gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(),
146 gen_ai.output.messages = tracing::field::Empty,
147 )
148 } else {
149 tracing::Span::current()
150 };
151
152 tracing::Instrument::instrument(send_streaming_request(builder), span).await
153 }
154}
155
156pub async fn send_streaming_request(
157 request_builder: RequestBuilder,
158) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
159 let response = request_builder
160 .send()
161 .await
162 .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?;
163
164 if !response.status().is_success() {
165 return Err(CompletionError::ProviderError(format!(
166 "{}: {}",
167 response.status(),
168 response
169 .text()
170 .await
171 .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?
172 )));
173 }
174
175 let stream = stream! {
177 let mut stream = response.bytes_stream();
178 let mut tool_calls = HashMap::new();
179 let mut partial_line = String::new();
180 let mut final_usage = None;
181
182 while let Some(chunk_result) = stream.next().await {
183 let chunk = match chunk_result {
184 Ok(c) => c,
185 Err(e) => {
186 yield Err(CompletionError::from(http_client::Error::Instance(e.into())));
187 break;
188 }
189 };
190
191 let text = match String::from_utf8(chunk.to_vec()) {
192 Ok(t) => t,
193 Err(e) => {
194 yield Err(CompletionError::ResponseError(e.to_string()));
195 break;
196 }
197 };
198
199 for line in text.lines() {
200 let mut line = line.to_string();
201
202 if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
204 continue;
205 }
206
207 line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
209
210 if line.starts_with('{') && !line.ends_with('}') {
212 partial_line = line;
213 continue;
214 }
215
216 if !partial_line.is_empty() {
218 if line.ends_with('}') {
219 partial_line.push_str(&line);
220 line = partial_line;
221 partial_line = String::new();
222 } else {
223 partial_line.push_str(&line);
224 continue;
225 }
226 }
227
228 let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
229 Ok(data) => data,
230 Err(_) => {
231 continue;
232 }
233 };
234
235
236 let choice = data.choices.first().expect("Should have at least one choice");
237
238 if let Some(delta) = &choice.delta {
248 if !delta.tool_calls.is_empty() {
249 for tool_call in &delta.tool_calls {
250 let index = tool_call.index;
251
252 let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
254 id: String::new(),
255 call_id: None,
256 function: ToolFunction {
257 name: String::new(),
258 arguments: serde_json::Value::Null,
259 },
260 });
261
262 if let Some(id) = &tool_call.id && !id.is_empty() {
264 existing_tool_call.id = id.clone();
265 }
266
267 if let Some(name) = &tool_call.function.name && !name.is_empty() {
268 existing_tool_call.function.name = name.clone();
269 }
270
271 if let Some(chunk) = &tool_call.function.arguments {
272 let current_args = match &existing_tool_call.function.arguments {
274 serde_json::Value::Null => String::new(),
275 serde_json::Value::String(s) => s.clone(),
276 v => v.to_string(),
277 };
278
279 let combined = format!("{current_args}{chunk}");
281
282 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
284 match serde_json::from_str(&combined) {
285 Ok(parsed) => existing_tool_call.function.arguments = parsed,
286 Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
287 }
288 } else {
289 existing_tool_call.function.arguments = serde_json::Value::String(combined);
290 }
291 }
292 }
293 }
294
295 if let Some(content) = &delta.content &&!content.is_empty() {
296 yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
297 }
298
299 if let Some(usage) = data.usage {
300 final_usage = Some(usage);
301 }
302 }
303
304 if let Some(message) = &choice.message {
306 if !message.tool_calls.is_empty() {
307 for tool_call in &message.tool_calls {
308 let name = tool_call.function.name.clone();
309 let id = tool_call.id.clone();
310 let arguments = if let Some(args) = &tool_call.function.arguments {
311 match serde_json::from_str(args) {
313 Ok(v) => v,
314 Err(_) => serde_json::Value::String(args.to_string()),
315 }
316 } else {
317 serde_json::Value::Null
318 };
319 let index = tool_call.index;
320
321 tool_calls.insert(index, ToolCall {
322 id: id.unwrap_or_default(),
323 call_id: None,
324 function: ToolFunction {
325 name: name.unwrap_or_default(),
326 arguments,
327 },
328 });
329 }
330 }
331
332 if !message.content.is_empty() {
333 yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
334 }
335 }
336 }
337 }
338
339 for (_, tool_call) in tool_calls.into_iter() {
340
341 yield Ok(streaming::RawStreamingChoice::ToolCall{
342 name: tool_call.function.name,
343 id: tool_call.id,
344 arguments: tool_call.function.arguments,
345 call_id: None
346 });
347 }
348
349 yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
350 usage: final_usage.unwrap_or_default()
351 }))
352
353 };
354
355 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
356 stream,
357 )))
358}
359
360pub async fn send_streaming_request1(
361 request_builder: RequestBuilder,
362) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError> {
363 let mut event_source = request_builder
364 .eventsource()
365 .expect("Cloning request must always succeed");
366
367 let stream = stream! {
368 let mut tool_calls: HashMap<usize, ToolCall> = HashMap::new();
370 let mut final_usage = None;
371
372 while let Some(event_result) = event_source.next().await {
373 match event_result {
374 Ok(Event::Open) => {
375 tracing::trace!("SSE connection opened");
376 continue;
377 }
378
379 Ok(Event::Message(event_message)) => {
380 let raw = event_message.data;
381
382 let parsed = serde_json::from_str::<StreamingCompletionResponse>(&raw);
383 let Ok(data) = parsed else {
384 tracing::debug!("Couldn't parse OpenRouter payload as StreamingCompletionResponse; skipping chunk");
385 continue;
386 };
387
388 let choice = match data.choices.first() {
390 Some(c) => c,
391 None => continue,
392 };
393
394 if let Some(delta) = &choice.delta {
396 if !delta.tool_calls.is_empty() {
397 for tc in &delta.tool_calls {
398 let index = tc.index;
399
400 let existing = tool_calls.entry(index).or_insert_with(|| ToolCall {
402 id: String::new(),
403 call_id: None,
404 function: ToolFunction {
405 name: String::new(),
406 arguments: Value::Null,
407 },
408 });
409
410 if let Some(id) = &tc.id && !id.is_empty() {
412 existing.id = id.clone();
413 }
414
415 if let Some(name) = &tc.function.name && !name.is_empty() {
417 existing.function.name = name.clone();
418 }
419
420 if let Some(chunk) = &tc.function.arguments {
422 let current_args = match &existing.function.arguments {
424 Value::Null => String::new(),
425 Value::String(s) => s.clone(),
426 v => v.to_string(),
427 };
428
429 let combined = format!("{}{}", current_args, chunk);
430
431 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
433 match serde_json::from_str::<Value>(&combined) {
434 Ok(parsed_value) => existing.function.arguments = parsed_value,
435 Err(_) => existing.function.arguments = Value::String(combined),
436 }
437 } else {
438 existing.function.arguments = Value::String(combined);
439 }
440 }
441 }
442 }
443
444 if let Some(content) = &delta.content && !content.is_empty() {
446 yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
447 }
448
449 if let Some(usage) = data.usage {
451 final_usage = Some(usage);
452 }
453 }
454
455 if let Some(message) = &choice.message {
457 if !message.tool_calls.is_empty() {
458 for tc in &message.tool_calls {
459 let idx = tc.index;
460 let name = tc.function.name.clone().unwrap_or_default();
461 let id = tc.id.clone().unwrap_or_default();
462
463 let args_value = if let Some(args_str) = &tc.function.arguments {
464 match serde_json::from_str::<Value>(args_str) {
465 Ok(v) => v,
466 Err(_) => Value::String(args_str.clone()),
467 }
468 } else {
469 Value::Null
470 };
471
472 tool_calls.insert(idx, ToolCall {
473 id,
474 call_id: None,
475 function: ToolFunction {
476 name,
477 arguments: args_value,
478 },
479 });
480 }
481 }
482
483 if !message.content.is_empty() {
484 yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()));
485 }
486 }
487 }
488
489 Err(reqwest_eventsource::Error::StreamEnded) => {
490 break;
491 }
492
493 Err(error) => {
494 tracing::error!(?error, "SSE error from OpenRouter event source");
495 yield Err(CompletionError::ResponseError(error.to_string()));
496 break;
497 }
498 }
499 }
500
501 event_source.close();
503
504 for (_idx, tool_call) in tool_calls.into_iter() {
506 yield Ok(streaming::RawStreamingChoice::ToolCall {
507 name: tool_call.function.name,
508 id: tool_call.id,
509 arguments: tool_call.function.arguments,
510 call_id: None,
511 });
512 }
513
514 yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
516 usage: final_usage.unwrap_or_default(),
517 }));
518 };
519
520 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
521 stream,
522 )))
523}