rig/providers/openrouter/
completion.rs

1use bytes::Bytes;
2use serde::{Deserialize, Serialize};
3use tracing::{Instrument, info_span};
4
5use super::client::{ApiErrorResponse, ApiResponse, Client, Usage};
6
7use crate::{
8    OneOrMany,
9    completion::{self, CompletionError, CompletionRequest},
10    http_client::HttpClientExt,
11    json_utils,
12    providers::openai::Message,
13};
14use serde_json::{Value, json};
15
16use crate::providers::openai::AssistantContent;
17use crate::providers::openrouter::streaming::FinalCompletionResponse;
18use crate::streaming::StreamingCompletionResponse;
19use crate::telemetry::SpanCombinator;
20
21// ================================================================
22// OpenRouter Completion API
23// ================================================================
24/// The `qwen/qwq-32b` model. Find more models at <https://openrouter.ai/models>.
25pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
26/// The `anthropic/claude-3.7-sonnet` model. Find more models at <https://openrouter.ai/models>.
27pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
28/// The `perplexity/sonar-pro` model. Find more models at <https://openrouter.ai/models>.
29pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
30/// The `google/gemini-2.0-flash-001` model. Find more models at <https://openrouter.ai/models>.
31pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
32
33/// A openrouter completion object.
34///
35/// For more information, see this link: <https://docs.openrouter.xyz/reference/create_chat_completion_v1_chat_completions_post>
36#[derive(Debug, Serialize, Deserialize)]
37pub struct CompletionResponse {
38    pub id: String,
39    pub object: String,
40    pub created: u64,
41    pub model: String,
42    pub choices: Vec<Choice>,
43    pub system_fingerprint: Option<String>,
44    pub usage: Option<Usage>,
45}
46
47impl From<ApiErrorResponse> for CompletionError {
48    fn from(err: ApiErrorResponse) -> Self {
49        CompletionError::ProviderError(err.message)
50    }
51}
52
53impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
54    type Error = CompletionError;
55
56    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
57        let choice = response.choices.first().ok_or_else(|| {
58            CompletionError::ResponseError("Response contained no choices".to_owned())
59        })?;
60
61        let content = match &choice.message {
62            Message::Assistant {
63                content,
64                tool_calls,
65                ..
66            } => {
67                let mut content = content
68                    .iter()
69                    .map(|c| match c {
70                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
71                        AssistantContent::Refusal { refusal } => {
72                            completion::AssistantContent::text(refusal)
73                        }
74                    })
75                    .collect::<Vec<_>>();
76
77                content.extend(
78                    tool_calls
79                        .iter()
80                        .map(|call| {
81                            completion::AssistantContent::tool_call(
82                                &call.id,
83                                &call.function.name,
84                                call.function.arguments.clone(),
85                            )
86                        })
87                        .collect::<Vec<_>>(),
88                );
89                Ok(content)
90            }
91            _ => Err(CompletionError::ResponseError(
92                "Response did not contain a valid message or tool call".into(),
93            )),
94        }?;
95
96        let choice = OneOrMany::many(content).map_err(|_| {
97            CompletionError::ResponseError(
98                "Response contained no message or tool call (empty)".to_owned(),
99            )
100        })?;
101
102        let usage = response
103            .usage
104            .as_ref()
105            .map(|usage| completion::Usage {
106                input_tokens: usage.prompt_tokens as u64,
107                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
108                total_tokens: usage.total_tokens as u64,
109            })
110            .unwrap_or_default();
111
112        Ok(completion::CompletionResponse {
113            choice,
114            usage,
115            raw_response: response,
116        })
117    }
118}
119
120#[derive(Debug, Deserialize, Serialize)]
121pub struct Choice {
122    pub index: usize,
123    pub native_finish_reason: Option<String>,
124    pub message: Message,
125    pub finish_reason: Option<String>,
126}
127
128#[derive(Debug, Serialize, Deserialize)]
129#[serde(untagged, rename_all = "snake_case")]
130pub enum ToolChoice {
131    None,
132    Auto,
133    Required,
134    Function(Vec<ToolChoiceFunctionKind>),
135}
136
137impl TryFrom<crate::message::ToolChoice> for ToolChoice {
138    type Error = CompletionError;
139
140    fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
141        let res = match value {
142            crate::message::ToolChoice::None => Self::None,
143            crate::message::ToolChoice::Auto => Self::Auto,
144            crate::message::ToolChoice::Required => Self::Required,
145            crate::message::ToolChoice::Specific { function_names } => {
146                let vec: Vec<ToolChoiceFunctionKind> = function_names
147                    .into_iter()
148                    .map(|name| ToolChoiceFunctionKind::Function { name })
149                    .collect();
150
151                Self::Function(vec)
152            }
153        };
154
155        Ok(res)
156    }
157}
158
159#[derive(Debug, Serialize, Deserialize)]
160#[serde(tag = "type", content = "function")]
161pub enum ToolChoiceFunctionKind {
162    Function { name: String },
163}
164
165#[derive(Clone)]
166pub struct CompletionModel<T = reqwest::Client> {
167    pub(crate) client: Client<T>,
168    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
169    pub model: String,
170}
171
172impl<T> CompletionModel<T> {
173    pub fn new(client: Client<T>, model: &str) -> Self {
174        Self {
175            client,
176            model: model.to_string(),
177        }
178    }
179
180    pub(crate) fn create_completion_request(
181        &self,
182        completion_request: CompletionRequest,
183    ) -> Result<Value, CompletionError> {
184        // Add preamble to chat history (if available)
185        let mut full_history: Vec<Message> = match &completion_request.preamble {
186            Some(preamble) => vec![Message::system(preamble)],
187            None => vec![],
188        };
189
190        // Gather docs
191        if let Some(docs) = completion_request.normalized_documents() {
192            let docs: Vec<Message> = docs.try_into()?;
193            full_history.extend(docs);
194        }
195
196        // Convert existing chat history
197        let chat_history: Vec<Message> = completion_request
198            .chat_history
199            .into_iter()
200            .map(|message| message.try_into())
201            .collect::<Result<Vec<Vec<Message>>, _>>()?
202            .into_iter()
203            .flatten()
204            .collect();
205
206        // Combine all messages into a single history
207        full_history.extend(chat_history);
208
209        let tool_choice = completion_request
210            .tool_choice
211            .map(ToolChoice::try_from)
212            .transpose()?;
213
214        let mut request = json!({
215            "model": self.model,
216            "messages": full_history,
217        });
218
219        if let Some(temperature) = completion_request.temperature {
220            request["temperature"] = json!(temperature);
221        }
222
223        if !completion_request.tools.is_empty() {
224            request["tools"] = json!(
225                completion_request
226                    .tools
227                    .into_iter()
228                    .map(crate::providers::openai::completion::ToolDefinition::from)
229                    .collect::<Vec<_>>()
230            );
231            request["tool_choice"] = json!(tool_choice);
232        }
233
234        let request = if let Some(params) = completion_request.additional_params {
235            json_utils::merge(request, params)
236        } else {
237            request
238        };
239
240        Ok(request)
241    }
242}
243
244impl<T> completion::CompletionModel for CompletionModel<T>
245where
246    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
247{
248    type Response = CompletionResponse;
249    type StreamingResponse = FinalCompletionResponse;
250
251    #[cfg_attr(feature = "worker", worker::send)]
252    async fn completion(
253        &self,
254        completion_request: CompletionRequest,
255    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
256        let preamble = completion_request.preamble.clone();
257        let request = self.create_completion_request(completion_request)?;
258        let span = if tracing::Span::current().is_disabled() {
259            info_span!(
260                target: "rig::completion",
261                "chat",
262                gen_ai.operation.name = "chat",
263                gen_ai.provider.name = "openrouter",
264                gen_ai.request.model = self.model,
265                gen_ai.system_instructions = preamble,
266                gen_ai.response.id = tracing::field::Empty,
267                gen_ai.response.model = tracing::field::Empty,
268                gen_ai.usage.output_tokens = tracing::field::Empty,
269                gen_ai.usage.input_tokens = tracing::field::Empty,
270                gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(),
271                gen_ai.output.messages = tracing::field::Empty,
272            )
273        } else {
274            tracing::Span::current()
275        };
276
277        let body = serde_json::to_vec(&request)?;
278
279        let req = self
280            .client
281            .post("/chat/completions")?
282            .header("Content-Type", "application/json")
283            .body(body)
284            .map_err(|x| CompletionError::HttpError(x.into()))?;
285
286        async move {
287            let response = self.client.http_client.send::<_, Bytes>(req).await?;
288            let status = response.status();
289            let response_body = response.into_body().into_future().await?.to_vec();
290
291            if status.is_success() {
292                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
293                    ApiResponse::Ok(response) => {
294                        let span = tracing::Span::current();
295                        span.record_token_usage(&response.usage);
296                        span.record_model_output(&response.choices);
297                        span.record("gen_ai.response.id", &response.id);
298                        span.record("gen_ai.response.model_name", &response.model);
299
300                        tracing::debug!(target: "rig::completion",
301                            "OpenRouter response: {response:?}");
302                        response.try_into()
303                    }
304                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
305                }
306            } else {
307                Err(CompletionError::ProviderError(
308                    String::from_utf8_lossy(&response_body).to_string(),
309                ))
310            }
311        }
312        .instrument(span)
313        .await
314    }
315
316    #[cfg_attr(feature = "worker", worker::send)]
317    async fn stream(
318        &self,
319        completion_request: CompletionRequest,
320    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
321        CompletionModel::stream(self, completion_request).await
322    }
323}