rig/providers/openrouter/
completion.rs1use bytes::Bytes;
2use serde::{Deserialize, Serialize};
3use tracing::{Instrument, info_span};
4
5use super::client::{ApiErrorResponse, ApiResponse, Client, Usage};
6
7use crate::{
8 OneOrMany,
9 completion::{self, CompletionError, CompletionRequest},
10 http_client::HttpClientExt,
11 json_utils,
12 providers::openai::Message,
13};
14use serde_json::{Value, json};
15
16use crate::providers::openai::AssistantContent;
17use crate::providers::openrouter::streaming::FinalCompletionResponse;
18use crate::streaming::StreamingCompletionResponse;
19use crate::telemetry::SpanCombinator;
20
21pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
26pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
28pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
30pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
32
33#[derive(Debug, Serialize, Deserialize)]
37pub struct CompletionResponse {
38 pub id: String,
39 pub object: String,
40 pub created: u64,
41 pub model: String,
42 pub choices: Vec<Choice>,
43 pub system_fingerprint: Option<String>,
44 pub usage: Option<Usage>,
45}
46
47impl From<ApiErrorResponse> for CompletionError {
48 fn from(err: ApiErrorResponse) -> Self {
49 CompletionError::ProviderError(err.message)
50 }
51}
52
53impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
54 type Error = CompletionError;
55
56 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
57 let choice = response.choices.first().ok_or_else(|| {
58 CompletionError::ResponseError("Response contained no choices".to_owned())
59 })?;
60
61 let content = match &choice.message {
62 Message::Assistant {
63 content,
64 tool_calls,
65 ..
66 } => {
67 let mut content = content
68 .iter()
69 .map(|c| match c {
70 AssistantContent::Text { text } => completion::AssistantContent::text(text),
71 AssistantContent::Refusal { refusal } => {
72 completion::AssistantContent::text(refusal)
73 }
74 })
75 .collect::<Vec<_>>();
76
77 content.extend(
78 tool_calls
79 .iter()
80 .map(|call| {
81 completion::AssistantContent::tool_call(
82 &call.id,
83 &call.function.name,
84 call.function.arguments.clone(),
85 )
86 })
87 .collect::<Vec<_>>(),
88 );
89 Ok(content)
90 }
91 _ => Err(CompletionError::ResponseError(
92 "Response did not contain a valid message or tool call".into(),
93 )),
94 }?;
95
96 let choice = OneOrMany::many(content).map_err(|_| {
97 CompletionError::ResponseError(
98 "Response contained no message or tool call (empty)".to_owned(),
99 )
100 })?;
101
102 let usage = response
103 .usage
104 .as_ref()
105 .map(|usage| completion::Usage {
106 input_tokens: usage.prompt_tokens as u64,
107 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
108 total_tokens: usage.total_tokens as u64,
109 })
110 .unwrap_or_default();
111
112 Ok(completion::CompletionResponse {
113 choice,
114 usage,
115 raw_response: response,
116 })
117 }
118}
119
120#[derive(Debug, Deserialize, Serialize)]
121pub struct Choice {
122 pub index: usize,
123 pub native_finish_reason: Option<String>,
124 pub message: Message,
125 pub finish_reason: Option<String>,
126}
127
128#[derive(Debug, Serialize, Deserialize)]
129#[serde(untagged, rename_all = "snake_case")]
130pub enum ToolChoice {
131 None,
132 Auto,
133 Required,
134 Function(Vec<ToolChoiceFunctionKind>),
135}
136
137impl TryFrom<crate::message::ToolChoice> for ToolChoice {
138 type Error = CompletionError;
139
140 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
141 let res = match value {
142 crate::message::ToolChoice::None => Self::None,
143 crate::message::ToolChoice::Auto => Self::Auto,
144 crate::message::ToolChoice::Required => Self::Required,
145 crate::message::ToolChoice::Specific { function_names } => {
146 let vec: Vec<ToolChoiceFunctionKind> = function_names
147 .into_iter()
148 .map(|name| ToolChoiceFunctionKind::Function { name })
149 .collect();
150
151 Self::Function(vec)
152 }
153 };
154
155 Ok(res)
156 }
157}
158
159#[derive(Debug, Serialize, Deserialize)]
160#[serde(tag = "type", content = "function")]
161pub enum ToolChoiceFunctionKind {
162 Function { name: String },
163}
164
165#[derive(Clone)]
166pub struct CompletionModel<T = reqwest::Client> {
167 pub(crate) client: Client<T>,
168 pub model: String,
170}
171
172impl<T> CompletionModel<T> {
173 pub fn new(client: Client<T>, model: &str) -> Self {
174 Self {
175 client,
176 model: model.to_string(),
177 }
178 }
179
180 pub(crate) fn create_completion_request(
181 &self,
182 completion_request: CompletionRequest,
183 ) -> Result<Value, CompletionError> {
184 let mut full_history: Vec<Message> = match &completion_request.preamble {
186 Some(preamble) => vec![Message::system(preamble)],
187 None => vec![],
188 };
189
190 if let Some(docs) = completion_request.normalized_documents() {
192 let docs: Vec<Message> = docs.try_into()?;
193 full_history.extend(docs);
194 }
195
196 let chat_history: Vec<Message> = completion_request
198 .chat_history
199 .into_iter()
200 .map(|message| message.try_into())
201 .collect::<Result<Vec<Vec<Message>>, _>>()?
202 .into_iter()
203 .flatten()
204 .collect();
205
206 full_history.extend(chat_history);
208
209 let tool_choice = completion_request
210 .tool_choice
211 .map(ToolChoice::try_from)
212 .transpose()?;
213
214 let mut request = json!({
215 "model": self.model,
216 "messages": full_history,
217 });
218
219 if let Some(temperature) = completion_request.temperature {
220 request["temperature"] = json!(temperature);
221 }
222
223 if !completion_request.tools.is_empty() {
224 request["tools"] = json!(
225 completion_request
226 .tools
227 .into_iter()
228 .map(crate::providers::openai::completion::ToolDefinition::from)
229 .collect::<Vec<_>>()
230 );
231 request["tool_choice"] = json!(tool_choice);
232 }
233
234 let request = if let Some(params) = completion_request.additional_params {
235 json_utils::merge(request, params)
236 } else {
237 request
238 };
239
240 Ok(request)
241 }
242}
243
244impl<T> completion::CompletionModel for CompletionModel<T>
245where
246 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
247{
248 type Response = CompletionResponse;
249 type StreamingResponse = FinalCompletionResponse;
250
251 #[cfg_attr(feature = "worker", worker::send)]
252 async fn completion(
253 &self,
254 completion_request: CompletionRequest,
255 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
256 let preamble = completion_request.preamble.clone();
257 let request = self.create_completion_request(completion_request)?;
258 let span = if tracing::Span::current().is_disabled() {
259 info_span!(
260 target: "rig::completion",
261 "chat",
262 gen_ai.operation.name = "chat",
263 gen_ai.provider.name = "openrouter",
264 gen_ai.request.model = self.model,
265 gen_ai.system_instructions = preamble,
266 gen_ai.response.id = tracing::field::Empty,
267 gen_ai.response.model = tracing::field::Empty,
268 gen_ai.usage.output_tokens = tracing::field::Empty,
269 gen_ai.usage.input_tokens = tracing::field::Empty,
270 gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(),
271 gen_ai.output.messages = tracing::field::Empty,
272 )
273 } else {
274 tracing::Span::current()
275 };
276
277 let body = serde_json::to_vec(&request)?;
278
279 let req = self
280 .client
281 .post("/chat/completions")?
282 .header("Content-Type", "application/json")
283 .body(body)
284 .map_err(|x| CompletionError::HttpError(x.into()))?;
285
286 async move {
287 let response = self.client.http_client.send::<_, Bytes>(req).await?;
288 let status = response.status();
289 let response_body = response.into_body().into_future().await?.to_vec();
290
291 if status.is_success() {
292 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
293 ApiResponse::Ok(response) => {
294 let span = tracing::Span::current();
295 span.record_token_usage(&response.usage);
296 span.record_model_output(&response.choices);
297 span.record("gen_ai.response.id", &response.id);
298 span.record("gen_ai.response.model_name", &response.model);
299
300 tracing::debug!(target: "rig::completion",
301 "OpenRouter response: {response:?}");
302 response.try_into()
303 }
304 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
305 }
306 } else {
307 Err(CompletionError::ProviderError(
308 String::from_utf8_lossy(&response_body).to_string(),
309 ))
310 }
311 }
312 .instrument(span)
313 .await
314 }
315
316 #[cfg_attr(feature = "worker", worker::send)]
317 async fn stream(
318 &self,
319 completion_request: CompletionRequest,
320 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
321 CompletionModel::stream(self, completion_request).await
322 }
323}