rig/providers/openrouter/
client.rs1use crate::{
2 client::{CompletionClient, ProviderClient, VerifyClient, VerifyError},
3 completion::GetTokenUsage,
4 http_client::{self, HttpClientExt},
5 impl_conversion_traits,
6};
7use http::Method;
8use serde::{Deserialize, Serialize};
9use std::fmt::Debug;
10
11use super::completion::CompletionModel;
12
13const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
17
18pub struct ClientBuilder<'a, T = reqwest::Client> {
19 api_key: &'a str,
20 base_url: &'a str,
21 http_client: T,
22}
23
24impl<'a, T> ClientBuilder<'a, T>
25where
26 T: Default,
27{
28 pub fn new(api_key: &'a str) -> Self {
29 Self {
30 api_key,
31 base_url: OPENROUTER_API_BASE_URL,
32 http_client: Default::default(),
33 }
34 }
35}
36
37impl<'a, T> ClientBuilder<'a, T> {
38 pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
39 Self {
40 api_key,
41 base_url: OPENROUTER_API_BASE_URL,
42 http_client,
43 }
44 }
45 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
46 ClientBuilder {
47 api_key: self.api_key,
48 base_url: self.base_url,
49 http_client,
50 }
51 }
52
53 pub fn base_url(mut self, base_url: &'a str) -> Self {
54 self.base_url = base_url;
55 self
56 }
57
58 pub fn build(self) -> Client<T> {
59 Client {
60 base_url: self.base_url.to_string(),
61 api_key: self.api_key.to_string(),
62 http_client: self.http_client,
63 }
64 }
65}
66
67#[derive(Clone)]
68pub struct Client<T = reqwest::Client> {
69 base_url: String,
70 api_key: String,
71 pub http_client: T,
72}
73
74impl<T> Debug for Client<T>
75where
76 T: Debug,
77{
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("Client")
80 .field("base_url", &self.base_url)
81 .field("http_client", &self.http_client)
82 .field("api_key", &"<REDACTED>")
83 .finish()
84 }
85}
86
87impl<T> Client<T> {
88 pub(crate) fn req(
89 &self,
90 method: Method,
91 path: &str,
92 ) -> http_client::Result<http_client::Builder> {
93 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
94 let req = http_client::Request::builder().uri(url).method(method);
95
96 http_client::with_bearer_auth(req, &self.api_key)
97 }
98
99 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
100 self.req(Method::GET, path)
101 }
102
103 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
104 self.req(Method::POST, path)
105 }
106}
107
108impl Client<reqwest::Client> {
109 pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
120 ClientBuilder::new(api_key)
121 }
122
123 pub fn new(api_key: &str) -> Self {
126 Self::builder(api_key).build()
127 }
128
129 pub fn from_env() -> Self {
130 <Self as ProviderClient>::from_env()
131 }
132}
133
134impl<T> ProviderClient for Client<T>
135where
136 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
137{
138 fn from_env() -> Self {
141 let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
142 ClientBuilder::<T>::new(&api_key).build()
143 }
144
145 fn from_val(input: crate::client::ProviderValue) -> Self {
146 let crate::client::ProviderValue::Simple(api_key) = input else {
147 panic!("Incorrect provider value type")
148 };
149 ClientBuilder::<T>::new(&api_key).build()
150 }
151}
152
153impl<T> CompletionClient for Client<T>
154where
155 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
156{
157 type CompletionModel = CompletionModel<T>;
158
159 fn completion_model(&self, model: &str) -> CompletionModel<T> {
171 CompletionModel::new(self.clone(), model)
172 }
173}
174
175impl<T> VerifyClient for Client<T>
176where
177 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
178{
179 #[cfg_attr(feature = "worker", worker::send)]
180 async fn verify(&self) -> Result<(), VerifyError> {
181 let req = self
182 .get("/key")?
183 .body(http_client::NoBody)
184 .map_err(|e| VerifyError::HttpError(e.into()))?;
185
186 let response = HttpClientExt::send(&self.http_client, req).await?;
187
188 match response.status() {
189 reqwest::StatusCode::OK => Ok(()),
190 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
191 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
192 let text = http_client::text(response).await?;
193 Err(VerifyError::ProviderError(text))
194 }
195 _ => {
196 Ok(())
198 }
199 }
200 }
201}
202
203impl_conversion_traits!(
204 AsEmbeddings,
205 AsTranscription,
206 AsImageGeneration,
207 AsAudioGeneration for Client<T>
208);
209
210#[derive(Debug, Deserialize)]
211pub(crate) struct ApiErrorResponse {
212 pub message: String,
213}
214
215#[derive(Debug, Deserialize)]
216#[serde(untagged)]
217pub(crate) enum ApiResponse<T> {
218 Ok(T),
219 Err(ApiErrorResponse),
220}
221
222#[derive(Clone, Debug, Deserialize, Serialize)]
223pub struct Usage {
224 pub prompt_tokens: usize,
225 pub completion_tokens: usize,
226 pub total_tokens: usize,
227}
228
229impl std::fmt::Display for Usage {
230 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 write!(
232 f,
233 "Prompt tokens: {} Total tokens: {}",
234 self.prompt_tokens, self.total_tokens
235 )
236 }
237}
238
239impl GetTokenUsage for Usage {
240 fn token_usage(&self) -> Option<crate::completion::Usage> {
241 let mut usage = crate::completion::Usage::new();
242
243 usage.input_tokens = self.prompt_tokens as u64;
244 usage.output_tokens = self.completion_tokens as u64;
245 usage.total_tokens = self.total_tokens as u64;
246
247 Some(usage)
248 }
249}