rig/providers/openrouter/
completion.rs

1use super::{
2    client::{ApiErrorResponse, ApiResponse, Client, Usage},
3    streaming::StreamingCompletionResponse,
4};
5use crate::message;
6use crate::telemetry::SpanCombinator;
7use crate::{
8    OneOrMany,
9    completion::{self, CompletionError, CompletionRequest},
10    http_client::HttpClientExt,
11    json_utils,
12    one_or_many::string_or_one_or_many,
13    providers::openai,
14};
15use bytes::Bytes;
16use serde::{Deserialize, Serialize};
17use tracing::{Instrument, Level, enabled, info_span};
18
19// ================================================================
20// OpenRouter Completion API
21// ================================================================
22
23/// The `qwen/qwq-32b` model. Find more models at <https://openrouter.ai/models>.
24pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
25/// The `anthropic/claude-3.7-sonnet` model. Find more models at <https://openrouter.ai/models>.
26pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
27/// The `perplexity/sonar-pro` model. Find more models at <https://openrouter.ai/models>.
28pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
29/// The `google/gemini-2.0-flash-001` model. Find more models at <https://openrouter.ai/models>.
30pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
31
32/// A openrouter completion object.
33///
34/// For more information, see this link: <https://docs.openrouter.xyz/reference/create_chat_completion_v1_chat_completions_post>
35#[derive(Debug, Serialize, Deserialize)]
36pub struct CompletionResponse {
37    pub id: String,
38    pub object: String,
39    pub created: u64,
40    pub model: String,
41    pub choices: Vec<Choice>,
42    pub system_fingerprint: Option<String>,
43    pub usage: Option<Usage>,
44}
45
46impl From<ApiErrorResponse> for CompletionError {
47    fn from(err: ApiErrorResponse) -> Self {
48        CompletionError::ProviderError(err.message)
49    }
50}
51
52impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
53    type Error = CompletionError;
54
55    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
56        let choice = response.choices.first().ok_or_else(|| {
57            CompletionError::ResponseError("Response contained no choices".to_owned())
58        })?;
59
60        let content = match &choice.message {
61            Message::Assistant {
62                content,
63                tool_calls,
64                reasoning,
65                ..
66            } => {
67                let mut content = content
68                    .iter()
69                    .map(|c| match c {
70                        openai::AssistantContent::Text { text } => {
71                            completion::AssistantContent::text(text)
72                        }
73                        openai::AssistantContent::Refusal { refusal } => {
74                            completion::AssistantContent::text(refusal)
75                        }
76                    })
77                    .collect::<Vec<_>>();
78
79                content.extend(
80                    tool_calls
81                        .iter()
82                        .map(|call| {
83                            completion::AssistantContent::tool_call(
84                                &call.id,
85                                &call.function.name,
86                                call.function.arguments.clone(),
87                            )
88                        })
89                        .collect::<Vec<_>>(),
90                );
91
92                if let Some(reasoning) = reasoning {
93                    content.push(completion::AssistantContent::reasoning(reasoning));
94                }
95
96                Ok(content)
97            }
98            _ => Err(CompletionError::ResponseError(
99                "Response did not contain a valid message or tool call".into(),
100            )),
101        }?;
102
103        let choice = OneOrMany::many(content).map_err(|_| {
104            CompletionError::ResponseError(
105                "Response contained no message or tool call (empty)".to_owned(),
106            )
107        })?;
108
109        let usage = response
110            .usage
111            .as_ref()
112            .map(|usage| completion::Usage {
113                input_tokens: usage.prompt_tokens as u64,
114                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
115                total_tokens: usage.total_tokens as u64,
116            })
117            .unwrap_or_default();
118
119        Ok(completion::CompletionResponse {
120            choice,
121            usage,
122            raw_response: response,
123        })
124    }
125}
126
127#[derive(Debug, Deserialize, Serialize)]
128pub struct Choice {
129    pub index: usize,
130    pub native_finish_reason: Option<String>,
131    pub message: Message,
132    pub finish_reason: Option<String>,
133}
134
135/// OpenRouter message.
136///
137/// Almost identical to OpenAI's Message, but supports more parameters
138/// for some providers like `reasoning`.
139#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
140#[serde(tag = "role", rename_all = "lowercase")]
141pub enum Message {
142    #[serde(alias = "developer")]
143    System {
144        #[serde(deserialize_with = "string_or_one_or_many")]
145        content: OneOrMany<openai::SystemContent>,
146        #[serde(skip_serializing_if = "Option::is_none")]
147        name: Option<String>,
148    },
149    User {
150        #[serde(deserialize_with = "string_or_one_or_many")]
151        content: OneOrMany<openai::UserContent>,
152        #[serde(skip_serializing_if = "Option::is_none")]
153        name: Option<String>,
154    },
155    Assistant {
156        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
157        content: Vec<openai::AssistantContent>,
158        #[serde(skip_serializing_if = "Option::is_none")]
159        refusal: Option<String>,
160        #[serde(skip_serializing_if = "Option::is_none")]
161        audio: Option<openai::AudioAssistant>,
162        #[serde(skip_serializing_if = "Option::is_none")]
163        name: Option<String>,
164        #[serde(
165            default,
166            deserialize_with = "json_utils::null_or_vec",
167            skip_serializing_if = "Vec::is_empty"
168        )]
169        tool_calls: Vec<openai::ToolCall>,
170        #[serde(skip_serializing_if = "Option::is_none")]
171        reasoning: Option<String>,
172        #[serde(skip_serializing_if = "Vec::is_empty")]
173        reasoning_details: Vec<ReasoningDetails>,
174    },
175    #[serde(rename = "tool")]
176    ToolResult {
177        tool_call_id: String,
178        content: String,
179    },
180}
181
182impl Message {
183    pub fn system(content: &str) -> Self {
184        Message::System {
185            content: OneOrMany::one(content.to_owned().into()),
186            name: None,
187        }
188    }
189}
190
191#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
192#[serde(tag = "type", rename_all = "snake_case")]
193pub enum ReasoningDetails {
194    #[serde(rename = "reasoning.summary")]
195    Summary {
196        id: Option<String>,
197        format: Option<String>,
198        index: Option<usize>,
199        summary: String,
200    },
201    #[serde(rename = "reasoning.encrypted")]
202    Encrypted {
203        id: Option<String>,
204        format: Option<String>,
205        index: Option<usize>,
206        data: String,
207    },
208    #[serde(rename = "reasoning.text")]
209    Text {
210        id: Option<String>,
211        format: Option<String>,
212        index: Option<usize>,
213        text: Option<String>,
214        signature: Option<String>,
215    },
216}
217
218#[derive(Debug, Deserialize, PartialEq, Clone)]
219#[serde(untagged)]
220enum ToolCallAdditionalParams {
221    ReasoningDetails(ReasoningDetails),
222    Minimal {
223        id: Option<String>,
224        format: Option<String>,
225    },
226}
227
228impl From<openai::Message> for Message {
229    fn from(value: openai::Message) -> Self {
230        match value {
231            openai::Message::System { content, name } => Self::System { content, name },
232            openai::Message::User { content, name } => Self::User { content, name },
233            openai::Message::Assistant {
234                content,
235                refusal,
236                audio,
237                name,
238                tool_calls,
239            } => Self::Assistant {
240                content,
241                refusal,
242                audio,
243                name,
244                tool_calls,
245                reasoning: None,
246                reasoning_details: Vec::new(),
247            },
248            openai::Message::ToolResult {
249                tool_call_id,
250                content,
251            } => Self::ToolResult {
252                tool_call_id,
253                content: content.as_text(),
254            },
255        }
256    }
257}
258
259impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
260    type Error = message::MessageError;
261
262    fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
263        let mut text_content = Vec::new();
264        let mut tool_calls = Vec::new();
265        let mut reasoning = None;
266        let mut reasoning_details = Vec::new();
267
268        for content in value.into_iter() {
269            match content {
270                message::AssistantContent::Text(text) => text_content.push(text),
271                message::AssistantContent::ToolCall(tool_call) => {
272                    // We usually want to provide back the reasoning to OpenRouter since some
273                    // providers require it.
274                    // 1. Full reasoning details passed back the user
275                    // 2. The signature, an id and a format if present
276                    // 3. The signature and the call_id if present
277                    if let Some(additional_params) = &tool_call.additional_params
278                        && let Ok(additional_params) =
279                            serde_json::from_value::<ToolCallAdditionalParams>(
280                                additional_params.clone(),
281                            )
282                    {
283                        match additional_params {
284                            ToolCallAdditionalParams::ReasoningDetails(full) => {
285                                reasoning_details.push(full);
286                            }
287                            ToolCallAdditionalParams::Minimal { id, format } => {
288                                let id = id.or_else(|| tool_call.call_id.clone());
289                                if let Some(signature) = &tool_call.signature
290                                    && let Some(id) = id
291                                {
292                                    reasoning_details.push(ReasoningDetails::Encrypted {
293                                        id: Some(id),
294                                        format,
295                                        index: None,
296                                        data: signature.clone(),
297                                    })
298                                }
299                            }
300                        }
301                    } else if let Some(signature) = &tool_call.signature {
302                        reasoning_details.push(ReasoningDetails::Encrypted {
303                            id: tool_call.call_id.clone(),
304                            format: None,
305                            index: None,
306                            data: signature.clone(),
307                        });
308                    }
309                    tool_calls.push(tool_call.into())
310                }
311                message::AssistantContent::Reasoning(r) => {
312                    reasoning = r.reasoning.into_iter().next();
313                }
314                message::AssistantContent::Image(_) => {
315                    return Err(Self::Error::ConversionError(
316                        "OpenRouter currently doesn't support images.".into(),
317                    ));
318                }
319            }
320        }
321
322        // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
323        //  so either `content` or `tool_calls` will have some content.
324        Ok(vec![Message::Assistant {
325            content: text_content
326                .into_iter()
327                .map(|content| content.text.into())
328                .collect::<Vec<_>>(),
329            refusal: None,
330            audio: None,
331            name: None,
332            tool_calls,
333            reasoning,
334            reasoning_details,
335        }])
336    }
337}
338
339// We re-use most of the openai implementation when we can and we re-implement
340// only the part that differentate for openrouter (like reasoning support).
341impl TryFrom<message::Message> for Vec<Message> {
342    type Error = message::MessageError;
343
344    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
345        match message {
346            message::Message::User { content } => {
347                let messages: Vec<openai::Message> = content.try_into()?;
348                Ok(messages.into_iter().map(Message::from).collect::<Vec<_>>())
349            }
350            message::Message::Assistant { content, .. } => content.try_into(),
351        }
352    }
353}
354
355#[derive(Debug, Serialize, Deserialize)]
356#[serde(untagged, rename_all = "snake_case")]
357pub enum ToolChoice {
358    None,
359    Auto,
360    Required,
361    Function(Vec<ToolChoiceFunctionKind>),
362}
363
364impl TryFrom<crate::message::ToolChoice> for ToolChoice {
365    type Error = CompletionError;
366
367    fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
368        let res = match value {
369            crate::message::ToolChoice::None => Self::None,
370            crate::message::ToolChoice::Auto => Self::Auto,
371            crate::message::ToolChoice::Required => Self::Required,
372            crate::message::ToolChoice::Specific { function_names } => {
373                let vec: Vec<ToolChoiceFunctionKind> = function_names
374                    .into_iter()
375                    .map(|name| ToolChoiceFunctionKind::Function { name })
376                    .collect();
377
378                Self::Function(vec)
379            }
380        };
381
382        Ok(res)
383    }
384}
385
386#[derive(Debug, Serialize, Deserialize)]
387#[serde(tag = "type", content = "function")]
388pub enum ToolChoiceFunctionKind {
389    Function { name: String },
390}
391
392#[derive(Debug, Serialize, Deserialize)]
393pub(super) struct OpenrouterCompletionRequest {
394    model: String,
395    pub messages: Vec<Message>,
396    #[serde(skip_serializing_if = "Option::is_none")]
397    temperature: Option<f64>,
398    #[serde(skip_serializing_if = "Vec::is_empty")]
399    tools: Vec<crate::providers::openai::completion::ToolDefinition>,
400    #[serde(skip_serializing_if = "Option::is_none")]
401    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
402    #[serde(flatten, skip_serializing_if = "Option::is_none")]
403    pub additional_params: Option<serde_json::Value>,
404}
405
406/// Parameters for building an OpenRouter CompletionRequest
407pub struct OpenRouterRequestParams<'a> {
408    pub model: &'a str,
409    pub request: CompletionRequest,
410    pub strict_tools: bool,
411}
412
413impl TryFrom<OpenRouterRequestParams<'_>> for OpenrouterCompletionRequest {
414    type Error = CompletionError;
415
416    fn try_from(params: OpenRouterRequestParams) -> Result<Self, Self::Error> {
417        let OpenRouterRequestParams {
418            model,
419            request: req,
420            strict_tools,
421        } = params;
422
423        let mut full_history: Vec<Message> = match &req.preamble {
424            Some(preamble) => vec![Message::system(preamble)],
425            None => vec![],
426        };
427        if let Some(docs) = req.normalized_documents() {
428            let docs: Vec<Message> = docs.try_into()?;
429            full_history.extend(docs);
430        }
431
432        let chat_history: Vec<Message> = req
433            .chat_history
434            .clone()
435            .into_iter()
436            .map(|message| message.try_into())
437            .collect::<Result<Vec<Vec<Message>>, _>>()?
438            .into_iter()
439            .flatten()
440            .collect();
441
442        full_history.extend(chat_history);
443
444        let tool_choice = req
445            .tool_choice
446            .clone()
447            .map(crate::providers::openai::completion::ToolChoice::try_from)
448            .transpose()?;
449
450        let tools: Vec<crate::providers::openai::completion::ToolDefinition> = req
451            .tools
452            .clone()
453            .into_iter()
454            .map(|tool| {
455                let def = crate::providers::openai::completion::ToolDefinition::from(tool);
456                if strict_tools { def.with_strict() } else { def }
457            })
458            .collect();
459
460        Ok(Self {
461            model: model.to_string(),
462            messages: full_history,
463            temperature: req.temperature,
464            tools,
465            tool_choice,
466            additional_params: req.additional_params,
467        })
468    }
469}
470
471impl TryFrom<(&str, CompletionRequest)> for OpenrouterCompletionRequest {
472    type Error = CompletionError;
473
474    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
475        OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
476            model,
477            request: req,
478            strict_tools: false,
479        })
480    }
481}
482
483#[derive(Clone)]
484pub struct CompletionModel<T = reqwest::Client> {
485    pub(crate) client: Client<T>,
486    pub model: String,
487    /// Enable strict mode for tool schemas.
488    /// When enabled, tool schemas are sanitized to meet OpenAI's strict mode requirements.
489    pub strict_tools: bool,
490}
491
492impl<T> CompletionModel<T> {
493    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
494        Self {
495            client,
496            model: model.into(),
497            strict_tools: false,
498        }
499    }
500
501    /// Enable strict mode for tool schemas.
502    ///
503    /// When enabled, tool schemas are automatically sanitized to meet OpenAI's strict mode requirements:
504    /// - `additionalProperties: false` is added to all objects
505    /// - All properties are marked as required
506    /// - `strict: true` is set on each function definition
507    ///
508    /// Note: Not all models on OpenRouter support strict mode. This works best with OpenAI models.
509    pub fn with_strict_tools(mut self) -> Self {
510        self.strict_tools = true;
511        self
512    }
513}
514
515impl<T> completion::CompletionModel for CompletionModel<T>
516where
517    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
518{
519    type Response = CompletionResponse;
520    type StreamingResponse = StreamingCompletionResponse;
521
522    type Client = Client<T>;
523
524    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
525        Self::new(client.clone(), model)
526    }
527
528    async fn completion(
529        &self,
530        completion_request: CompletionRequest,
531    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
532        let preamble = completion_request.preamble.clone();
533        let request = OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
534            model: self.model.as_ref(),
535            request: completion_request,
536            strict_tools: self.strict_tools,
537        })?;
538
539        if enabled!(Level::TRACE) {
540            tracing::trace!(
541                target: "rig::completions",
542                "OpenRouter completion request: {}",
543                serde_json::to_string_pretty(&request)?
544            );
545        }
546
547        let span = if tracing::Span::current().is_disabled() {
548            info_span!(
549                target: "rig::completions",
550                "chat",
551                gen_ai.operation.name = "chat",
552                gen_ai.provider.name = "openrouter",
553                gen_ai.request.model = self.model,
554                gen_ai.system_instructions = preamble,
555                gen_ai.response.id = tracing::field::Empty,
556                gen_ai.response.model = tracing::field::Empty,
557                gen_ai.usage.output_tokens = tracing::field::Empty,
558                gen_ai.usage.input_tokens = tracing::field::Empty,
559            )
560        } else {
561            tracing::Span::current()
562        };
563
564        let body = serde_json::to_vec(&request)?;
565
566        let req = self
567            .client
568            .post("/chat/completions")?
569            .body(body)
570            .map_err(|x| CompletionError::HttpError(x.into()))?;
571
572        async move {
573            let response = self.client.send::<_, Bytes>(req).await?;
574            let status = response.status();
575            let response_body = response.into_body().into_future().await?.to_vec();
576
577            if status.is_success() {
578                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
579                    ApiResponse::Ok(response) => {
580                        let span = tracing::Span::current();
581                        span.record_token_usage(&response.usage);
582                        span.record("gen_ai.response.id", &response.id);
583                        span.record("gen_ai.response.model_name", &response.model);
584
585                        tracing::debug!(target: "rig::completions",
586                            "OpenRouter response: {response:?}");
587                        response.try_into()
588                    }
589                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
590                }
591            } else {
592                Err(CompletionError::ProviderError(
593                    String::from_utf8_lossy(&response_body).to_string(),
594                ))
595            }
596        }
597        .instrument(span)
598        .await
599    }
600
601    async fn stream(
602        &self,
603        completion_request: CompletionRequest,
604    ) -> Result<
605        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
606        CompletionError,
607    > {
608        CompletionModel::stream(self, completion_request).await
609    }
610}