rig/providers/openrouter/
streaming.rs1use http::Request;
2use std::collections::HashMap;
3use tracing::info_span;
4
5use crate::{
6 completion::GetTokenUsage,
7 http_client::{self, HttpClientExt},
8 json_utils,
9 message::{ToolCall, ToolFunction},
10 streaming::{self},
11};
12use async_stream::stream;
13use futures::StreamExt;
14use serde_json::{Value, json};
15
16use crate::completion::{CompletionError, CompletionRequest};
17use serde::{Deserialize, Serialize};
18
19#[derive(Serialize, Deserialize, Debug)]
20pub struct StreamingCompletionResponse {
21 pub id: String,
22 pub choices: Vec<StreamingChoice>,
23 pub created: u64,
24 pub model: String,
25 pub object: String,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub system_fingerprint: Option<String>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub usage: Option<ResponseUsage>,
30}
31
32impl GetTokenUsage for FinalCompletionResponse {
33 fn token_usage(&self) -> Option<crate::completion::Usage> {
34 let mut usage = crate::completion::Usage::new();
35
36 usage.input_tokens = self.usage.prompt_tokens as u64;
37 usage.output_tokens = self.usage.completion_tokens as u64;
38 usage.total_tokens = self.usage.total_tokens as u64;
39
40 Some(usage)
41 }
42}
43
44#[derive(Serialize, Deserialize, Debug)]
45pub struct StreamingChoice {
46 #[serde(skip_serializing_if = "Option::is_none")]
47 pub finish_reason: Option<String>,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 pub native_finish_reason: Option<String>,
50 #[serde(skip_serializing_if = "Option::is_none")]
51 pub logprobs: Option<Value>,
52 pub index: usize,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub message: Option<MessageResponse>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub delta: Option<DeltaResponse>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub error: Option<ErrorResponse>,
59}
60
61#[derive(Serialize, Deserialize, Debug)]
62pub struct MessageResponse {
63 pub role: String,
64 pub content: String,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub refusal: Option<Value>,
67 #[serde(default)]
68 pub tool_calls: Vec<OpenRouterToolCall>,
69}
70
71#[derive(Serialize, Deserialize, Debug)]
72pub struct OpenRouterToolFunction {
73 pub name: Option<String>,
74 pub arguments: Option<String>,
75}
76
77#[derive(Serialize, Deserialize, Debug)]
78pub struct OpenRouterToolCall {
79 pub index: usize,
80 pub id: Option<String>,
81 pub r#type: Option<String>,
82 pub function: OpenRouterToolFunction,
83}
84
85#[derive(Serialize, Deserialize, Debug, Clone, Default)]
86pub struct ResponseUsage {
87 pub prompt_tokens: u32,
88 pub completion_tokens: u32,
89 pub total_tokens: u32,
90}
91
92#[derive(Serialize, Deserialize, Debug)]
93pub struct ErrorResponse {
94 pub code: i32,
95 pub message: String,
96 #[serde(skip_serializing_if = "Option::is_none")]
97 pub metadata: Option<HashMap<String, Value>>,
98}
99
100#[derive(Serialize, Deserialize, Debug)]
101pub struct DeltaResponse {
102 pub role: Option<String>,
103 #[serde(skip_serializing_if = "Option::is_none")]
104 pub content: Option<String>,
105 #[serde(default)]
106 pub tool_calls: Vec<OpenRouterToolCall>,
107 #[serde(skip_serializing_if = "Option::is_none")]
108 pub native_finish_reason: Option<String>,
109}
110
111#[derive(Clone, Deserialize, Serialize)]
112pub struct FinalCompletionResponse {
113 pub usage: ResponseUsage,
114}
115
116impl<T> super::CompletionModel<T>
117where
118 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
119{
120 pub(crate) async fn stream(
121 &self,
122 completion_request: CompletionRequest,
123 ) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
124 {
125 let preamble = completion_request.preamble.clone();
126 let request = self.create_completion_request(completion_request)?;
127
128 let request = json_utils::merge(request, json!({"stream": true}));
129
130 let body = serde_json::to_vec(&request)?;
131
132 let req = self
133 .client
134 .post("/chat/completions")?
135 .header("Content-Type", "application/json")
136 .body(body)
137 .map_err(|x| CompletionError::HttpError(x.into()))?;
138
139 let span = if tracing::Span::current().is_disabled() {
140 info_span!(
141 target: "rig::completions",
142 "chat_streaming",
143 gen_ai.operation.name = "chat_streaming",
144 gen_ai.provider.name = "openrouter",
145 gen_ai.request.model = self.model,
146 gen_ai.system_instructions = preamble,
147 gen_ai.response.id = tracing::field::Empty,
148 gen_ai.response.model = tracing::field::Empty,
149 gen_ai.usage.output_tokens = tracing::field::Empty,
150 gen_ai.usage.input_tokens = tracing::field::Empty,
151 gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(),
152 gen_ai.output.messages = tracing::field::Empty,
153 )
154 } else {
155 tracing::Span::current()
156 };
157
158 tracing::Instrument::instrument(
159 send_streaming_request(self.client.http_client.clone(), req),
160 span,
161 )
162 .await
163 }
164}
165
166pub async fn send_streaming_request<T>(
167 client: T,
168 req: Request<Vec<u8>>,
169) -> Result<streaming::StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>
170where
171 T: HttpClientExt + Clone + 'static,
172{
173 let response = client.send_streaming(req).await?;
174 let status = response.status();
175
176 if !status.is_success() {
177 return Err(CompletionError::ProviderError(format!(
178 "Got response error trying to send a completion request to OpenRouter: {status}"
179 )));
180 }
181
182 let mut stream = response.into_body();
183
184 let stream = stream! {
186 let mut tool_calls = HashMap::new();
187 let mut partial_line = String::new();
188 let mut final_usage = None;
189
190 while let Some(chunk_result) = stream.next().await {
191 let chunk = match chunk_result {
192 Ok(c) => c,
193 Err(e) => {
194 yield Err(CompletionError::from(http_client::Error::Instance(e.into())));
195 break;
196 }
197 };
198
199 let text = match String::from_utf8(chunk.to_vec()) {
200 Ok(t) => t,
201 Err(e) => {
202 yield Err(CompletionError::ResponseError(e.to_string()));
203 break;
204 }
205 };
206
207 for line in text.lines() {
208 let mut line = line.to_string();
209
210 if line.trim().is_empty() || line.trim() == ": OPENROUTER PROCESSING" || line.trim() == "data: [DONE]" {
212 continue;
213 }
214
215 line = line.strip_prefix("data: ").unwrap_or(&line).to_string();
217
218 if line.starts_with('{') && !line.ends_with('}') {
220 partial_line = line;
221 continue;
222 }
223
224 if !partial_line.is_empty() {
226 if line.ends_with('}') {
227 partial_line.push_str(&line);
228 line = partial_line;
229 partial_line = String::new();
230 } else {
231 partial_line.push_str(&line);
232 continue;
233 }
234 }
235
236 let data = match serde_json::from_str::<StreamingCompletionResponse>(&line) {
237 Ok(data) => data,
238 Err(_) => {
239 continue;
240 }
241 };
242
243
244 let choice = data.choices.first().expect("Should have at least one choice");
245
246 if let Some(delta) = &choice.delta {
256 if !delta.tool_calls.is_empty() {
257 for tool_call in &delta.tool_calls {
258 let index = tool_call.index;
259
260 let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
262 id: String::new(),
263 call_id: None,
264 function: ToolFunction {
265 name: String::new(),
266 arguments: serde_json::Value::Null,
267 },
268 });
269
270 if let Some(id) = &tool_call.id && !id.is_empty() {
272 existing_tool_call.id = id.clone();
273 }
274
275 if let Some(name) = &tool_call.function.name && !name.is_empty() {
276 existing_tool_call.function.name = name.clone();
277 }
278
279 if let Some(chunk) = &tool_call.function.arguments {
280 let current_args = match &existing_tool_call.function.arguments {
282 serde_json::Value::Null => String::new(),
283 serde_json::Value::String(s) => s.clone(),
284 v => v.to_string(),
285 };
286
287 let combined = format!("{current_args}{chunk}");
289
290 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
292 match serde_json::from_str(&combined) {
293 Ok(parsed) => existing_tool_call.function.arguments = parsed,
294 Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
295 }
296 } else {
297 existing_tool_call.function.arguments = serde_json::Value::String(combined);
298 }
299 }
300 }
301 }
302
303 if let Some(content) = &delta.content &&!content.is_empty() {
304 yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
305 }
306
307 if let Some(usage) = data.usage {
308 final_usage = Some(usage);
309 }
310 }
311
312 if let Some(message) = &choice.message {
314 if !message.tool_calls.is_empty() {
315 for tool_call in &message.tool_calls {
316 let name = tool_call.function.name.clone();
317 let id = tool_call.id.clone();
318 let arguments = if let Some(args) = &tool_call.function.arguments {
319 match serde_json::from_str(args) {
321 Ok(v) => v,
322 Err(_) => serde_json::Value::String(args.to_string()),
323 }
324 } else {
325 serde_json::Value::Null
326 };
327 let index = tool_call.index;
328
329 tool_calls.insert(index, ToolCall {
330 id: id.unwrap_or_default(),
331 call_id: None,
332 function: ToolFunction {
333 name: name.unwrap_or_default(),
334 arguments,
335 },
336 });
337 }
338 }
339
340 if !message.content.is_empty() {
341 yield Ok(streaming::RawStreamingChoice::Message(message.content.clone()))
342 }
343 }
344 }
345 }
346
347 for (_, tool_call) in tool_calls.into_iter() {
348
349 yield Ok(streaming::RawStreamingChoice::ToolCall{
350 name: tool_call.function.name,
351 id: tool_call.id,
352 arguments: tool_call.function.arguments,
353 call_id: None
354 });
355 }
356
357 yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse {
358 usage: final_usage.unwrap_or_default()
359 }))
360
361 };
362
363 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
364 stream,
365 )))
366}