rig/providers/openrouter/
completion.rs1use serde::Deserialize;
2
3use super::client::{ApiErrorResponse, ApiResponse, Client, Usage};
4
5use crate::{
6 OneOrMany,
7 completion::{self, CompletionError, CompletionRequest},
8 json_utils,
9 providers::openai::Message,
10};
11use serde_json::{Value, json};
12
13use crate::providers::openai::AssistantContent;
14use crate::providers::openrouter::streaming::FinalCompletionResponse;
15use crate::streaming::StreamingCompletionResponse;
16
17pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
22pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
24pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
26pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
28
29#[derive(Debug, Deserialize)]
33pub struct CompletionResponse {
34 pub id: String,
35 pub object: String,
36 pub created: u64,
37 pub model: String,
38 pub choices: Vec<Choice>,
39 pub system_fingerprint: Option<String>,
40 pub usage: Option<Usage>,
41}
42
43impl From<ApiErrorResponse> for CompletionError {
44 fn from(err: ApiErrorResponse) -> Self {
45 CompletionError::ProviderError(err.message)
46 }
47}
48
49impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
50 type Error = CompletionError;
51
52 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
53 let choice = response.choices.first().ok_or_else(|| {
54 CompletionError::ResponseError("Response contained no choices".to_owned())
55 })?;
56
57 let content = match &choice.message {
58 Message::Assistant {
59 content,
60 tool_calls,
61 ..
62 } => {
63 let mut content = content
64 .iter()
65 .map(|c| match c {
66 AssistantContent::Text { text } => completion::AssistantContent::text(text),
67 AssistantContent::Refusal { refusal } => {
68 completion::AssistantContent::text(refusal)
69 }
70 })
71 .collect::<Vec<_>>();
72
73 content.extend(
74 tool_calls
75 .iter()
76 .map(|call| {
77 completion::AssistantContent::tool_call(
78 &call.id,
79 &call.function.name,
80 call.function.arguments.clone(),
81 )
82 })
83 .collect::<Vec<_>>(),
84 );
85 Ok(content)
86 }
87 _ => Err(CompletionError::ResponseError(
88 "Response did not contain a valid message or tool call".into(),
89 )),
90 }?;
91
92 let choice = OneOrMany::many(content).map_err(|_| {
93 CompletionError::ResponseError(
94 "Response contained no message or tool call (empty)".to_owned(),
95 )
96 })?;
97
98 Ok(completion::CompletionResponse {
99 choice,
100 raw_response: response,
101 })
102 }
103}
104
105#[derive(Debug, Deserialize)]
106pub struct Choice {
107 pub index: usize,
108 pub native_finish_reason: Option<String>,
109 pub message: Message,
110 pub finish_reason: Option<String>,
111}
112
113#[derive(Clone)]
114pub struct CompletionModel {
115 pub(crate) client: Client,
116 pub model: String,
118}
119
120impl CompletionModel {
121 pub fn new(client: Client, model: &str) -> Self {
122 Self {
123 client,
124 model: model.to_string(),
125 }
126 }
127
128 pub(crate) fn create_completion_request(
129 &self,
130 completion_request: CompletionRequest,
131 ) -> Result<Value, CompletionError> {
132 let mut full_history: Vec<Message> = match &completion_request.preamble {
134 Some(preamble) => vec![Message::system(preamble)],
135 None => vec![],
136 };
137
138 if let Some(docs) = completion_request.normalized_documents() {
140 let docs: Vec<Message> = docs.try_into()?;
141 full_history.extend(docs);
142 }
143
144 let chat_history: Vec<Message> = completion_request
146 .chat_history
147 .into_iter()
148 .map(|message| message.try_into())
149 .collect::<Result<Vec<Vec<Message>>, _>>()?
150 .into_iter()
151 .flatten()
152 .collect();
153
154 full_history.extend(chat_history);
156
157 let request = json!({
158 "model": self.model,
159 "messages": full_history,
160 "temperature": completion_request.temperature,
161 "tools": completion_request.tools.into_iter().map(crate::providers::openai::completion::ToolDefinition::from).collect::<Vec<_>>()
162 });
163
164 let request = if let Some(params) = completion_request.additional_params {
165 json_utils::merge(request, params)
166 } else {
167 request
168 };
169
170 Ok(request)
171 }
172}
173
174impl completion::CompletionModel for CompletionModel {
175 type Response = CompletionResponse;
176 type StreamingResponse = FinalCompletionResponse;
177
178 #[cfg_attr(feature = "worker", worker::send)]
179 async fn completion(
180 &self,
181 completion_request: CompletionRequest,
182 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
183 let request = self.create_completion_request(completion_request)?;
184
185 let response = self
186 .client
187 .post("/chat/completions")
188 .json(&request)
189 .send()
190 .await?;
191
192 if response.status().is_success() {
193 match response.json::<ApiResponse<CompletionResponse>>().await? {
194 ApiResponse::Ok(response) => {
195 tracing::info!(target: "rig",
196 "OpenRouter completion token usage: {:?}",
197 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
198 );
199 tracing::debug!(target: "rig",
200 "OpenRouter response: {response:?}");
201 response.try_into()
202 }
203 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
204 }
205 } else {
206 Err(CompletionError::ProviderError(response.text().await?))
207 }
208 }
209
210 #[cfg_attr(feature = "worker", worker::send)]
211 async fn stream(
212 &self,
213 completion_request: CompletionRequest,
214 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
215 CompletionModel::stream(self, completion_request).await
216 }
217}