1use super::{
2 client::{ApiErrorResponse, ApiResponse, Client, Usage},
3 streaming::StreamingCompletionResponse,
4};
5use crate::message;
6use crate::telemetry::SpanCombinator;
7use crate::{
8 OneOrMany,
9 completion::{self, CompletionError, CompletionRequest},
10 http_client::HttpClientExt,
11 json_utils,
12 one_or_many::string_or_one_or_many,
13 providers::openai,
14};
15use bytes::Bytes;
16use serde::{Deserialize, Serialize};
17use tracing::{Instrument, Level, enabled, info_span};
18
19pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
25pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
27pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
29pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
31
32#[derive(Debug, Serialize, Deserialize)]
36pub struct CompletionResponse {
37 pub id: String,
38 pub object: String,
39 pub created: u64,
40 pub model: String,
41 pub choices: Vec<Choice>,
42 pub system_fingerprint: Option<String>,
43 pub usage: Option<Usage>,
44}
45
46impl From<ApiErrorResponse> for CompletionError {
47 fn from(err: ApiErrorResponse) -> Self {
48 CompletionError::ProviderError(err.message)
49 }
50}
51
52impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
53 type Error = CompletionError;
54
55 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
56 let choice = response.choices.first().ok_or_else(|| {
57 CompletionError::ResponseError("Response contained no choices".to_owned())
58 })?;
59
60 let content = match &choice.message {
61 Message::Assistant {
62 content,
63 tool_calls,
64 reasoning,
65 ..
66 } => {
67 let mut content = content
68 .iter()
69 .map(|c| match c {
70 openai::AssistantContent::Text { text } => {
71 completion::AssistantContent::text(text)
72 }
73 openai::AssistantContent::Refusal { refusal } => {
74 completion::AssistantContent::text(refusal)
75 }
76 })
77 .collect::<Vec<_>>();
78
79 content.extend(
80 tool_calls
81 .iter()
82 .map(|call| {
83 completion::AssistantContent::tool_call(
84 &call.id,
85 &call.function.name,
86 call.function.arguments.clone(),
87 )
88 })
89 .collect::<Vec<_>>(),
90 );
91
92 if let Some(reasoning) = reasoning {
93 content.push(completion::AssistantContent::reasoning(reasoning));
94 }
95
96 Ok(content)
97 }
98 _ => Err(CompletionError::ResponseError(
99 "Response did not contain a valid message or tool call".into(),
100 )),
101 }?;
102
103 let choice = OneOrMany::many(content).map_err(|_| {
104 CompletionError::ResponseError(
105 "Response contained no message or tool call (empty)".to_owned(),
106 )
107 })?;
108
109 let usage = response
110 .usage
111 .as_ref()
112 .map(|usage| completion::Usage {
113 input_tokens: usage.prompt_tokens as u64,
114 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
115 total_tokens: usage.total_tokens as u64,
116 })
117 .unwrap_or_default();
118
119 Ok(completion::CompletionResponse {
120 choice,
121 usage,
122 raw_response: response,
123 })
124 }
125}
126
127#[derive(Debug, Deserialize, Serialize)]
128pub struct Choice {
129 pub index: usize,
130 pub native_finish_reason: Option<String>,
131 pub message: Message,
132 pub finish_reason: Option<String>,
133}
134
135#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
140#[serde(tag = "role", rename_all = "lowercase")]
141pub enum Message {
142 #[serde(alias = "developer")]
143 System {
144 #[serde(deserialize_with = "string_or_one_or_many")]
145 content: OneOrMany<openai::SystemContent>,
146 #[serde(skip_serializing_if = "Option::is_none")]
147 name: Option<String>,
148 },
149 User {
150 #[serde(deserialize_with = "string_or_one_or_many")]
151 content: OneOrMany<openai::UserContent>,
152 #[serde(skip_serializing_if = "Option::is_none")]
153 name: Option<String>,
154 },
155 Assistant {
156 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
157 content: Vec<openai::AssistantContent>,
158 #[serde(skip_serializing_if = "Option::is_none")]
159 refusal: Option<String>,
160 #[serde(skip_serializing_if = "Option::is_none")]
161 audio: Option<openai::AudioAssistant>,
162 #[serde(skip_serializing_if = "Option::is_none")]
163 name: Option<String>,
164 #[serde(
165 default,
166 deserialize_with = "json_utils::null_or_vec",
167 skip_serializing_if = "Vec::is_empty"
168 )]
169 tool_calls: Vec<openai::ToolCall>,
170 #[serde(skip_serializing_if = "Option::is_none")]
171 reasoning: Option<String>,
172 #[serde(skip_serializing_if = "Vec::is_empty")]
173 reasoning_details: Vec<ReasoningDetails>,
174 },
175 #[serde(rename = "tool")]
176 ToolResult {
177 tool_call_id: String,
178 content: String,
179 },
180}
181
182impl Message {
183 pub fn system(content: &str) -> Self {
184 Message::System {
185 content: OneOrMany::one(content.to_owned().into()),
186 name: None,
187 }
188 }
189}
190
191#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
192#[serde(tag = "type", rename_all = "snake_case")]
193pub enum ReasoningDetails {
194 #[serde(rename = "reasoning.summary")]
195 Summary {
196 id: Option<String>,
197 format: Option<String>,
198 index: Option<usize>,
199 summary: String,
200 },
201 #[serde(rename = "reasoning.encrypted")]
202 Encrypted {
203 id: Option<String>,
204 format: Option<String>,
205 index: Option<usize>,
206 data: String,
207 },
208 #[serde(rename = "reasoning.text")]
209 Text {
210 id: Option<String>,
211 format: Option<String>,
212 index: Option<usize>,
213 text: Option<String>,
214 signature: Option<String>,
215 },
216}
217
218#[derive(Debug, Deserialize, PartialEq, Clone)]
219#[serde(untagged)]
220enum ToolCallAdditionalParams {
221 ReasoningDetails(ReasoningDetails),
222 Minimal {
223 id: Option<String>,
224 format: Option<String>,
225 },
226}
227
228impl From<openai::Message> for Message {
229 fn from(value: openai::Message) -> Self {
230 match value {
231 openai::Message::System { content, name } => Self::System { content, name },
232 openai::Message::User { content, name } => Self::User { content, name },
233 openai::Message::Assistant {
234 content,
235 refusal,
236 audio,
237 name,
238 tool_calls,
239 } => Self::Assistant {
240 content,
241 refusal,
242 audio,
243 name,
244 tool_calls,
245 reasoning: None,
246 reasoning_details: Vec::new(),
247 },
248 openai::Message::ToolResult {
249 tool_call_id,
250 content,
251 } => Self::ToolResult {
252 tool_call_id,
253 content: content.as_text(),
254 },
255 }
256 }
257}
258
259impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
260 type Error = message::MessageError;
261
262 fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
263 let mut text_content = Vec::new();
264 let mut tool_calls = Vec::new();
265 let mut reasoning = None;
266 let mut reasoning_details = Vec::new();
267
268 for content in value.into_iter() {
269 match content {
270 message::AssistantContent::Text(text) => text_content.push(text),
271 message::AssistantContent::ToolCall(tool_call) => {
272 if let Some(additional_params) = &tool_call.additional_params
278 && let Ok(additional_params) =
279 serde_json::from_value::<ToolCallAdditionalParams>(
280 additional_params.clone(),
281 )
282 {
283 match additional_params {
284 ToolCallAdditionalParams::ReasoningDetails(full) => {
285 reasoning_details.push(full);
286 }
287 ToolCallAdditionalParams::Minimal { id, format } => {
288 let id = id.or_else(|| tool_call.call_id.clone());
289 if let Some(signature) = &tool_call.signature
290 && let Some(id) = id
291 {
292 reasoning_details.push(ReasoningDetails::Encrypted {
293 id: Some(id),
294 format,
295 index: None,
296 data: signature.clone(),
297 })
298 }
299 }
300 }
301 } else if let Some(signature) = &tool_call.signature {
302 reasoning_details.push(ReasoningDetails::Encrypted {
303 id: tool_call.call_id.clone(),
304 format: None,
305 index: None,
306 data: signature.clone(),
307 });
308 }
309 tool_calls.push(tool_call.into())
310 }
311 message::AssistantContent::Reasoning(r) => {
312 reasoning = r.reasoning.into_iter().next();
313 }
314 message::AssistantContent::Image(_) => {
315 return Err(Self::Error::ConversionError(
316 "OpenRouter currently doesn't support images.".into(),
317 ));
318 }
319 }
320 }
321
322 Ok(vec![Message::Assistant {
325 content: text_content
326 .into_iter()
327 .map(|content| content.text.into())
328 .collect::<Vec<_>>(),
329 refusal: None,
330 audio: None,
331 name: None,
332 tool_calls,
333 reasoning,
334 reasoning_details,
335 }])
336 }
337}
338
339impl TryFrom<message::Message> for Vec<Message> {
342 type Error = message::MessageError;
343
344 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
345 match message {
346 message::Message::User { content } => {
347 let messages: Vec<openai::Message> = content.try_into()?;
348 Ok(messages.into_iter().map(Message::from).collect::<Vec<_>>())
349 }
350 message::Message::Assistant { content, .. } => content.try_into(),
351 }
352 }
353}
354
355#[derive(Debug, Serialize, Deserialize)]
356#[serde(untagged, rename_all = "snake_case")]
357pub enum ToolChoice {
358 None,
359 Auto,
360 Required,
361 Function(Vec<ToolChoiceFunctionKind>),
362}
363
364impl TryFrom<crate::message::ToolChoice> for ToolChoice {
365 type Error = CompletionError;
366
367 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
368 let res = match value {
369 crate::message::ToolChoice::None => Self::None,
370 crate::message::ToolChoice::Auto => Self::Auto,
371 crate::message::ToolChoice::Required => Self::Required,
372 crate::message::ToolChoice::Specific { function_names } => {
373 let vec: Vec<ToolChoiceFunctionKind> = function_names
374 .into_iter()
375 .map(|name| ToolChoiceFunctionKind::Function { name })
376 .collect();
377
378 Self::Function(vec)
379 }
380 };
381
382 Ok(res)
383 }
384}
385
386#[derive(Debug, Serialize, Deserialize)]
387#[serde(tag = "type", content = "function")]
388pub enum ToolChoiceFunctionKind {
389 Function { name: String },
390}
391
392#[derive(Debug, Serialize, Deserialize)]
393pub(super) struct OpenrouterCompletionRequest {
394 model: String,
395 pub messages: Vec<Message>,
396 #[serde(skip_serializing_if = "Option::is_none")]
397 temperature: Option<f64>,
398 #[serde(skip_serializing_if = "Vec::is_empty")]
399 tools: Vec<crate::providers::openai::completion::ToolDefinition>,
400 #[serde(skip_serializing_if = "Option::is_none")]
401 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
402 #[serde(flatten, skip_serializing_if = "Option::is_none")]
403 pub additional_params: Option<serde_json::Value>,
404}
405
406pub struct OpenRouterRequestParams<'a> {
408 pub model: &'a str,
409 pub request: CompletionRequest,
410 pub strict_tools: bool,
411}
412
413impl TryFrom<OpenRouterRequestParams<'_>> for OpenrouterCompletionRequest {
414 type Error = CompletionError;
415
416 fn try_from(params: OpenRouterRequestParams) -> Result<Self, Self::Error> {
417 let OpenRouterRequestParams {
418 model,
419 request: req,
420 strict_tools,
421 } = params;
422
423 let mut full_history: Vec<Message> = match &req.preamble {
424 Some(preamble) => vec![Message::system(preamble)],
425 None => vec![],
426 };
427 if let Some(docs) = req.normalized_documents() {
428 let docs: Vec<Message> = docs.try_into()?;
429 full_history.extend(docs);
430 }
431
432 let chat_history: Vec<Message> = req
433 .chat_history
434 .clone()
435 .into_iter()
436 .map(|message| message.try_into())
437 .collect::<Result<Vec<Vec<Message>>, _>>()?
438 .into_iter()
439 .flatten()
440 .collect();
441
442 full_history.extend(chat_history);
443
444 let tool_choice = req
445 .tool_choice
446 .clone()
447 .map(crate::providers::openai::completion::ToolChoice::try_from)
448 .transpose()?;
449
450 let tools: Vec<crate::providers::openai::completion::ToolDefinition> = req
451 .tools
452 .clone()
453 .into_iter()
454 .map(|tool| {
455 let def = crate::providers::openai::completion::ToolDefinition::from(tool);
456 if strict_tools { def.with_strict() } else { def }
457 })
458 .collect();
459
460 Ok(Self {
461 model: model.to_string(),
462 messages: full_history,
463 temperature: req.temperature,
464 tools,
465 tool_choice,
466 additional_params: req.additional_params,
467 })
468 }
469}
470
471impl TryFrom<(&str, CompletionRequest)> for OpenrouterCompletionRequest {
472 type Error = CompletionError;
473
474 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
475 OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
476 model,
477 request: req,
478 strict_tools: false,
479 })
480 }
481}
482
483#[derive(Clone)]
484pub struct CompletionModel<T = reqwest::Client> {
485 pub(crate) client: Client<T>,
486 pub model: String,
487 pub strict_tools: bool,
490}
491
492impl<T> CompletionModel<T> {
493 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
494 Self {
495 client,
496 model: model.into(),
497 strict_tools: false,
498 }
499 }
500
501 pub fn with_strict_tools(mut self) -> Self {
510 self.strict_tools = true;
511 self
512 }
513}
514
515impl<T> completion::CompletionModel for CompletionModel<T>
516where
517 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
518{
519 type Response = CompletionResponse;
520 type StreamingResponse = StreamingCompletionResponse;
521
522 type Client = Client<T>;
523
524 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
525 Self::new(client.clone(), model)
526 }
527
528 async fn completion(
529 &self,
530 completion_request: CompletionRequest,
531 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
532 let preamble = completion_request.preamble.clone();
533 let request = OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
534 model: self.model.as_ref(),
535 request: completion_request,
536 strict_tools: self.strict_tools,
537 })?;
538
539 if enabled!(Level::TRACE) {
540 tracing::trace!(
541 target: "rig::completions",
542 "OpenRouter completion request: {}",
543 serde_json::to_string_pretty(&request)?
544 );
545 }
546
547 let span = if tracing::Span::current().is_disabled() {
548 info_span!(
549 target: "rig::completions",
550 "chat",
551 gen_ai.operation.name = "chat",
552 gen_ai.provider.name = "openrouter",
553 gen_ai.request.model = self.model,
554 gen_ai.system_instructions = preamble,
555 gen_ai.response.id = tracing::field::Empty,
556 gen_ai.response.model = tracing::field::Empty,
557 gen_ai.usage.output_tokens = tracing::field::Empty,
558 gen_ai.usage.input_tokens = tracing::field::Empty,
559 )
560 } else {
561 tracing::Span::current()
562 };
563
564 let body = serde_json::to_vec(&request)?;
565
566 let req = self
567 .client
568 .post("/chat/completions")?
569 .body(body)
570 .map_err(|x| CompletionError::HttpError(x.into()))?;
571
572 async move {
573 let response = self.client.send::<_, Bytes>(req).await?;
574 let status = response.status();
575 let response_body = response.into_body().into_future().await?.to_vec();
576
577 if status.is_success() {
578 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
579 ApiResponse::Ok(response) => {
580 let span = tracing::Span::current();
581 span.record_token_usage(&response.usage);
582 span.record("gen_ai.response.id", &response.id);
583 span.record("gen_ai.response.model_name", &response.model);
584
585 tracing::debug!(target: "rig::completions",
586 "OpenRouter response: {response:?}");
587 response.try_into()
588 }
589 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
590 }
591 } else {
592 Err(CompletionError::ProviderError(
593 String::from_utf8_lossy(&response_body).to_string(),
594 ))
595 }
596 }
597 .instrument(span)
598 .await
599 }
600
601 async fn stream(
602 &self,
603 completion_request: CompletionRequest,
604 ) -> Result<
605 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
606 CompletionError,
607 > {
608 CompletionModel::stream(self, completion_request).await
609 }
610}