rig/providers/openrouter/
completion.rs

1use super::{
2    client::{ApiErrorResponse, ApiResponse, Client, Usage},
3    streaming::StreamingCompletionResponse,
4};
5use crate::message;
6use crate::telemetry::SpanCombinator;
7use crate::{
8    OneOrMany,
9    completion::{self, CompletionError, CompletionRequest},
10    http_client::HttpClientExt,
11    json_utils,
12    one_or_many::string_or_one_or_many,
13    providers::openai,
14};
15use bytes::Bytes;
16use serde::{Deserialize, Serialize};
17use tracing::{Instrument, info_span};
18
19// ================================================================
20// OpenRouter Completion API
21// ================================================================
22
23/// The `qwen/qwq-32b` model. Find more models at <https://openrouter.ai/models>.
24pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
25/// The `anthropic/claude-3.7-sonnet` model. Find more models at <https://openrouter.ai/models>.
26pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
27/// The `perplexity/sonar-pro` model. Find more models at <https://openrouter.ai/models>.
28pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
29/// The `google/gemini-2.0-flash-001` model. Find more models at <https://openrouter.ai/models>.
30pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
31
32/// A openrouter completion object.
33///
34/// For more information, see this link: <https://docs.openrouter.xyz/reference/create_chat_completion_v1_chat_completions_post>
35#[derive(Debug, Serialize, Deserialize)]
36pub struct CompletionResponse {
37    pub id: String,
38    pub object: String,
39    pub created: u64,
40    pub model: String,
41    pub choices: Vec<Choice>,
42    pub system_fingerprint: Option<String>,
43    pub usage: Option<Usage>,
44}
45
46impl From<ApiErrorResponse> for CompletionError {
47    fn from(err: ApiErrorResponse) -> Self {
48        CompletionError::ProviderError(err.message)
49    }
50}
51
52impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
53    type Error = CompletionError;
54
55    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
56        let choice = response.choices.first().ok_or_else(|| {
57            CompletionError::ResponseError("Response contained no choices".to_owned())
58        })?;
59
60        let content = match &choice.message {
61            Message::Assistant {
62                content,
63                tool_calls,
64                reasoning,
65                ..
66            } => {
67                let mut content = content
68                    .iter()
69                    .map(|c| match c {
70                        openai::AssistantContent::Text { text } => {
71                            completion::AssistantContent::text(text)
72                        }
73                        openai::AssistantContent::Refusal { refusal } => {
74                            completion::AssistantContent::text(refusal)
75                        }
76                    })
77                    .collect::<Vec<_>>();
78
79                content.extend(
80                    tool_calls
81                        .iter()
82                        .map(|call| {
83                            completion::AssistantContent::tool_call(
84                                &call.id,
85                                &call.function.name,
86                                call.function.arguments.clone(),
87                            )
88                        })
89                        .collect::<Vec<_>>(),
90                );
91
92                if let Some(reasoning) = reasoning {
93                    content.push(completion::AssistantContent::reasoning(reasoning));
94                }
95
96                Ok(content)
97            }
98            _ => Err(CompletionError::ResponseError(
99                "Response did not contain a valid message or tool call".into(),
100            )),
101        }?;
102
103        let choice = OneOrMany::many(content).map_err(|_| {
104            CompletionError::ResponseError(
105                "Response contained no message or tool call (empty)".to_owned(),
106            )
107        })?;
108
109        let usage = response
110            .usage
111            .as_ref()
112            .map(|usage| completion::Usage {
113                input_tokens: usage.prompt_tokens as u64,
114                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
115                total_tokens: usage.total_tokens as u64,
116            })
117            .unwrap_or_default();
118
119        Ok(completion::CompletionResponse {
120            choice,
121            usage,
122            raw_response: response,
123        })
124    }
125}
126
127#[derive(Debug, Deserialize, Serialize)]
128pub struct Choice {
129    pub index: usize,
130    pub native_finish_reason: Option<String>,
131    pub message: Message,
132    pub finish_reason: Option<String>,
133}
134
135/// OpenRouter message.
136///
137/// Almost identical to OpenAI's Message, but supports more parameters
138/// for some providers like `reasoning`.
139#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
140#[serde(tag = "role", rename_all = "lowercase")]
141pub enum Message {
142    #[serde(alias = "developer")]
143    System {
144        #[serde(deserialize_with = "string_or_one_or_many")]
145        content: OneOrMany<openai::SystemContent>,
146        #[serde(skip_serializing_if = "Option::is_none")]
147        name: Option<String>,
148    },
149    User {
150        #[serde(deserialize_with = "string_or_one_or_many")]
151        content: OneOrMany<openai::UserContent>,
152        #[serde(skip_serializing_if = "Option::is_none")]
153        name: Option<String>,
154    },
155    Assistant {
156        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
157        content: Vec<openai::AssistantContent>,
158        #[serde(skip_serializing_if = "Option::is_none")]
159        refusal: Option<String>,
160        #[serde(skip_serializing_if = "Option::is_none")]
161        audio: Option<openai::AudioAssistant>,
162        #[serde(skip_serializing_if = "Option::is_none")]
163        name: Option<String>,
164        #[serde(
165            default,
166            deserialize_with = "json_utils::null_or_vec",
167            skip_serializing_if = "Vec::is_empty"
168        )]
169        tool_calls: Vec<openai::ToolCall>,
170        #[serde(skip_serializing_if = "Option::is_none")]
171        reasoning: Option<String>,
172    },
173    #[serde(rename = "tool")]
174    ToolResult {
175        tool_call_id: String,
176        content: OneOrMany<openai::ToolResultContent>,
177    },
178}
179
180impl Message {
181    pub fn system(content: &str) -> Self {
182        Message::System {
183            content: OneOrMany::one(content.to_owned().into()),
184            name: None,
185        }
186    }
187}
188
189impl From<openai::Message> for Message {
190    fn from(value: openai::Message) -> Self {
191        match value {
192            openai::Message::System { content, name } => Self::System { content, name },
193            openai::Message::User { content, name } => Self::User { content, name },
194            openai::Message::Assistant {
195                content,
196                refusal,
197                audio,
198                name,
199                tool_calls,
200            } => Self::Assistant {
201                content,
202                refusal,
203                audio,
204                name,
205                tool_calls,
206                reasoning: None,
207            },
208            openai::Message::ToolResult {
209                tool_call_id,
210                content,
211            } => Self::ToolResult {
212                tool_call_id,
213                content,
214            },
215        }
216    }
217}
218
219impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
220    type Error = message::MessageError;
221
222    fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
223        let mut text_content = Vec::new();
224        let mut tool_calls = Vec::new();
225        let mut reasoning = None;
226
227        for content in value.into_iter() {
228            match content {
229                message::AssistantContent::Text(text) => text_content.push(text),
230                message::AssistantContent::ToolCall(tool_call) => tool_calls.push(tool_call),
231                message::AssistantContent::Reasoning(r) => {
232                    reasoning = r.reasoning.into_iter().next();
233                }
234                message::AssistantContent::Image(_) => {
235                    return Err(Self::Error::ConversionError(
236                        "OpenRouter currently doesn't support images.".into(),
237                    ));
238                }
239            }
240        }
241
242        // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
243        //  so either `content` or `tool_calls` will have some content.
244        Ok(vec![Message::Assistant {
245            content: text_content
246                .into_iter()
247                .map(|content| content.text.into())
248                .collect::<Vec<_>>(),
249            refusal: None,
250            audio: None,
251            name: None,
252            tool_calls: tool_calls
253                .into_iter()
254                .map(|tool_call| tool_call.into())
255                .collect::<Vec<_>>(),
256            reasoning,
257        }])
258    }
259}
260
261// We re-use most of the openai implementation when we can and we re-implement
262// only the part that differentate for openrouter (like reasoning support).
263impl TryFrom<message::Message> for Vec<Message> {
264    type Error = message::MessageError;
265
266    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
267        match message {
268            message::Message::User { content } => {
269                let messages: Vec<openai::Message> = content.try_into()?;
270                Ok(messages.into_iter().map(Message::from).collect::<Vec<_>>())
271            }
272            message::Message::Assistant { content, .. } => content.try_into(),
273        }
274    }
275}
276
277#[derive(Debug, Serialize, Deserialize)]
278#[serde(untagged, rename_all = "snake_case")]
279pub enum ToolChoice {
280    None,
281    Auto,
282    Required,
283    Function(Vec<ToolChoiceFunctionKind>),
284}
285
286impl TryFrom<crate::message::ToolChoice> for ToolChoice {
287    type Error = CompletionError;
288
289    fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
290        let res = match value {
291            crate::message::ToolChoice::None => Self::None,
292            crate::message::ToolChoice::Auto => Self::Auto,
293            crate::message::ToolChoice::Required => Self::Required,
294            crate::message::ToolChoice::Specific { function_names } => {
295                let vec: Vec<ToolChoiceFunctionKind> = function_names
296                    .into_iter()
297                    .map(|name| ToolChoiceFunctionKind::Function { name })
298                    .collect();
299
300                Self::Function(vec)
301            }
302        };
303
304        Ok(res)
305    }
306}
307
308#[derive(Debug, Serialize, Deserialize)]
309#[serde(tag = "type", content = "function")]
310pub enum ToolChoiceFunctionKind {
311    Function { name: String },
312}
313
314#[derive(Debug, Serialize, Deserialize)]
315pub(super) struct OpenrouterCompletionRequest {
316    model: String,
317    pub messages: Vec<Message>,
318    #[serde(flatten, skip_serializing_if = "Option::is_none")]
319    temperature: Option<f64>,
320    #[serde(skip_serializing_if = "Vec::is_empty")]
321    tools: Vec<crate::providers::openai::completion::ToolDefinition>,
322    #[serde(flatten, skip_serializing_if = "Option::is_none")]
323    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
324    #[serde(flatten, skip_serializing_if = "Option::is_none")]
325    pub additional_params: Option<serde_json::Value>,
326}
327
328impl TryFrom<(&str, CompletionRequest)> for OpenrouterCompletionRequest {
329    type Error = CompletionError;
330
331    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
332        let mut full_history: Vec<Message> = match &req.preamble {
333            Some(preamble) => vec![Message::system(preamble)],
334            None => vec![],
335        };
336        if let Some(docs) = req.normalized_documents() {
337            let docs: Vec<Message> = docs.try_into()?;
338            full_history.extend(docs);
339        }
340
341        let chat_history: Vec<Message> = req
342            .chat_history
343            .clone()
344            .into_iter()
345            .map(|message| message.try_into())
346            .collect::<Result<Vec<Vec<Message>>, _>>()?
347            .into_iter()
348            .flatten()
349            .collect();
350
351        full_history.extend(chat_history);
352
353        let tool_choice = req
354            .tool_choice
355            .clone()
356            .map(crate::providers::openai::completion::ToolChoice::try_from)
357            .transpose()?;
358
359        Ok(Self {
360            model: model.to_string(),
361            messages: full_history,
362            temperature: req.temperature,
363            tools: req
364                .tools
365                .clone()
366                .into_iter()
367                .map(crate::providers::openai::completion::ToolDefinition::from)
368                .collect::<Vec<_>>(),
369            tool_choice,
370            additional_params: req.additional_params,
371        })
372    }
373}
374
375#[derive(Clone)]
376pub struct CompletionModel<T = reqwest::Client> {
377    pub(crate) client: Client<T>,
378    pub model: String,
379}
380
381impl<T> CompletionModel<T> {
382    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
383        Self {
384            client,
385            model: model.into(),
386        }
387    }
388}
389
390impl<T> completion::CompletionModel for CompletionModel<T>
391where
392    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
393{
394    type Response = CompletionResponse;
395    type StreamingResponse = StreamingCompletionResponse;
396
397    type Client = Client<T>;
398
399    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
400        Self::new(client.clone(), model)
401    }
402
403    #[cfg_attr(feature = "worker", worker::send)]
404    async fn completion(
405        &self,
406        completion_request: CompletionRequest,
407    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
408        let preamble = completion_request.preamble.clone();
409        let request =
410            OpenrouterCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
411        let span = if tracing::Span::current().is_disabled() {
412            info_span!(
413                target: "rig::completions",
414                "chat",
415                gen_ai.operation.name = "chat",
416                gen_ai.provider.name = "openrouter",
417                gen_ai.request.model = self.model,
418                gen_ai.system_instructions = preamble,
419                gen_ai.response.id = tracing::field::Empty,
420                gen_ai.response.model = tracing::field::Empty,
421                gen_ai.usage.output_tokens = tracing::field::Empty,
422                gen_ai.usage.input_tokens = tracing::field::Empty,
423                gen_ai.input.messages = serde_json::to_string(&request.messages)?,
424                gen_ai.output.messages = tracing::field::Empty,
425            )
426        } else {
427            tracing::Span::current()
428        };
429
430        let body = serde_json::to_vec(&request)?;
431
432        let req = self
433            .client
434            .post("/chat/completions")?
435            .body(body)
436            .map_err(|x| CompletionError::HttpError(x.into()))?;
437
438        async move {
439            let response = self.client.send::<_, Bytes>(req).await?;
440            let status = response.status();
441            let response_body = response.into_body().into_future().await?.to_vec();
442
443            if status.is_success() {
444                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
445                    ApiResponse::Ok(response) => {
446                        let span = tracing::Span::current();
447                        span.record_token_usage(&response.usage);
448                        span.record_model_output(&response.choices);
449                        span.record("gen_ai.response.id", &response.id);
450                        span.record("gen_ai.response.model_name", &response.model);
451
452                        tracing::debug!(target: "rig::completions",
453                            "OpenRouter response: {response:?}");
454                        response.try_into()
455                    }
456                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
457                }
458            } else {
459                Err(CompletionError::ProviderError(
460                    String::from_utf8_lossy(&response_body).to_string(),
461                ))
462            }
463        }
464        .instrument(span)
465        .await
466    }
467
468    #[cfg_attr(feature = "worker", worker::send)]
469    async fn stream(
470        &self,
471        completion_request: CompletionRequest,
472    ) -> Result<
473        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
474        CompletionError,
475    > {
476        CompletionModel::stream(self, completion_request).await
477    }
478}