rig/providers/openrouter/
completion.rs

1use serde::Deserialize;
2
3use super::client::{ApiErrorResponse, ApiResponse, Client, Usage};
4
5use crate::{
6    OneOrMany,
7    completion::{self, CompletionError, CompletionRequest},
8    json_utils,
9    providers::openai::Message,
10};
11use serde_json::{Value, json};
12
13use crate::providers::openai::AssistantContent;
14use crate::providers::openrouter::streaming::FinalCompletionResponse;
15use crate::streaming::StreamingCompletionResponse;
16
17// ================================================================
18// OpenRouter Completion API
19// ================================================================
20/// The `qwen/qwq-32b` model. Find more models at <https://openrouter.ai/models>.
21pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
22/// The `anthropic/claude-3.7-sonnet` model. Find more models at <https://openrouter.ai/models>.
23pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
24/// The `perplexity/sonar-pro` model. Find more models at <https://openrouter.ai/models>.
25pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
26/// The `google/gemini-2.0-flash-001` model. Find more models at <https://openrouter.ai/models>.
27pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
28
29/// A openrouter completion object.
30///
31/// For more information, see this link: <https://docs.openrouter.xyz/reference/create_chat_completion_v1_chat_completions_post>
32#[derive(Debug, Deserialize)]
33pub struct CompletionResponse {
34    pub id: String,
35    pub object: String,
36    pub created: u64,
37    pub model: String,
38    pub choices: Vec<Choice>,
39    pub system_fingerprint: Option<String>,
40    pub usage: Option<Usage>,
41}
42
43impl From<ApiErrorResponse> for CompletionError {
44    fn from(err: ApiErrorResponse) -> Self {
45        CompletionError::ProviderError(err.message)
46    }
47}
48
49impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
50    type Error = CompletionError;
51
52    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
53        let choice = response.choices.first().ok_or_else(|| {
54            CompletionError::ResponseError("Response contained no choices".to_owned())
55        })?;
56
57        let content = match &choice.message {
58            Message::Assistant {
59                content,
60                tool_calls,
61                ..
62            } => {
63                let mut content = content
64                    .iter()
65                    .map(|c| match c {
66                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
67                        AssistantContent::Refusal { refusal } => {
68                            completion::AssistantContent::text(refusal)
69                        }
70                    })
71                    .collect::<Vec<_>>();
72
73                content.extend(
74                    tool_calls
75                        .iter()
76                        .map(|call| {
77                            completion::AssistantContent::tool_call(
78                                &call.id,
79                                &call.function.name,
80                                call.function.arguments.clone(),
81                            )
82                        })
83                        .collect::<Vec<_>>(),
84                );
85                Ok(content)
86            }
87            _ => Err(CompletionError::ResponseError(
88                "Response did not contain a valid message or tool call".into(),
89            )),
90        }?;
91
92        let choice = OneOrMany::many(content).map_err(|_| {
93            CompletionError::ResponseError(
94                "Response contained no message or tool call (empty)".to_owned(),
95            )
96        })?;
97
98        Ok(completion::CompletionResponse {
99            choice,
100            raw_response: response,
101        })
102    }
103}
104
105#[derive(Debug, Deserialize)]
106pub struct Choice {
107    pub index: usize,
108    pub native_finish_reason: Option<String>,
109    pub message: Message,
110    pub finish_reason: Option<String>,
111}
112
113#[derive(Clone)]
114pub struct CompletionModel {
115    pub(crate) client: Client,
116    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
117    pub model: String,
118}
119
120impl CompletionModel {
121    pub fn new(client: Client, model: &str) -> Self {
122        Self {
123            client,
124            model: model.to_string(),
125        }
126    }
127
128    pub(crate) fn create_completion_request(
129        &self,
130        completion_request: CompletionRequest,
131    ) -> Result<Value, CompletionError> {
132        // Add preamble to chat history (if available)
133        let mut full_history: Vec<Message> = match &completion_request.preamble {
134            Some(preamble) => vec![Message::system(preamble)],
135            None => vec![],
136        };
137
138        // Gather docs
139        if let Some(docs) = completion_request.normalized_documents() {
140            let docs: Vec<Message> = docs.try_into()?;
141            full_history.extend(docs);
142        }
143
144        // Convert existing chat history
145        let chat_history: Vec<Message> = completion_request
146            .chat_history
147            .into_iter()
148            .map(|message| message.try_into())
149            .collect::<Result<Vec<Vec<Message>>, _>>()?
150            .into_iter()
151            .flatten()
152            .collect();
153
154        // Combine all messages into a single history
155        full_history.extend(chat_history);
156
157        let request = json!({
158            "model": self.model,
159            "messages": full_history,
160            "temperature": completion_request.temperature,
161            "tools": completion_request.tools.into_iter().map(crate::providers::openai::completion::ToolDefinition::from).collect::<Vec<_>>()
162        });
163
164        let request = if let Some(params) = completion_request.additional_params {
165            json_utils::merge(request, params)
166        } else {
167            request
168        };
169
170        Ok(request)
171    }
172}
173
174impl completion::CompletionModel for CompletionModel {
175    type Response = CompletionResponse;
176    type StreamingResponse = FinalCompletionResponse;
177
178    #[cfg_attr(feature = "worker", worker::send)]
179    async fn completion(
180        &self,
181        completion_request: CompletionRequest,
182    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
183        let request = self.create_completion_request(completion_request)?;
184
185        let response = self
186            .client
187            .post("/chat/completions")
188            .json(&request)
189            .send()
190            .await?;
191
192        if response.status().is_success() {
193            match response.json::<ApiResponse<CompletionResponse>>().await? {
194                ApiResponse::Ok(response) => {
195                    tracing::info!(target: "rig",
196                        "OpenRouter completion token usage: {:?}",
197                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
198                    );
199                    tracing::debug!(target: "rig",
200                        "OpenRouter response: {response:?}");
201                    response.try_into()
202                }
203                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
204            }
205        } else {
206            Err(CompletionError::ProviderError(response.text().await?))
207        }
208    }
209
210    #[cfg_attr(feature = "worker", worker::send)]
211    async fn stream(
212        &self,
213        completion_request: CompletionRequest,
214    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
215        CompletionModel::stream(self, completion_request).await
216    }
217}