1use super::{
2 client::{ApiErrorResponse, ApiResponse, Client, Usage},
3 streaming::StreamingCompletionResponse,
4};
5use crate::message;
6use crate::telemetry::SpanCombinator;
7use crate::{
8 OneOrMany,
9 completion::{self, CompletionError, CompletionRequest},
10 http_client::HttpClientExt,
11 json_utils,
12 one_or_many::string_or_one_or_many,
13 providers::openai,
14};
15use bytes::Bytes;
16use serde::{Deserialize, Serialize};
17use tracing::{Instrument, info_span};
18
19pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
25pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
27pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
29pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
31
32#[derive(Debug, Serialize, Deserialize)]
36pub struct CompletionResponse {
37 pub id: String,
38 pub object: String,
39 pub created: u64,
40 pub model: String,
41 pub choices: Vec<Choice>,
42 pub system_fingerprint: Option<String>,
43 pub usage: Option<Usage>,
44}
45
46impl From<ApiErrorResponse> for CompletionError {
47 fn from(err: ApiErrorResponse) -> Self {
48 CompletionError::ProviderError(err.message)
49 }
50}
51
52impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
53 type Error = CompletionError;
54
55 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
56 let choice = response.choices.first().ok_or_else(|| {
57 CompletionError::ResponseError("Response contained no choices".to_owned())
58 })?;
59
60 let content = match &choice.message {
61 Message::Assistant {
62 content,
63 tool_calls,
64 reasoning,
65 ..
66 } => {
67 let mut content = content
68 .iter()
69 .map(|c| match c {
70 openai::AssistantContent::Text { text } => {
71 completion::AssistantContent::text(text)
72 }
73 openai::AssistantContent::Refusal { refusal } => {
74 completion::AssistantContent::text(refusal)
75 }
76 })
77 .collect::<Vec<_>>();
78
79 content.extend(
80 tool_calls
81 .iter()
82 .map(|call| {
83 completion::AssistantContent::tool_call(
84 &call.id,
85 &call.function.name,
86 call.function.arguments.clone(),
87 )
88 })
89 .collect::<Vec<_>>(),
90 );
91
92 if let Some(reasoning) = reasoning {
93 content.push(completion::AssistantContent::reasoning(reasoning));
94 }
95
96 Ok(content)
97 }
98 _ => Err(CompletionError::ResponseError(
99 "Response did not contain a valid message or tool call".into(),
100 )),
101 }?;
102
103 let choice = OneOrMany::many(content).map_err(|_| {
104 CompletionError::ResponseError(
105 "Response contained no message or tool call (empty)".to_owned(),
106 )
107 })?;
108
109 let usage = response
110 .usage
111 .as_ref()
112 .map(|usage| completion::Usage {
113 input_tokens: usage.prompt_tokens as u64,
114 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
115 total_tokens: usage.total_tokens as u64,
116 })
117 .unwrap_or_default();
118
119 Ok(completion::CompletionResponse {
120 choice,
121 usage,
122 raw_response: response,
123 })
124 }
125}
126
127#[derive(Debug, Deserialize, Serialize)]
128pub struct Choice {
129 pub index: usize,
130 pub native_finish_reason: Option<String>,
131 pub message: Message,
132 pub finish_reason: Option<String>,
133}
134
135#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
140#[serde(tag = "role", rename_all = "lowercase")]
141pub enum Message {
142 #[serde(alias = "developer")]
143 System {
144 #[serde(deserialize_with = "string_or_one_or_many")]
145 content: OneOrMany<openai::SystemContent>,
146 #[serde(skip_serializing_if = "Option::is_none")]
147 name: Option<String>,
148 },
149 User {
150 #[serde(deserialize_with = "string_or_one_or_many")]
151 content: OneOrMany<openai::UserContent>,
152 #[serde(skip_serializing_if = "Option::is_none")]
153 name: Option<String>,
154 },
155 Assistant {
156 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
157 content: Vec<openai::AssistantContent>,
158 #[serde(skip_serializing_if = "Option::is_none")]
159 refusal: Option<String>,
160 #[serde(skip_serializing_if = "Option::is_none")]
161 audio: Option<openai::AudioAssistant>,
162 #[serde(skip_serializing_if = "Option::is_none")]
163 name: Option<String>,
164 #[serde(
165 default,
166 deserialize_with = "json_utils::null_or_vec",
167 skip_serializing_if = "Vec::is_empty"
168 )]
169 tool_calls: Vec<openai::ToolCall>,
170 #[serde(skip_serializing_if = "Option::is_none")]
171 reasoning: Option<String>,
172 },
173 #[serde(rename = "tool")]
174 ToolResult {
175 tool_call_id: String,
176 content: OneOrMany<openai::ToolResultContent>,
177 },
178}
179
180impl Message {
181 pub fn system(content: &str) -> Self {
182 Message::System {
183 content: OneOrMany::one(content.to_owned().into()),
184 name: None,
185 }
186 }
187}
188
189impl From<openai::Message> for Message {
190 fn from(value: openai::Message) -> Self {
191 match value {
192 openai::Message::System { content, name } => Self::System { content, name },
193 openai::Message::User { content, name } => Self::User { content, name },
194 openai::Message::Assistant {
195 content,
196 refusal,
197 audio,
198 name,
199 tool_calls,
200 } => Self::Assistant {
201 content,
202 refusal,
203 audio,
204 name,
205 tool_calls,
206 reasoning: None,
207 },
208 openai::Message::ToolResult {
209 tool_call_id,
210 content,
211 } => Self::ToolResult {
212 tool_call_id,
213 content,
214 },
215 }
216 }
217}
218
219impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
220 type Error = message::MessageError;
221
222 fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
223 let mut text_content = Vec::new();
224 let mut tool_calls = Vec::new();
225 let mut reasoning = None;
226
227 for content in value.into_iter() {
228 match content {
229 message::AssistantContent::Text(text) => text_content.push(text),
230 message::AssistantContent::ToolCall(tool_call) => tool_calls.push(tool_call),
231 message::AssistantContent::Reasoning(r) => {
232 reasoning = r.reasoning.into_iter().next();
233 }
234 message::AssistantContent::Image(_) => {
235 return Err(Self::Error::ConversionError(
236 "OpenRouter currently doesn't support images.".into(),
237 ));
238 }
239 }
240 }
241
242 Ok(vec![Message::Assistant {
245 content: text_content
246 .into_iter()
247 .map(|content| content.text.into())
248 .collect::<Vec<_>>(),
249 refusal: None,
250 audio: None,
251 name: None,
252 tool_calls: tool_calls
253 .into_iter()
254 .map(|tool_call| tool_call.into())
255 .collect::<Vec<_>>(),
256 reasoning,
257 }])
258 }
259}
260
261impl TryFrom<message::Message> for Vec<Message> {
264 type Error = message::MessageError;
265
266 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
267 match message {
268 message::Message::User { content } => {
269 let messages: Vec<openai::Message> = content.try_into()?;
270 Ok(messages.into_iter().map(Message::from).collect::<Vec<_>>())
271 }
272 message::Message::Assistant { content, .. } => content.try_into(),
273 }
274 }
275}
276
277#[derive(Debug, Serialize, Deserialize)]
278#[serde(untagged, rename_all = "snake_case")]
279pub enum ToolChoice {
280 None,
281 Auto,
282 Required,
283 Function(Vec<ToolChoiceFunctionKind>),
284}
285
286impl TryFrom<crate::message::ToolChoice> for ToolChoice {
287 type Error = CompletionError;
288
289 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
290 let res = match value {
291 crate::message::ToolChoice::None => Self::None,
292 crate::message::ToolChoice::Auto => Self::Auto,
293 crate::message::ToolChoice::Required => Self::Required,
294 crate::message::ToolChoice::Specific { function_names } => {
295 let vec: Vec<ToolChoiceFunctionKind> = function_names
296 .into_iter()
297 .map(|name| ToolChoiceFunctionKind::Function { name })
298 .collect();
299
300 Self::Function(vec)
301 }
302 };
303
304 Ok(res)
305 }
306}
307
308#[derive(Debug, Serialize, Deserialize)]
309#[serde(tag = "type", content = "function")]
310pub enum ToolChoiceFunctionKind {
311 Function { name: String },
312}
313
314#[derive(Debug, Serialize, Deserialize)]
315pub(super) struct OpenrouterCompletionRequest {
316 model: String,
317 pub messages: Vec<Message>,
318 #[serde(flatten, skip_serializing_if = "Option::is_none")]
319 temperature: Option<f64>,
320 #[serde(skip_serializing_if = "Vec::is_empty")]
321 tools: Vec<crate::providers::openai::completion::ToolDefinition>,
322 #[serde(flatten, skip_serializing_if = "Option::is_none")]
323 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
324 #[serde(flatten, skip_serializing_if = "Option::is_none")]
325 pub additional_params: Option<serde_json::Value>,
326}
327
328impl TryFrom<(&str, CompletionRequest)> for OpenrouterCompletionRequest {
329 type Error = CompletionError;
330
331 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
332 let mut full_history: Vec<Message> = match &req.preamble {
333 Some(preamble) => vec![Message::system(preamble)],
334 None => vec![],
335 };
336 if let Some(docs) = req.normalized_documents() {
337 let docs: Vec<Message> = docs.try_into()?;
338 full_history.extend(docs);
339 }
340
341 let chat_history: Vec<Message> = req
342 .chat_history
343 .clone()
344 .into_iter()
345 .map(|message| message.try_into())
346 .collect::<Result<Vec<Vec<Message>>, _>>()?
347 .into_iter()
348 .flatten()
349 .collect();
350
351 full_history.extend(chat_history);
352
353 let tool_choice = req
354 .tool_choice
355 .clone()
356 .map(crate::providers::openai::completion::ToolChoice::try_from)
357 .transpose()?;
358
359 Ok(Self {
360 model: model.to_string(),
361 messages: full_history,
362 temperature: req.temperature,
363 tools: req
364 .tools
365 .clone()
366 .into_iter()
367 .map(crate::providers::openai::completion::ToolDefinition::from)
368 .collect::<Vec<_>>(),
369 tool_choice,
370 additional_params: req.additional_params,
371 })
372 }
373}
374
375#[derive(Clone)]
376pub struct CompletionModel<T = reqwest::Client> {
377 pub(crate) client: Client<T>,
378 pub model: String,
379}
380
381impl<T> CompletionModel<T> {
382 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
383 Self {
384 client,
385 model: model.into(),
386 }
387 }
388}
389
390impl<T> completion::CompletionModel for CompletionModel<T>
391where
392 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
393{
394 type Response = CompletionResponse;
395 type StreamingResponse = StreamingCompletionResponse;
396
397 type Client = Client<T>;
398
399 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
400 Self::new(client.clone(), model)
401 }
402
403 #[cfg_attr(feature = "worker", worker::send)]
404 async fn completion(
405 &self,
406 completion_request: CompletionRequest,
407 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
408 let preamble = completion_request.preamble.clone();
409 let request =
410 OpenrouterCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
411 let span = if tracing::Span::current().is_disabled() {
412 info_span!(
413 target: "rig::completions",
414 "chat",
415 gen_ai.operation.name = "chat",
416 gen_ai.provider.name = "openrouter",
417 gen_ai.request.model = self.model,
418 gen_ai.system_instructions = preamble,
419 gen_ai.response.id = tracing::field::Empty,
420 gen_ai.response.model = tracing::field::Empty,
421 gen_ai.usage.output_tokens = tracing::field::Empty,
422 gen_ai.usage.input_tokens = tracing::field::Empty,
423 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
424 gen_ai.output.messages = tracing::field::Empty,
425 )
426 } else {
427 tracing::Span::current()
428 };
429
430 let body = serde_json::to_vec(&request)?;
431
432 let req = self
433 .client
434 .post("/chat/completions")?
435 .body(body)
436 .map_err(|x| CompletionError::HttpError(x.into()))?;
437
438 async move {
439 let response = self.client.send::<_, Bytes>(req).await?;
440 let status = response.status();
441 let response_body = response.into_body().into_future().await?.to_vec();
442
443 if status.is_success() {
444 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
445 ApiResponse::Ok(response) => {
446 let span = tracing::Span::current();
447 span.record_token_usage(&response.usage);
448 span.record_model_output(&response.choices);
449 span.record("gen_ai.response.id", &response.id);
450 span.record("gen_ai.response.model_name", &response.model);
451
452 tracing::debug!(target: "rig::completions",
453 "OpenRouter response: {response:?}");
454 response.try_into()
455 }
456 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
457 }
458 } else {
459 Err(CompletionError::ProviderError(
460 String::from_utf8_lossy(&response_body).to_string(),
461 ))
462 }
463 }
464 .instrument(span)
465 .await
466 }
467
468 #[cfg_attr(feature = "worker", worker::send)]
469 async fn stream(
470 &self,
471 completion_request: CompletionRequest,
472 ) -> Result<
473 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
474 CompletionError,
475 > {
476 CompletionModel::stream(self, completion_request).await
477 }
478}