rig/providers/openrouter/
client.rs

1use crate::{
2    client::{CompletionClient, ProviderClient, VerifyClient, VerifyError},
3    completion::GetTokenUsage,
4    http_client::{self, HttpClientExt},
5    impl_conversion_traits,
6};
7use http::Method;
8use serde::{Deserialize, Serialize};
9use std::fmt::Debug;
10
11use super::completion::CompletionModel;
12
13// ================================================================
14// Main openrouter Client
15// ================================================================
16const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
17
18pub struct ClientBuilder<'a, T = reqwest::Client> {
19    api_key: &'a str,
20    base_url: &'a str,
21    http_client: T,
22}
23
24impl<'a, T> ClientBuilder<'a, T>
25where
26    T: Default,
27{
28    pub fn new(api_key: &'a str) -> Self {
29        Self {
30            api_key,
31            base_url: OPENROUTER_API_BASE_URL,
32            http_client: Default::default(),
33        }
34    }
35}
36
37impl<'a, T> ClientBuilder<'a, T> {
38    pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
39        Self {
40            api_key,
41            base_url: OPENROUTER_API_BASE_URL,
42            http_client,
43        }
44    }
45    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
46        ClientBuilder {
47            api_key: self.api_key,
48            base_url: self.base_url,
49            http_client,
50        }
51    }
52
53    pub fn base_url(mut self, base_url: &'a str) -> Self {
54        self.base_url = base_url;
55        self
56    }
57
58    pub fn build(self) -> Client<T> {
59        Client {
60            base_url: self.base_url.to_string(),
61            api_key: self.api_key.to_string(),
62            http_client: self.http_client,
63        }
64    }
65}
66
67#[derive(Clone)]
68pub struct Client<T = reqwest::Client> {
69    base_url: String,
70    api_key: String,
71    pub http_client: T,
72}
73
74impl<T> Debug for Client<T>
75where
76    T: Debug,
77{
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("Client")
80            .field("base_url", &self.base_url)
81            .field("http_client", &self.http_client)
82            .field("api_key", &"<REDACTED>")
83            .finish()
84    }
85}
86
87impl<T> Client<T> {
88    pub(crate) fn req(
89        &self,
90        method: Method,
91        path: &str,
92    ) -> http_client::Result<http_client::Builder> {
93        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
94        let req = http_client::Request::builder().uri(url).method(method);
95
96        http_client::with_bearer_auth(req, &self.api_key)
97    }
98
99    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
100        self.req(Method::GET, path)
101    }
102
103    pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
104        self.req(Method::POST, path)
105    }
106}
107
108impl Client<reqwest::Client> {
109    /// Create a new OpenRouter client builder.
110    ///
111    /// # Example
112    /// ```
113    /// use rig::providers::openrouter::{ClientBuilder, self};
114    ///
115    /// // Initialize the OpenAI client
116    /// let openai_client = Client::builder("your-open-ai-api-key")
117    ///    .build()
118    /// ```
119    pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
120        ClientBuilder::new(api_key)
121    }
122
123    /// Create a new OpenAI client. For more control, use the `builder` method.
124    ///
125    pub fn new(api_key: &str) -> Self {
126        Self::builder(api_key).build()
127    }
128
129    pub fn from_env() -> Self {
130        <Self as ProviderClient>::from_env()
131    }
132}
133
134impl<T> ProviderClient for Client<T>
135where
136    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
137{
138    /// Create a new openrouter client from the `OPENROUTER_API_KEY` environment variable.
139    /// Panics if the environment variable is not set.
140    fn from_env() -> Self {
141        let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
142        ClientBuilder::<T>::new(&api_key).build()
143    }
144
145    fn from_val(input: crate::client::ProviderValue) -> Self {
146        let crate::client::ProviderValue::Simple(api_key) = input else {
147            panic!("Incorrect provider value type")
148        };
149        ClientBuilder::<T>::new(&api_key).build()
150    }
151}
152
153impl<T> CompletionClient for Client<T>
154where
155    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
156{
157    type CompletionModel = CompletionModel<T>;
158
159    /// Create a completion model with the given name.
160    ///
161    /// # Example
162    /// ```
163    /// use rig::providers::openrouter::{Client, self};
164    ///
165    /// // Initialize the openrouter client
166    /// let openrouter = Client::new("your-openrouter-api-key");
167    ///
168    /// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
169    /// ```
170    fn completion_model(&self, model: &str) -> CompletionModel<T> {
171        CompletionModel::new(self.clone(), model)
172    }
173}
174
175impl<T> VerifyClient for Client<T>
176where
177    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
178{
179    #[cfg_attr(feature = "worker", worker::send)]
180    async fn verify(&self) -> Result<(), VerifyError> {
181        let req = self
182            .get("/key")?
183            .body(http_client::NoBody)
184            .map_err(|e| VerifyError::HttpError(e.into()))?;
185
186        let response = HttpClientExt::send(&self.http_client, req).await?;
187
188        match response.status() {
189            reqwest::StatusCode::OK => Ok(()),
190            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
191            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
192                let text = http_client::text(response).await?;
193                Err(VerifyError::ProviderError(text))
194            }
195            _ => {
196                //response.error_for_status()?;
197                Ok(())
198            }
199        }
200    }
201}
202
203impl_conversion_traits!(
204    AsEmbeddings,
205    AsTranscription,
206    AsImageGeneration,
207    AsAudioGeneration for Client<T>
208);
209
210#[derive(Debug, Deserialize)]
211pub(crate) struct ApiErrorResponse {
212    pub message: String,
213}
214
215#[derive(Debug, Deserialize)]
216#[serde(untagged)]
217pub(crate) enum ApiResponse<T> {
218    Ok(T),
219    Err(ApiErrorResponse),
220}
221
222#[derive(Clone, Debug, Deserialize, Serialize)]
223pub struct Usage {
224    pub prompt_tokens: usize,
225    pub completion_tokens: usize,
226    pub total_tokens: usize,
227}
228
229impl std::fmt::Display for Usage {
230    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231        write!(
232            f,
233            "Prompt tokens: {} Total tokens: {}",
234            self.prompt_tokens, self.total_tokens
235        )
236    }
237}
238
239impl GetTokenUsage for Usage {
240    fn token_usage(&self) -> Option<crate::completion::Usage> {
241        let mut usage = crate::completion::Usage::new();
242
243        usage.input_tokens = self.prompt_tokens as u64;
244        usage.output_tokens = self.completion_tokens as u64;
245        usage.total_tokens = self.total_tokens as u64;
246
247        Some(usage)
248    }
249}