rig/providers/openrouter/
completion.rs

1use serde::{Deserialize, Serialize};
2
3use super::client::{ApiErrorResponse, ApiResponse, Client, Usage};
4
5use crate::{
6    OneOrMany,
7    completion::{self, CompletionError, CompletionRequest},
8    json_utils,
9    providers::openai::Message,
10};
11use serde_json::{Value, json};
12
13use crate::providers::openai::AssistantContent;
14use crate::providers::openrouter::streaming::FinalCompletionResponse;
15use crate::streaming::StreamingCompletionResponse;
16
17// ================================================================
18// OpenRouter Completion API
19// ================================================================
20/// The `qwen/qwq-32b` model. Find more models at <https://openrouter.ai/models>.
21pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
22/// The `anthropic/claude-3.7-sonnet` model. Find more models at <https://openrouter.ai/models>.
23pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
24/// The `perplexity/sonar-pro` model. Find more models at <https://openrouter.ai/models>.
25pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
26/// The `google/gemini-2.0-flash-001` model. Find more models at <https://openrouter.ai/models>.
27pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
28
29/// A openrouter completion object.
30///
31/// For more information, see this link: <https://docs.openrouter.xyz/reference/create_chat_completion_v1_chat_completions_post>
32#[derive(Debug, Serialize, Deserialize)]
33pub struct CompletionResponse {
34    pub id: String,
35    pub object: String,
36    pub created: u64,
37    pub model: String,
38    pub choices: Vec<Choice>,
39    pub system_fingerprint: Option<String>,
40    pub usage: Option<Usage>,
41}
42
43impl From<ApiErrorResponse> for CompletionError {
44    fn from(err: ApiErrorResponse) -> Self {
45        CompletionError::ProviderError(err.message)
46    }
47}
48
49impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
50    type Error = CompletionError;
51
52    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
53        let choice = response.choices.first().ok_or_else(|| {
54            CompletionError::ResponseError("Response contained no choices".to_owned())
55        })?;
56
57        let content = match &choice.message {
58            Message::Assistant {
59                content,
60                tool_calls,
61                ..
62            } => {
63                let mut content = content
64                    .iter()
65                    .map(|c| match c {
66                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
67                        AssistantContent::Refusal { refusal } => {
68                            completion::AssistantContent::text(refusal)
69                        }
70                    })
71                    .collect::<Vec<_>>();
72
73                content.extend(
74                    tool_calls
75                        .iter()
76                        .map(|call| {
77                            completion::AssistantContent::tool_call(
78                                &call.id,
79                                &call.function.name,
80                                call.function.arguments.clone(),
81                            )
82                        })
83                        .collect::<Vec<_>>(),
84                );
85                Ok(content)
86            }
87            _ => Err(CompletionError::ResponseError(
88                "Response did not contain a valid message or tool call".into(),
89            )),
90        }?;
91
92        let choice = OneOrMany::many(content).map_err(|_| {
93            CompletionError::ResponseError(
94                "Response contained no message or tool call (empty)".to_owned(),
95            )
96        })?;
97
98        let usage = response
99            .usage
100            .as_ref()
101            .map(|usage| completion::Usage {
102                input_tokens: usage.prompt_tokens as u64,
103                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
104                total_tokens: usage.total_tokens as u64,
105            })
106            .unwrap_or_default();
107
108        Ok(completion::CompletionResponse {
109            choice,
110            usage,
111            raw_response: response,
112        })
113    }
114}
115
116#[derive(Debug, Deserialize, Serialize)]
117pub struct Choice {
118    pub index: usize,
119    pub native_finish_reason: Option<String>,
120    pub message: Message,
121    pub finish_reason: Option<String>,
122}
123
124#[derive(Clone)]
125pub struct CompletionModel {
126    pub(crate) client: Client,
127    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
128    pub model: String,
129}
130
131impl CompletionModel {
132    pub fn new(client: Client, model: &str) -> Self {
133        Self {
134            client,
135            model: model.to_string(),
136        }
137    }
138
139    pub(crate) fn create_completion_request(
140        &self,
141        completion_request: CompletionRequest,
142    ) -> Result<Value, CompletionError> {
143        // Add preamble to chat history (if available)
144        let mut full_history: Vec<Message> = match &completion_request.preamble {
145            Some(preamble) => vec![Message::system(preamble)],
146            None => vec![],
147        };
148
149        // Gather docs
150        if let Some(docs) = completion_request.normalized_documents() {
151            let docs: Vec<Message> = docs.try_into()?;
152            full_history.extend(docs);
153        }
154
155        // Convert existing chat history
156        let chat_history: Vec<Message> = completion_request
157            .chat_history
158            .into_iter()
159            .map(|message| message.try_into())
160            .collect::<Result<Vec<Vec<Message>>, _>>()?
161            .into_iter()
162            .flatten()
163            .collect();
164
165        // Combine all messages into a single history
166        full_history.extend(chat_history);
167
168        let request = json!({
169            "model": self.model,
170            "messages": full_history,
171            "temperature": completion_request.temperature,
172            "tools": completion_request.tools.into_iter().map(crate::providers::openai::completion::ToolDefinition::from).collect::<Vec<_>>()
173        });
174
175        let request = if let Some(params) = completion_request.additional_params {
176            json_utils::merge(request, params)
177        } else {
178            request
179        };
180
181        Ok(request)
182    }
183}
184
185impl completion::CompletionModel for CompletionModel {
186    type Response = CompletionResponse;
187    type StreamingResponse = FinalCompletionResponse;
188
189    #[cfg_attr(feature = "worker", worker::send)]
190    async fn completion(
191        &self,
192        completion_request: CompletionRequest,
193    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
194        let request = self.create_completion_request(completion_request)?;
195
196        let response = self
197            .client
198            .post("/chat/completions")
199            .json(&request)
200            .send()
201            .await?;
202
203        if response.status().is_success() {
204            match response.json::<ApiResponse<CompletionResponse>>().await? {
205                ApiResponse::Ok(response) => {
206                    tracing::info!(target: "rig",
207                        "OpenRouter completion token usage: {:?}",
208                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
209                    );
210                    tracing::debug!(target: "rig",
211                        "OpenRouter response: {response:?}");
212                    response.try_into()
213                }
214                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
215            }
216        } else {
217            Err(CompletionError::ProviderError(response.text().await?))
218        }
219    }
220
221    #[cfg_attr(feature = "worker", worker::send)]
222    async fn stream(
223        &self,
224        completion_request: CompletionRequest,
225    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
226        CompletionModel::stream(self, completion_request).await
227    }
228}