rig/providers/openrouter/
client.rs1use crate::{
2 client::{CompletionClient, ProviderClient, VerifyClient, VerifyError},
3 completion::GetTokenUsage,
4 http_client::{self, HttpClientExt},
5 impl_conversion_traits,
6};
7use serde::{Deserialize, Serialize};
8use std::fmt::Debug;
9
10use super::completion::CompletionModel;
11
12const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
16
17pub struct ClientBuilder<'a, T = reqwest::Client> {
18 api_key: &'a str,
19 base_url: &'a str,
20 http_client: T,
21}
22
23impl<'a, T> ClientBuilder<'a, T>
24where
25 T: Default,
26{
27 pub fn new(api_key: &'a str) -> Self {
28 Self {
29 api_key,
30 base_url: OPENROUTER_API_BASE_URL,
31 http_client: Default::default(),
32 }
33 }
34}
35
36impl<'a, T> ClientBuilder<'a, T> {
37 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
38 ClientBuilder {
39 api_key: self.api_key,
40 base_url: self.base_url,
41 http_client,
42 }
43 }
44
45 pub fn base_url(mut self, base_url: &'a str) -> Self {
46 self.base_url = base_url;
47 self
48 }
49
50 pub fn build(self) -> Client<T> {
51 Client {
52 base_url: self.base_url.to_string(),
53 api_key: self.api_key.to_string(),
54 http_client: self.http_client,
55 }
56 }
57}
58
59#[derive(Clone)]
60pub struct Client<T = reqwest::Client> {
61 base_url: String,
62 api_key: String,
63 http_client: T,
64}
65
66impl<T> Debug for Client<T>
67where
68 T: Debug,
69{
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("Client")
72 .field("base_url", &self.base_url)
73 .field("http_client", &self.http_client)
74 .field("api_key", &"<REDACTED>")
75 .finish()
76 }
77}
78
79impl Client<reqwest::Client> {
80 pub(crate) fn reqwest_client(&self) -> &reqwest::Client {
81 &self.http_client
82 }
83
84 pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
85 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
86
87 self.http_client.post(url).bearer_auth(&self.api_key)
88 }
89}
90
91impl<T> Client<T>
92where
93 T: Default,
94{
95 pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
106 ClientBuilder::new(api_key)
107 }
108
109 pub fn new(api_key: &str) -> Self {
114 Self::builder(api_key).build()
115 }
116}
117
118impl<T> Client<T> {
119 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
120 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
121
122 http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key)
123 }
124}
125
126impl ProviderClient for Client<reqwest::Client> {
127 fn from_env() -> Self {
130 let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
131 Self::new(&api_key)
132 }
133
134 fn from_val(input: crate::client::ProviderValue) -> Self {
135 let crate::client::ProviderValue::Simple(api_key) = input else {
136 panic!("Incorrect provider value type")
137 };
138 Self::new(&api_key)
139 }
140}
141
142impl CompletionClient for Client<reqwest::Client> {
143 type CompletionModel = CompletionModel<reqwest::Client>;
144
145 fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
157 CompletionModel::new(self.clone(), model)
158 }
159}
160
161impl VerifyClient for Client<reqwest::Client> {
162 #[cfg_attr(feature = "worker", worker::send)]
163 async fn verify(&self) -> Result<(), VerifyError> {
164 let req = self
165 .get("/key")?
166 .body(http_client::NoBody)
167 .map_err(|e| VerifyError::HttpError(e.into()))?;
168
169 let response = HttpClientExt::send(&self.http_client, req).await?;
170
171 match response.status() {
172 reqwest::StatusCode::OK => Ok(()),
173 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
174 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
175 let text = http_client::text(response).await?;
176 Err(VerifyError::ProviderError(text))
177 }
178 _ => {
179 Ok(())
181 }
182 }
183 }
184}
185
186impl_conversion_traits!(
187 AsEmbeddings,
188 AsTranscription,
189 AsImageGeneration,
190 AsAudioGeneration for Client<T>
191);
192
193#[derive(Debug, Deserialize)]
194pub(crate) struct ApiErrorResponse {
195 pub message: String,
196}
197
198#[derive(Debug, Deserialize)]
199#[serde(untagged)]
200pub(crate) enum ApiResponse<T> {
201 Ok(T),
202 Err(ApiErrorResponse),
203}
204
205#[derive(Clone, Debug, Deserialize, Serialize)]
206pub struct Usage {
207 pub prompt_tokens: usize,
208 pub completion_tokens: usize,
209 pub total_tokens: usize,
210}
211
212impl std::fmt::Display for Usage {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 write!(
215 f,
216 "Prompt tokens: {} Total tokens: {}",
217 self.prompt_tokens, self.total_tokens
218 )
219 }
220}
221
222impl GetTokenUsage for Usage {
223 fn token_usage(&self) -> Option<crate::completion::Usage> {
224 let mut usage = crate::completion::Usage::new();
225
226 usage.input_tokens = self.prompt_tokens as u64;
227 usage.output_tokens = self.completion_tokens as u64;
228 usage.total_tokens = self.total_tokens as u64;
229
230 Some(usage)
231 }
232}