1use super::{
2 client::{ApiErrorResponse, ApiResponse, Client, Usage},
3 streaming::StreamingCompletionResponse,
4};
5use crate::message;
6use crate::telemetry::SpanCombinator;
7use crate::{
8 OneOrMany,
9 completion::{self, CompletionError, CompletionRequest},
10 http_client::HttpClientExt,
11 json_utils,
12 one_or_many::string_or_one_or_many,
13 providers::openai,
14};
15use bytes::Bytes;
16use serde::{Deserialize, Serialize};
17use tracing::{Instrument, Level, enabled, info_span};
18
19pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
25pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
27pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
29pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
31
32#[derive(Debug, Serialize, Deserialize)]
36pub struct CompletionResponse {
37 pub id: String,
38 pub object: String,
39 pub created: u64,
40 pub model: String,
41 pub choices: Vec<Choice>,
42 pub system_fingerprint: Option<String>,
43 pub usage: Option<Usage>,
44}
45
46impl From<ApiErrorResponse> for CompletionError {
47 fn from(err: ApiErrorResponse) -> Self {
48 CompletionError::ProviderError(err.message)
49 }
50}
51
52impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
53 type Error = CompletionError;
54
55 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
56 let choice = response.choices.first().ok_or_else(|| {
57 CompletionError::ResponseError("Response contained no choices".to_owned())
58 })?;
59
60 let content = match &choice.message {
61 Message::Assistant {
62 content,
63 tool_calls,
64 reasoning,
65 ..
66 } => {
67 let mut content = content
68 .iter()
69 .map(|c| match c {
70 openai::AssistantContent::Text { text } => {
71 completion::AssistantContent::text(text)
72 }
73 openai::AssistantContent::Refusal { refusal } => {
74 completion::AssistantContent::text(refusal)
75 }
76 })
77 .collect::<Vec<_>>();
78
79 content.extend(
80 tool_calls
81 .iter()
82 .map(|call| {
83 completion::AssistantContent::tool_call(
84 &call.id,
85 &call.function.name,
86 call.function.arguments.clone(),
87 )
88 })
89 .collect::<Vec<_>>(),
90 );
91
92 if let Some(reasoning) = reasoning {
93 content.push(completion::AssistantContent::reasoning(reasoning));
94 }
95
96 Ok(content)
97 }
98 _ => Err(CompletionError::ResponseError(
99 "Response did not contain a valid message or tool call".into(),
100 )),
101 }?;
102
103 let choice = OneOrMany::many(content).map_err(|_| {
104 CompletionError::ResponseError(
105 "Response contained no message or tool call (empty)".to_owned(),
106 )
107 })?;
108
109 let usage = response
110 .usage
111 .as_ref()
112 .map(|usage| completion::Usage {
113 input_tokens: usage.prompt_tokens as u64,
114 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
115 total_tokens: usage.total_tokens as u64,
116 cached_input_tokens: 0,
117 })
118 .unwrap_or_default();
119
120 Ok(completion::CompletionResponse {
121 choice,
122 usage,
123 raw_response: response,
124 })
125 }
126}
127
128#[derive(Debug, Deserialize, Serialize)]
129pub struct Choice {
130 pub index: usize,
131 pub native_finish_reason: Option<String>,
132 pub message: Message,
133 pub finish_reason: Option<String>,
134}
135
136#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
141#[serde(tag = "role", rename_all = "lowercase")]
142pub enum Message {
143 #[serde(alias = "developer")]
144 System {
145 #[serde(deserialize_with = "string_or_one_or_many")]
146 content: OneOrMany<openai::SystemContent>,
147 #[serde(skip_serializing_if = "Option::is_none")]
148 name: Option<String>,
149 },
150 User {
151 #[serde(deserialize_with = "string_or_one_or_many")]
152 content: OneOrMany<openai::UserContent>,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 name: Option<String>,
155 },
156 Assistant {
157 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
158 content: Vec<openai::AssistantContent>,
159 #[serde(skip_serializing_if = "Option::is_none")]
160 refusal: Option<String>,
161 #[serde(skip_serializing_if = "Option::is_none")]
162 audio: Option<openai::AudioAssistant>,
163 #[serde(skip_serializing_if = "Option::is_none")]
164 name: Option<String>,
165 #[serde(
166 default,
167 deserialize_with = "json_utils::null_or_vec",
168 skip_serializing_if = "Vec::is_empty"
169 )]
170 tool_calls: Vec<openai::ToolCall>,
171 #[serde(skip_serializing_if = "Option::is_none")]
172 reasoning: Option<String>,
173 #[serde(default, skip_serializing_if = "Vec::is_empty")]
174 reasoning_details: Vec<ReasoningDetails>,
175 },
176 #[serde(rename = "tool")]
177 ToolResult {
178 tool_call_id: String,
179 content: String,
180 },
181}
182
183impl Message {
184 pub fn system(content: &str) -> Self {
185 Message::System {
186 content: OneOrMany::one(content.to_owned().into()),
187 name: None,
188 }
189 }
190}
191
192#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
193#[serde(tag = "type", rename_all = "snake_case")]
194pub enum ReasoningDetails {
195 #[serde(rename = "reasoning.summary")]
196 Summary {
197 id: Option<String>,
198 format: Option<String>,
199 index: Option<usize>,
200 summary: String,
201 },
202 #[serde(rename = "reasoning.encrypted")]
203 Encrypted {
204 id: Option<String>,
205 format: Option<String>,
206 index: Option<usize>,
207 data: String,
208 },
209 #[serde(rename = "reasoning.text")]
210 Text {
211 id: Option<String>,
212 format: Option<String>,
213 index: Option<usize>,
214 text: Option<String>,
215 signature: Option<String>,
216 },
217}
218
219#[derive(Debug, Deserialize, PartialEq, Clone)]
220#[serde(untagged)]
221enum ToolCallAdditionalParams {
222 ReasoningDetails(ReasoningDetails),
223 Minimal {
224 id: Option<String>,
225 format: Option<String>,
226 },
227}
228
229impl From<openai::Message> for Message {
230 fn from(value: openai::Message) -> Self {
231 match value {
232 openai::Message::System { content, name } => Self::System { content, name },
233 openai::Message::User { content, name } => Self::User { content, name },
234 openai::Message::Assistant {
235 content,
236 refusal,
237 audio,
238 name,
239 tool_calls,
240 } => Self::Assistant {
241 content,
242 refusal,
243 audio,
244 name,
245 tool_calls,
246 reasoning: None,
247 reasoning_details: Vec::new(),
248 },
249 openai::Message::ToolResult {
250 tool_call_id,
251 content,
252 } => Self::ToolResult {
253 tool_call_id,
254 content: content.as_text(),
255 },
256 }
257 }
258}
259
260impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
261 type Error = message::MessageError;
262
263 fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
264 let mut text_content = Vec::new();
265 let mut tool_calls = Vec::new();
266 let mut reasoning = None;
267 let mut reasoning_details = Vec::new();
268
269 for content in value.into_iter() {
270 match content {
271 message::AssistantContent::Text(text) => text_content.push(text),
272 message::AssistantContent::ToolCall(tool_call) => {
273 if let Some(additional_params) = &tool_call.additional_params
279 && let Ok(additional_params) =
280 serde_json::from_value::<ToolCallAdditionalParams>(
281 additional_params.clone(),
282 )
283 {
284 match additional_params {
285 ToolCallAdditionalParams::ReasoningDetails(full) => {
286 reasoning_details.push(full);
287 }
288 ToolCallAdditionalParams::Minimal { id, format } => {
289 let id = id.or_else(|| tool_call.call_id.clone());
290 if let Some(signature) = &tool_call.signature
291 && let Some(id) = id
292 {
293 reasoning_details.push(ReasoningDetails::Encrypted {
294 id: Some(id),
295 format,
296 index: None,
297 data: signature.clone(),
298 })
299 }
300 }
301 }
302 } else if let Some(signature) = &tool_call.signature {
303 reasoning_details.push(ReasoningDetails::Encrypted {
304 id: tool_call.call_id.clone(),
305 format: None,
306 index: None,
307 data: signature.clone(),
308 });
309 }
310 tool_calls.push(tool_call.into())
311 }
312 message::AssistantContent::Reasoning(r) => {
313 reasoning = r.reasoning.into_iter().next();
314 }
315 message::AssistantContent::Image(_) => {
316 return Err(Self::Error::ConversionError(
317 "OpenRouter currently doesn't support images.".into(),
318 ));
319 }
320 }
321 }
322
323 Ok(vec![Message::Assistant {
326 content: text_content
327 .into_iter()
328 .map(|content| content.text.into())
329 .collect::<Vec<_>>(),
330 refusal: None,
331 audio: None,
332 name: None,
333 tool_calls,
334 reasoning,
335 reasoning_details,
336 }])
337 }
338}
339
340impl TryFrom<message::Message> for Vec<Message> {
343 type Error = message::MessageError;
344
345 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
346 match message {
347 message::Message::User { content } => {
348 let messages: Vec<openai::Message> = content.try_into()?;
349 Ok(messages.into_iter().map(Message::from).collect::<Vec<_>>())
350 }
351 message::Message::Assistant { content, .. } => content.try_into(),
352 }
353 }
354}
355
356#[derive(Debug, Serialize, Deserialize)]
357#[serde(untagged, rename_all = "snake_case")]
358pub enum ToolChoice {
359 None,
360 Auto,
361 Required,
362 Function(Vec<ToolChoiceFunctionKind>),
363}
364
365impl TryFrom<crate::message::ToolChoice> for ToolChoice {
366 type Error = CompletionError;
367
368 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
369 let res = match value {
370 crate::message::ToolChoice::None => Self::None,
371 crate::message::ToolChoice::Auto => Self::Auto,
372 crate::message::ToolChoice::Required => Self::Required,
373 crate::message::ToolChoice::Specific { function_names } => {
374 let vec: Vec<ToolChoiceFunctionKind> = function_names
375 .into_iter()
376 .map(|name| ToolChoiceFunctionKind::Function { name })
377 .collect();
378
379 Self::Function(vec)
380 }
381 };
382
383 Ok(res)
384 }
385}
386
387#[derive(Debug, Serialize, Deserialize)]
388#[serde(tag = "type", content = "function")]
389pub enum ToolChoiceFunctionKind {
390 Function { name: String },
391}
392
393#[derive(Debug, Serialize, Deserialize)]
394pub(super) struct OpenrouterCompletionRequest {
395 model: String,
396 pub messages: Vec<Message>,
397 #[serde(skip_serializing_if = "Option::is_none")]
398 temperature: Option<f64>,
399 #[serde(skip_serializing_if = "Vec::is_empty")]
400 tools: Vec<crate::providers::openai::completion::ToolDefinition>,
401 #[serde(skip_serializing_if = "Option::is_none")]
402 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
403 #[serde(flatten, skip_serializing_if = "Option::is_none")]
404 pub additional_params: Option<serde_json::Value>,
405}
406
407pub struct OpenRouterRequestParams<'a> {
409 pub model: &'a str,
410 pub request: CompletionRequest,
411 pub strict_tools: bool,
412}
413
414impl TryFrom<OpenRouterRequestParams<'_>> for OpenrouterCompletionRequest {
415 type Error = CompletionError;
416
417 fn try_from(params: OpenRouterRequestParams) -> Result<Self, Self::Error> {
418 let OpenRouterRequestParams {
419 model,
420 request: req,
421 strict_tools,
422 } = params;
423
424 let mut full_history: Vec<Message> = match &req.preamble {
425 Some(preamble) => vec![Message::system(preamble)],
426 None => vec![],
427 };
428 if let Some(docs) = req.normalized_documents() {
429 let docs: Vec<Message> = docs.try_into()?;
430 full_history.extend(docs);
431 }
432
433 let chat_history: Vec<Message> = req
434 .chat_history
435 .clone()
436 .into_iter()
437 .map(|message| message.try_into())
438 .collect::<Result<Vec<Vec<Message>>, _>>()?
439 .into_iter()
440 .flatten()
441 .collect();
442
443 full_history.extend(chat_history);
444
445 let tool_choice = req
446 .tool_choice
447 .clone()
448 .map(crate::providers::openai::completion::ToolChoice::try_from)
449 .transpose()?;
450
451 let tools: Vec<crate::providers::openai::completion::ToolDefinition> = req
452 .tools
453 .clone()
454 .into_iter()
455 .map(|tool| {
456 let def = crate::providers::openai::completion::ToolDefinition::from(tool);
457 if strict_tools { def.with_strict() } else { def }
458 })
459 .collect();
460
461 Ok(Self {
462 model: model.to_string(),
463 messages: full_history,
464 temperature: req.temperature,
465 tools,
466 tool_choice,
467 additional_params: req.additional_params,
468 })
469 }
470}
471
472impl TryFrom<(&str, CompletionRequest)> for OpenrouterCompletionRequest {
473 type Error = CompletionError;
474
475 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
476 OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
477 model,
478 request: req,
479 strict_tools: false,
480 })
481 }
482}
483
484#[derive(Clone)]
485pub struct CompletionModel<T = reqwest::Client> {
486 pub(crate) client: Client<T>,
487 pub model: String,
488 pub strict_tools: bool,
491}
492
493impl<T> CompletionModel<T> {
494 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
495 Self {
496 client,
497 model: model.into(),
498 strict_tools: false,
499 }
500 }
501
502 pub fn with_strict_tools(mut self) -> Self {
511 self.strict_tools = true;
512 self
513 }
514}
515
516impl<T> completion::CompletionModel for CompletionModel<T>
517where
518 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
519{
520 type Response = CompletionResponse;
521 type StreamingResponse = StreamingCompletionResponse;
522
523 type Client = Client<T>;
524
525 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
526 Self::new(client.clone(), model)
527 }
528
529 async fn completion(
530 &self,
531 completion_request: CompletionRequest,
532 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
533 let preamble = completion_request.preamble.clone();
534 let request = OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
535 model: self.model.as_ref(),
536 request: completion_request,
537 strict_tools: self.strict_tools,
538 })?;
539
540 if enabled!(Level::TRACE) {
541 tracing::trace!(
542 target: "rig::completions",
543 "OpenRouter completion request: {}",
544 serde_json::to_string_pretty(&request)?
545 );
546 }
547
548 let span = if tracing::Span::current().is_disabled() {
549 info_span!(
550 target: "rig::completions",
551 "chat",
552 gen_ai.operation.name = "chat",
553 gen_ai.provider.name = "openrouter",
554 gen_ai.request.model = self.model,
555 gen_ai.system_instructions = preamble,
556 gen_ai.response.id = tracing::field::Empty,
557 gen_ai.response.model = tracing::field::Empty,
558 gen_ai.usage.output_tokens = tracing::field::Empty,
559 gen_ai.usage.input_tokens = tracing::field::Empty,
560 )
561 } else {
562 tracing::Span::current()
563 };
564
565 let body = serde_json::to_vec(&request)?;
566
567 let req = self
568 .client
569 .post("/chat/completions")?
570 .body(body)
571 .map_err(|x| CompletionError::HttpError(x.into()))?;
572
573 async move {
574 let response = self.client.send::<_, Bytes>(req).await?;
575 let status = response.status();
576 let response_body = response.into_body().into_future().await?.to_vec();
577
578 if status.is_success() {
579 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
580 ApiResponse::Ok(response) => {
581 let span = tracing::Span::current();
582 span.record_token_usage(&response.usage);
583 span.record("gen_ai.response.id", &response.id);
584 span.record("gen_ai.response.model_name", &response.model);
585
586 tracing::debug!(target: "rig::completions",
587 "OpenRouter response: {response:?}");
588 response.try_into()
589 }
590 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
591 }
592 } else {
593 Err(CompletionError::ProviderError(
594 String::from_utf8_lossy(&response_body).to_string(),
595 ))
596 }
597 }
598 .instrument(span)
599 .await
600 }
601
602 async fn stream(
603 &self,
604 completion_request: CompletionRequest,
605 ) -> Result<
606 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
607 CompletionError,
608 > {
609 CompletionModel::stream(self, completion_request).await
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616 use serde_json::json;
617
618 #[test]
619 fn test_completion_response_deserialization_gemini_flash() {
620 let json = json!({
622 "id": "gen-AAAAAAAAAA-AAAAAAAAAAAAAAAAAAAA",
623 "provider": "Google",
624 "model": "google/gemini-2.5-flash",
625 "object": "chat.completion",
626 "created": 1765971703u64,
627 "choices": [{
628 "logprobs": null,
629 "finish_reason": "stop",
630 "native_finish_reason": "STOP",
631 "index": 0,
632 "message": {
633 "role": "assistant",
634 "content": "CONTENT",
635 "refusal": null,
636 "reasoning": null
637 }
638 }],
639 "usage": {
640 "prompt_tokens": 669,
641 "completion_tokens": 5,
642 "total_tokens": 674
643 }
644 });
645
646 let response: CompletionResponse = serde_json::from_value(json).unwrap();
647 assert_eq!(response.id, "gen-AAAAAAAAAA-AAAAAAAAAAAAAAAAAAAA");
648 assert_eq!(response.model, "google/gemini-2.5-flash");
649 assert_eq!(response.choices.len(), 1);
650 assert_eq!(response.choices[0].finish_reason, Some("stop".to_string()));
651 }
652
653 #[test]
654 fn test_message_assistant_without_reasoning_details() {
655 let json = json!({
657 "role": "assistant",
658 "content": "Hello world",
659 "refusal": null,
660 "reasoning": null
661 });
662
663 let message: Message = serde_json::from_value(json).unwrap();
664 match message {
665 Message::Assistant {
666 content,
667 reasoning_details,
668 ..
669 } => {
670 assert_eq!(content.len(), 1);
671 assert!(reasoning_details.is_empty());
672 }
673 _ => panic!("Expected Assistant message"),
674 }
675 }
676}