rig/providers/openrouter/
completion.rs1use serde::{Deserialize, Serialize};
2use tracing::{Instrument, info_span};
3
4use super::client::{ApiErrorResponse, ApiResponse, Client, Usage};
5
6use crate::{
7 OneOrMany,
8 completion::{self, CompletionError, CompletionRequest},
9 http_client, json_utils,
10 providers::openai::Message,
11};
12use serde_json::{Value, json};
13
14use crate::providers::openai::AssistantContent;
15use crate::providers::openrouter::streaming::FinalCompletionResponse;
16use crate::streaming::StreamingCompletionResponse;
17use crate::telemetry::SpanCombinator;
18
19pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
24pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
26pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
28pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
30
31#[derive(Debug, Serialize, Deserialize)]
35pub struct CompletionResponse {
36 pub id: String,
37 pub object: String,
38 pub created: u64,
39 pub model: String,
40 pub choices: Vec<Choice>,
41 pub system_fingerprint: Option<String>,
42 pub usage: Option<Usage>,
43}
44
45impl From<ApiErrorResponse> for CompletionError {
46 fn from(err: ApiErrorResponse) -> Self {
47 CompletionError::ProviderError(err.message)
48 }
49}
50
51impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
52 type Error = CompletionError;
53
54 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
55 let choice = response.choices.first().ok_or_else(|| {
56 CompletionError::ResponseError("Response contained no choices".to_owned())
57 })?;
58
59 let content = match &choice.message {
60 Message::Assistant {
61 content,
62 tool_calls,
63 ..
64 } => {
65 let mut content = content
66 .iter()
67 .map(|c| match c {
68 AssistantContent::Text { text } => completion::AssistantContent::text(text),
69 AssistantContent::Refusal { refusal } => {
70 completion::AssistantContent::text(refusal)
71 }
72 })
73 .collect::<Vec<_>>();
74
75 content.extend(
76 tool_calls
77 .iter()
78 .map(|call| {
79 completion::AssistantContent::tool_call(
80 &call.id,
81 &call.function.name,
82 call.function.arguments.clone(),
83 )
84 })
85 .collect::<Vec<_>>(),
86 );
87 Ok(content)
88 }
89 _ => Err(CompletionError::ResponseError(
90 "Response did not contain a valid message or tool call".into(),
91 )),
92 }?;
93
94 let choice = OneOrMany::many(content).map_err(|_| {
95 CompletionError::ResponseError(
96 "Response contained no message or tool call (empty)".to_owned(),
97 )
98 })?;
99
100 let usage = response
101 .usage
102 .as_ref()
103 .map(|usage| completion::Usage {
104 input_tokens: usage.prompt_tokens as u64,
105 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
106 total_tokens: usage.total_tokens as u64,
107 })
108 .unwrap_or_default();
109
110 Ok(completion::CompletionResponse {
111 choice,
112 usage,
113 raw_response: response,
114 })
115 }
116}
117
118#[derive(Debug, Deserialize, Serialize)]
119pub struct Choice {
120 pub index: usize,
121 pub native_finish_reason: Option<String>,
122 pub message: Message,
123 pub finish_reason: Option<String>,
124}
125
126#[derive(Debug, Serialize, Deserialize)]
127#[serde(untagged, rename_all = "snake_case")]
128pub enum ToolChoice {
129 None,
130 Auto,
131 Required,
132 Function(Vec<ToolChoiceFunctionKind>),
133}
134
135impl TryFrom<crate::message::ToolChoice> for ToolChoice {
136 type Error = CompletionError;
137
138 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
139 let res = match value {
140 crate::message::ToolChoice::None => Self::None,
141 crate::message::ToolChoice::Auto => Self::Auto,
142 crate::message::ToolChoice::Required => Self::Required,
143 crate::message::ToolChoice::Specific { function_names } => {
144 let vec: Vec<ToolChoiceFunctionKind> = function_names
145 .into_iter()
146 .map(|name| ToolChoiceFunctionKind::Function { name })
147 .collect();
148
149 Self::Function(vec)
150 }
151 };
152
153 Ok(res)
154 }
155}
156
157#[derive(Debug, Serialize, Deserialize)]
158#[serde(tag = "type", content = "function")]
159pub enum ToolChoiceFunctionKind {
160 Function { name: String },
161}
162
163#[derive(Clone)]
164pub struct CompletionModel<T = reqwest::Client> {
165 pub(crate) client: Client<T>,
166 pub model: String,
168}
169
170impl<T> CompletionModel<T> {
171 pub fn new(client: Client<T>, model: &str) -> Self {
172 Self {
173 client,
174 model: model.to_string(),
175 }
176 }
177
178 pub(crate) fn create_completion_request(
179 &self,
180 completion_request: CompletionRequest,
181 ) -> Result<Value, CompletionError> {
182 let mut full_history: Vec<Message> = match &completion_request.preamble {
184 Some(preamble) => vec![Message::system(preamble)],
185 None => vec![],
186 };
187
188 if let Some(docs) = completion_request.normalized_documents() {
190 let docs: Vec<Message> = docs.try_into()?;
191 full_history.extend(docs);
192 }
193
194 let chat_history: Vec<Message> = completion_request
196 .chat_history
197 .into_iter()
198 .map(|message| message.try_into())
199 .collect::<Result<Vec<Vec<Message>>, _>>()?
200 .into_iter()
201 .flatten()
202 .collect();
203
204 full_history.extend(chat_history);
206
207 let tool_choice = completion_request
208 .tool_choice
209 .map(ToolChoice::try_from)
210 .transpose()?;
211
212 let request = json!({
213 "model": self.model,
214 "messages": full_history,
215 "temperature": completion_request.temperature,
216 "tools": completion_request.tools.into_iter().map(crate::providers::openai::completion::ToolDefinition::from).collect::<Vec<_>>(),
217 "tool_choice": tool_choice,
218 });
219
220 let request = if let Some(params) = completion_request.additional_params {
221 json_utils::merge(request, params)
222 } else {
223 request
224 };
225
226 Ok(request)
227 }
228}
229
230impl completion::CompletionModel for CompletionModel<reqwest::Client> {
231 type Response = CompletionResponse;
232 type StreamingResponse = FinalCompletionResponse;
233
234 #[cfg_attr(feature = "worker", worker::send)]
235 async fn completion(
236 &self,
237 completion_request: CompletionRequest,
238 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
239 let preamble = completion_request.preamble.clone();
240 let request = self.create_completion_request(completion_request)?;
241 let span = if tracing::Span::current().is_disabled() {
242 info_span!(
243 target: "rig::completion",
244 "chat",
245 gen_ai.operation.name = "chat",
246 gen_ai.provider.name = "openrouter",
247 gen_ai.request.model = self.model,
248 gen_ai.system_instructions = preamble,
249 gen_ai.response.id = tracing::field::Empty,
250 gen_ai.response.model = tracing::field::Empty,
251 gen_ai.usage.output_tokens = tracing::field::Empty,
252 gen_ai.usage.input_tokens = tracing::field::Empty,
253 gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(),
254 gen_ai.output.messages = tracing::field::Empty,
255 )
256 } else {
257 tracing::Span::current()
258 };
259
260 async move {
261 let response = self
262 .client
263 .reqwest_client()
264 .post("/chat/completions")
265 .json(&request)
266 .send()
267 .await
268 .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?;
269
270 if response.status().is_success() {
271 match response
272 .json::<ApiResponse<CompletionResponse>>()
273 .await
274 .map_err(|e| {
275 CompletionError::HttpError(http_client::Error::Instance(e.into()))
276 })? {
277 ApiResponse::Ok(response) => {
278 let span = tracing::Span::current();
279 span.record_token_usage(&response.usage);
280 span.record_model_output(&response.choices);
281 span.record("gen_ai.response.id", &response.id);
282 span.record("gen_ai.response.model_name", &response.model);
283
284 tracing::debug!(target: "rig::completion",
285 "OpenRouter response: {response:?}");
286 response.try_into()
287 }
288 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
289 }
290 } else {
291 Err(CompletionError::ProviderError(
292 response.text().await.map_err(|e| {
293 CompletionError::HttpError(http_client::Error::Instance(e.into()))
294 })?,
295 ))
296 }
297 }
298 .instrument(span)
299 .await
300 }
301
302 #[cfg_attr(feature = "worker", worker::send)]
303 async fn stream(
304 &self,
305 completion_request: CompletionRequest,
306 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
307 CompletionModel::stream(self, completion_request).await
308 }
309}