rig/providers/openrouter/
completion.rs

1use serde::{Deserialize, Serialize};
2use tracing::{Instrument, info_span};
3
4use super::client::{ApiErrorResponse, ApiResponse, Client, Usage};
5
6use crate::{
7    OneOrMany,
8    completion::{self, CompletionError, CompletionRequest},
9    http_client, json_utils,
10    providers::openai::Message,
11};
12use serde_json::{Value, json};
13
14use crate::providers::openai::AssistantContent;
15use crate::providers::openrouter::streaming::FinalCompletionResponse;
16use crate::streaming::StreamingCompletionResponse;
17use crate::telemetry::SpanCombinator;
18
19// ================================================================
20// OpenRouter Completion API
21// ================================================================
22/// The `qwen/qwq-32b` model. Find more models at <https://openrouter.ai/models>.
23pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
24/// The `anthropic/claude-3.7-sonnet` model. Find more models at <https://openrouter.ai/models>.
25pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
26/// The `perplexity/sonar-pro` model. Find more models at <https://openrouter.ai/models>.
27pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
28/// The `google/gemini-2.0-flash-001` model. Find more models at <https://openrouter.ai/models>.
29pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
30
31/// A openrouter completion object.
32///
33/// For more information, see this link: <https://docs.openrouter.xyz/reference/create_chat_completion_v1_chat_completions_post>
34#[derive(Debug, Serialize, Deserialize)]
35pub struct CompletionResponse {
36    pub id: String,
37    pub object: String,
38    pub created: u64,
39    pub model: String,
40    pub choices: Vec<Choice>,
41    pub system_fingerprint: Option<String>,
42    pub usage: Option<Usage>,
43}
44
45impl From<ApiErrorResponse> for CompletionError {
46    fn from(err: ApiErrorResponse) -> Self {
47        CompletionError::ProviderError(err.message)
48    }
49}
50
51impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
52    type Error = CompletionError;
53
54    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
55        let choice = response.choices.first().ok_or_else(|| {
56            CompletionError::ResponseError("Response contained no choices".to_owned())
57        })?;
58
59        let content = match &choice.message {
60            Message::Assistant {
61                content,
62                tool_calls,
63                ..
64            } => {
65                let mut content = content
66                    .iter()
67                    .map(|c| match c {
68                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
69                        AssistantContent::Refusal { refusal } => {
70                            completion::AssistantContent::text(refusal)
71                        }
72                    })
73                    .collect::<Vec<_>>();
74
75                content.extend(
76                    tool_calls
77                        .iter()
78                        .map(|call| {
79                            completion::AssistantContent::tool_call(
80                                &call.id,
81                                &call.function.name,
82                                call.function.arguments.clone(),
83                            )
84                        })
85                        .collect::<Vec<_>>(),
86                );
87                Ok(content)
88            }
89            _ => Err(CompletionError::ResponseError(
90                "Response did not contain a valid message or tool call".into(),
91            )),
92        }?;
93
94        let choice = OneOrMany::many(content).map_err(|_| {
95            CompletionError::ResponseError(
96                "Response contained no message or tool call (empty)".to_owned(),
97            )
98        })?;
99
100        let usage = response
101            .usage
102            .as_ref()
103            .map(|usage| completion::Usage {
104                input_tokens: usage.prompt_tokens as u64,
105                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
106                total_tokens: usage.total_tokens as u64,
107            })
108            .unwrap_or_default();
109
110        Ok(completion::CompletionResponse {
111            choice,
112            usage,
113            raw_response: response,
114        })
115    }
116}
117
118#[derive(Debug, Deserialize, Serialize)]
119pub struct Choice {
120    pub index: usize,
121    pub native_finish_reason: Option<String>,
122    pub message: Message,
123    pub finish_reason: Option<String>,
124}
125
126#[derive(Debug, Serialize, Deserialize)]
127#[serde(untagged, rename_all = "snake_case")]
128pub enum ToolChoice {
129    None,
130    Auto,
131    Required,
132    Function(Vec<ToolChoiceFunctionKind>),
133}
134
135impl TryFrom<crate::message::ToolChoice> for ToolChoice {
136    type Error = CompletionError;
137
138    fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
139        let res = match value {
140            crate::message::ToolChoice::None => Self::None,
141            crate::message::ToolChoice::Auto => Self::Auto,
142            crate::message::ToolChoice::Required => Self::Required,
143            crate::message::ToolChoice::Specific { function_names } => {
144                let vec: Vec<ToolChoiceFunctionKind> = function_names
145                    .into_iter()
146                    .map(|name| ToolChoiceFunctionKind::Function { name })
147                    .collect();
148
149                Self::Function(vec)
150            }
151        };
152
153        Ok(res)
154    }
155}
156
157#[derive(Debug, Serialize, Deserialize)]
158#[serde(tag = "type", content = "function")]
159pub enum ToolChoiceFunctionKind {
160    Function { name: String },
161}
162
163#[derive(Clone)]
164pub struct CompletionModel<T = reqwest::Client> {
165    pub(crate) client: Client<T>,
166    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
167    pub model: String,
168}
169
170impl<T> CompletionModel<T> {
171    pub fn new(client: Client<T>, model: &str) -> Self {
172        Self {
173            client,
174            model: model.to_string(),
175        }
176    }
177
178    pub(crate) fn create_completion_request(
179        &self,
180        completion_request: CompletionRequest,
181    ) -> Result<Value, CompletionError> {
182        // Add preamble to chat history (if available)
183        let mut full_history: Vec<Message> = match &completion_request.preamble {
184            Some(preamble) => vec![Message::system(preamble)],
185            None => vec![],
186        };
187
188        // Gather docs
189        if let Some(docs) = completion_request.normalized_documents() {
190            let docs: Vec<Message> = docs.try_into()?;
191            full_history.extend(docs);
192        }
193
194        // Convert existing chat history
195        let chat_history: Vec<Message> = completion_request
196            .chat_history
197            .into_iter()
198            .map(|message| message.try_into())
199            .collect::<Result<Vec<Vec<Message>>, _>>()?
200            .into_iter()
201            .flatten()
202            .collect();
203
204        // Combine all messages into a single history
205        full_history.extend(chat_history);
206
207        let tool_choice = completion_request
208            .tool_choice
209            .map(ToolChoice::try_from)
210            .transpose()?;
211
212        let request = json!({
213            "model": self.model,
214            "messages": full_history,
215            "temperature": completion_request.temperature,
216            "tools": completion_request.tools.into_iter().map(crate::providers::openai::completion::ToolDefinition::from).collect::<Vec<_>>(),
217            "tool_choice": tool_choice,
218        });
219
220        let request = if let Some(params) = completion_request.additional_params {
221            json_utils::merge(request, params)
222        } else {
223            request
224        };
225
226        Ok(request)
227    }
228}
229
230impl completion::CompletionModel for CompletionModel<reqwest::Client> {
231    type Response = CompletionResponse;
232    type StreamingResponse = FinalCompletionResponse;
233
234    #[cfg_attr(feature = "worker", worker::send)]
235    async fn completion(
236        &self,
237        completion_request: CompletionRequest,
238    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
239        let preamble = completion_request.preamble.clone();
240        let request = self.create_completion_request(completion_request)?;
241        let span = if tracing::Span::current().is_disabled() {
242            info_span!(
243                target: "rig::completion",
244                "chat",
245                gen_ai.operation.name = "chat",
246                gen_ai.provider.name = "openrouter",
247                gen_ai.request.model = self.model,
248                gen_ai.system_instructions = preamble,
249                gen_ai.response.id = tracing::field::Empty,
250                gen_ai.response.model = tracing::field::Empty,
251                gen_ai.usage.output_tokens = tracing::field::Empty,
252                gen_ai.usage.input_tokens = tracing::field::Empty,
253                gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(),
254                gen_ai.output.messages = tracing::field::Empty,
255            )
256        } else {
257            tracing::Span::current()
258        };
259
260        async move {
261            let response = self
262                .client
263                .reqwest_client()
264                .post("/chat/completions")
265                .json(&request)
266                .send()
267                .await
268                .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?;
269
270            if response.status().is_success() {
271                match response
272                    .json::<ApiResponse<CompletionResponse>>()
273                    .await
274                    .map_err(|e| {
275                        CompletionError::HttpError(http_client::Error::Instance(e.into()))
276                    })? {
277                    ApiResponse::Ok(response) => {
278                        let span = tracing::Span::current();
279                        span.record_token_usage(&response.usage);
280                        span.record_model_output(&response.choices);
281                        span.record("gen_ai.response.id", &response.id);
282                        span.record("gen_ai.response.model_name", &response.model);
283
284                        tracing::debug!(target: "rig::completion",
285                            "OpenRouter response: {response:?}");
286                        response.try_into()
287                    }
288                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
289                }
290            } else {
291                Err(CompletionError::ProviderError(
292                    response.text().await.map_err(|e| {
293                        CompletionError::HttpError(http_client::Error::Instance(e.into()))
294                    })?,
295                ))
296            }
297        }
298        .instrument(span)
299        .await
300    }
301
302    #[cfg_attr(feature = "worker", worker::send)]
303    async fn stream(
304        &self,
305        completion_request: CompletionRequest,
306    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
307        CompletionModel::stream(self, completion_request).await
308    }
309}