rig/providers/openrouter/
client.rs

1use crate::{
2    client::{CompletionClient, ProviderClient, VerifyClient, VerifyError},
3    completion::GetTokenUsage,
4    http_client::{self, HttpClientExt},
5    impl_conversion_traits,
6};
7use serde::{Deserialize, Serialize};
8use std::fmt::Debug;
9
10use super::completion::CompletionModel;
11
12// ================================================================
13// Main openrouter Client
14// ================================================================
15const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
16
17pub struct ClientBuilder<'a, T = reqwest::Client> {
18    api_key: &'a str,
19    base_url: &'a str,
20    http_client: T,
21}
22
23impl<'a, T> ClientBuilder<'a, T>
24where
25    T: Default,
26{
27    pub fn new(api_key: &'a str) -> Self {
28        Self {
29            api_key,
30            base_url: OPENROUTER_API_BASE_URL,
31            http_client: Default::default(),
32        }
33    }
34}
35
36impl<'a, T> ClientBuilder<'a, T> {
37    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
38        ClientBuilder {
39            api_key: self.api_key,
40            base_url: self.base_url,
41            http_client,
42        }
43    }
44
45    pub fn base_url(mut self, base_url: &'a str) -> Self {
46        self.base_url = base_url;
47        self
48    }
49
50    pub fn build(self) -> Client<T> {
51        Client {
52            base_url: self.base_url.to_string(),
53            api_key: self.api_key.to_string(),
54            http_client: self.http_client,
55        }
56    }
57}
58
59#[derive(Clone)]
60pub struct Client<T = reqwest::Client> {
61    base_url: String,
62    api_key: String,
63    http_client: T,
64}
65
66impl<T> Debug for Client<T>
67where
68    T: Debug,
69{
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("Client")
72            .field("base_url", &self.base_url)
73            .field("http_client", &self.http_client)
74            .field("api_key", &"<REDACTED>")
75            .finish()
76    }
77}
78
79impl Client<reqwest::Client> {
80    pub(crate) fn reqwest_client(&self) -> &reqwest::Client {
81        &self.http_client
82    }
83
84    pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
85        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
86
87        self.http_client.post(url).bearer_auth(&self.api_key)
88    }
89}
90
91impl<T> Client<T>
92where
93    T: Default,
94{
95    /// Create a new OpenRouter client builder.
96    ///
97    /// # Example
98    /// ```
99    /// use rig::providers::openrouter::{ClientBuilder, self};
100    ///
101    /// // Initialize the OpenRouter client
102    /// let openrouter = Client::builder("your-openrouter-api-key")
103    ///    .build()
104    /// ```
105    pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
106        ClientBuilder::new(api_key)
107    }
108
109    /// Create a new OpenRouter client. For more control, use the `builder` method.
110    ///
111    /// # Panics
112    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
113    pub fn new(api_key: &str) -> Self {
114        Self::builder(api_key).build()
115    }
116}
117
118impl<T> Client<T> {
119    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
120        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
121
122        http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key)
123    }
124}
125
126impl ProviderClient for Client<reqwest::Client> {
127    /// Create a new openrouter client from the `OPENROUTER_API_KEY` environment variable.
128    /// Panics if the environment variable is not set.
129    fn from_env() -> Self {
130        let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
131        Self::new(&api_key)
132    }
133
134    fn from_val(input: crate::client::ProviderValue) -> Self {
135        let crate::client::ProviderValue::Simple(api_key) = input else {
136            panic!("Incorrect provider value type")
137        };
138        Self::new(&api_key)
139    }
140}
141
142impl CompletionClient for Client<reqwest::Client> {
143    type CompletionModel = CompletionModel<reqwest::Client>;
144
145    /// Create a completion model with the given name.
146    ///
147    /// # Example
148    /// ```
149    /// use rig::providers::openrouter::{Client, self};
150    ///
151    /// // Initialize the openrouter client
152    /// let openrouter = Client::new("your-openrouter-api-key");
153    ///
154    /// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
155    /// ```
156    fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
157        CompletionModel::new(self.clone(), model)
158    }
159}
160
161impl VerifyClient for Client<reqwest::Client> {
162    #[cfg_attr(feature = "worker", worker::send)]
163    async fn verify(&self) -> Result<(), VerifyError> {
164        let req = self
165            .get("/key")?
166            .body(http_client::NoBody)
167            .map_err(|e| VerifyError::HttpError(e.into()))?;
168
169        let response = HttpClientExt::send(&self.http_client, req).await?;
170
171        match response.status() {
172            reqwest::StatusCode::OK => Ok(()),
173            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
174            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
175                let text = http_client::text(response).await?;
176                Err(VerifyError::ProviderError(text))
177            }
178            _ => {
179                //response.error_for_status()?;
180                Ok(())
181            }
182        }
183    }
184}
185
186impl_conversion_traits!(
187    AsEmbeddings,
188    AsTranscription,
189    AsImageGeneration,
190    AsAudioGeneration for Client<T>
191);
192
193#[derive(Debug, Deserialize)]
194pub(crate) struct ApiErrorResponse {
195    pub message: String,
196}
197
198#[derive(Debug, Deserialize)]
199#[serde(untagged)]
200pub(crate) enum ApiResponse<T> {
201    Ok(T),
202    Err(ApiErrorResponse),
203}
204
205#[derive(Clone, Debug, Deserialize, Serialize)]
206pub struct Usage {
207    pub prompt_tokens: usize,
208    pub completion_tokens: usize,
209    pub total_tokens: usize,
210}
211
212impl std::fmt::Display for Usage {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        write!(
215            f,
216            "Prompt tokens: {} Total tokens: {}",
217            self.prompt_tokens, self.total_tokens
218        )
219    }
220}
221
222impl GetTokenUsage for Usage {
223    fn token_usage(&self) -> Option<crate::completion::Usage> {
224        let mut usage = crate::completion::Usage::new();
225
226        usage.input_tokens = self.prompt_tokens as u64;
227        usage.output_tokens = self.completion_tokens as u64;
228        usage.total_tokens = self.total_tokens as u64;
229
230        Some(usage)
231    }
232}