rig/providers/openrouter/
completion.rs1use serde::Deserialize;
2
3use super::client::{ApiErrorResponse, ApiResponse, Client, Usage};
4
5use crate::{
6 OneOrMany,
7 completion::{self, CompletionError, CompletionRequest},
8 json_utils,
9 providers::openai::Message,
10};
11use serde_json::{Value, json};
12
13use crate::providers::openai::AssistantContent;
14use crate::providers::openrouter::streaming::FinalCompletionResponse;
15use crate::streaming::StreamingCompletionResponse;
16
17pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
22pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
24pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
26pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
28
29#[derive(Debug, Deserialize)]
33pub struct CompletionResponse {
34 pub id: String,
35 pub object: String,
36 pub created: u64,
37 pub model: String,
38 pub choices: Vec<Choice>,
39 pub system_fingerprint: Option<String>,
40 pub usage: Option<Usage>,
41}
42
43impl From<ApiErrorResponse> for CompletionError {
44 fn from(err: ApiErrorResponse) -> Self {
45 CompletionError::ProviderError(err.message)
46 }
47}
48
49impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
50 type Error = CompletionError;
51
52 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
53 let choice = response.choices.first().ok_or_else(|| {
54 CompletionError::ResponseError("Response contained no choices".to_owned())
55 })?;
56
57 let content = match &choice.message {
58 Message::Assistant {
59 content,
60 tool_calls,
61 ..
62 } => {
63 let mut content = content
64 .iter()
65 .map(|c| match c {
66 AssistantContent::Text { text } => completion::AssistantContent::text(text),
67 AssistantContent::Refusal { refusal } => {
68 completion::AssistantContent::text(refusal)
69 }
70 })
71 .collect::<Vec<_>>();
72
73 content.extend(
74 tool_calls
75 .iter()
76 .map(|call| {
77 completion::AssistantContent::tool_call(
78 &call.id,
79 &call.function.name,
80 call.function.arguments.clone(),
81 )
82 })
83 .collect::<Vec<_>>(),
84 );
85 Ok(content)
86 }
87 _ => Err(CompletionError::ResponseError(
88 "Response did not contain a valid message or tool call".into(),
89 )),
90 }?;
91
92 let choice = OneOrMany::many(content).map_err(|_| {
93 CompletionError::ResponseError(
94 "Response contained no message or tool call (empty)".to_owned(),
95 )
96 })?;
97
98 let usage = response
99 .usage
100 .as_ref()
101 .map(|usage| completion::Usage {
102 input_tokens: usage.prompt_tokens as u64,
103 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
104 total_tokens: usage.total_tokens as u64,
105 })
106 .unwrap_or_default();
107
108 Ok(completion::CompletionResponse {
109 choice,
110 usage,
111 raw_response: response,
112 })
113 }
114}
115
116#[derive(Debug, Deserialize)]
117pub struct Choice {
118 pub index: usize,
119 pub native_finish_reason: Option<String>,
120 pub message: Message,
121 pub finish_reason: Option<String>,
122}
123
124#[derive(Clone)]
125pub struct CompletionModel {
126 pub(crate) client: Client,
127 pub model: String,
129}
130
131impl CompletionModel {
132 pub fn new(client: Client, model: &str) -> Self {
133 Self {
134 client,
135 model: model.to_string(),
136 }
137 }
138
139 pub(crate) fn create_completion_request(
140 &self,
141 completion_request: CompletionRequest,
142 ) -> Result<Value, CompletionError> {
143 let mut full_history: Vec<Message> = match &completion_request.preamble {
145 Some(preamble) => vec![Message::system(preamble)],
146 None => vec![],
147 };
148
149 if let Some(docs) = completion_request.normalized_documents() {
151 let docs: Vec<Message> = docs.try_into()?;
152 full_history.extend(docs);
153 }
154
155 let chat_history: Vec<Message> = completion_request
157 .chat_history
158 .into_iter()
159 .map(|message| message.try_into())
160 .collect::<Result<Vec<Vec<Message>>, _>>()?
161 .into_iter()
162 .flatten()
163 .collect();
164
165 full_history.extend(chat_history);
167
168 let request = json!({
169 "model": self.model,
170 "messages": full_history,
171 "temperature": completion_request.temperature,
172 "tools": completion_request.tools.into_iter().map(crate::providers::openai::completion::ToolDefinition::from).collect::<Vec<_>>()
173 });
174
175 let request = if let Some(params) = completion_request.additional_params {
176 json_utils::merge(request, params)
177 } else {
178 request
179 };
180
181 Ok(request)
182 }
183}
184
185impl completion::CompletionModel for CompletionModel {
186 type Response = CompletionResponse;
187 type StreamingResponse = FinalCompletionResponse;
188
189 #[cfg_attr(feature = "worker", worker::send)]
190 async fn completion(
191 &self,
192 completion_request: CompletionRequest,
193 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
194 let request = self.create_completion_request(completion_request)?;
195
196 let response = self
197 .client
198 .post("/chat/completions")
199 .json(&request)
200 .send()
201 .await?;
202
203 if response.status().is_success() {
204 match response.json::<ApiResponse<CompletionResponse>>().await? {
205 ApiResponse::Ok(response) => {
206 tracing::info!(target: "rig",
207 "OpenRouter completion token usage: {:?}",
208 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
209 );
210 tracing::debug!(target: "rig",
211 "OpenRouter response: {response:?}");
212 response.try_into()
213 }
214 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
215 }
216 } else {
217 Err(CompletionError::ProviderError(response.text().await?))
218 }
219 }
220
221 #[cfg_attr(feature = "worker", worker::send)]
222 async fn stream(
223 &self,
224 completion_request: CompletionRequest,
225 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
226 CompletionModel::stream(self, completion_request).await
227 }
228}