rig/providers/openrouter/
client.rs

1use crate::{
2    client::{ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError},
3    impl_conversion_traits,
4};
5use serde::{Deserialize, Serialize};
6
7use super::completion::CompletionModel;
8
9// ================================================================
10// Main openrouter Client
11// ================================================================
12const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
13
14pub struct ClientBuilder<'a> {
15    api_key: &'a str,
16    base_url: &'a str,
17    http_client: Option<reqwest::Client>,
18}
19
20impl<'a> ClientBuilder<'a> {
21    pub fn new(api_key: &'a str) -> Self {
22        Self {
23            api_key,
24            base_url: OPENROUTER_API_BASE_URL,
25            http_client: None,
26        }
27    }
28
29    pub fn base_url(mut self, base_url: &'a str) -> Self {
30        self.base_url = base_url;
31        self
32    }
33
34    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
35        self.http_client = Some(client);
36        self
37    }
38
39    pub fn build(self) -> Result<Client, ClientBuilderError> {
40        let http_client = if let Some(http_client) = self.http_client {
41            http_client
42        } else {
43            reqwest::Client::builder().build()?
44        };
45
46        Ok(Client {
47            base_url: self.base_url.to_string(),
48            api_key: self.api_key.to_string(),
49            http_client,
50        })
51    }
52}
53
54#[derive(Clone)]
55pub struct Client {
56    base_url: String,
57    api_key: String,
58    http_client: reqwest::Client,
59}
60
61impl std::fmt::Debug for Client {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        f.debug_struct("Client")
64            .field("base_url", &self.base_url)
65            .field("http_client", &self.http_client)
66            .field("api_key", &"<REDACTED>")
67            .finish()
68    }
69}
70
71impl Client {
72    /// Create a new OpenRouter client builder.
73    ///
74    /// # Example
75    /// ```
76    /// use rig::providers::openrouter::{ClientBuilder, self};
77    ///
78    /// // Initialize the OpenRouter client
79    /// let openrouter = Client::builder("your-openrouter-api-key")
80    ///    .build()
81    /// ```
82    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
83        ClientBuilder::new(api_key)
84    }
85
86    /// Create a new OpenRouter client. For more control, use the `builder` method.
87    ///
88    /// # Panics
89    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
90    pub fn new(api_key: &str) -> Self {
91        Self::builder(api_key)
92            .build()
93            .expect("OpenRouter client should build")
94    }
95
96    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
97        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
98        self.http_client.post(url).bearer_auth(&self.api_key)
99    }
100
101    pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
102        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
103        self.http_client.get(url).bearer_auth(&self.api_key)
104    }
105}
106
107impl ProviderClient for Client {
108    /// Create a new openrouter client from the `OPENROUTER_API_KEY` environment variable.
109    /// Panics if the environment variable is not set.
110    fn from_env() -> Self {
111        let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
112        Self::new(&api_key)
113    }
114
115    fn from_val(input: crate::client::ProviderValue) -> Self {
116        let crate::client::ProviderValue::Simple(api_key) = input else {
117            panic!("Incorrect provider value type")
118        };
119        Self::new(&api_key)
120    }
121}
122
123impl CompletionClient for Client {
124    type CompletionModel = CompletionModel;
125
126    /// Create a completion model with the given name.
127    ///
128    /// # Example
129    /// ```
130    /// use rig::providers::openrouter::{Client, self};
131    ///
132    /// // Initialize the openrouter client
133    /// let openrouter = Client::new("your-openrouter-api-key");
134    ///
135    /// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
136    /// ```
137    fn completion_model(&self, model: &str) -> CompletionModel {
138        CompletionModel::new(self.clone(), model)
139    }
140}
141
142impl VerifyClient for Client {
143    #[cfg_attr(feature = "worker", worker::send)]
144    async fn verify(&self) -> Result<(), VerifyError> {
145        let response = self.get("/key").send().await?;
146        match response.status() {
147            reqwest::StatusCode::OK => Ok(()),
148            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
149            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
150                Err(VerifyError::ProviderError(response.text().await?))
151            }
152            _ => {
153                response.error_for_status()?;
154                Ok(())
155            }
156        }
157    }
158}
159
160impl_conversion_traits!(
161    AsEmbeddings,
162    AsTranscription,
163    AsImageGeneration,
164    AsAudioGeneration for Client
165);
166
167#[derive(Debug, Deserialize)]
168pub(crate) struct ApiErrorResponse {
169    pub message: String,
170}
171
172#[derive(Debug, Deserialize)]
173#[serde(untagged)]
174pub(crate) enum ApiResponse<T> {
175    Ok(T),
176    Err(ApiErrorResponse),
177}
178
179#[derive(Clone, Debug, Deserialize, Serialize)]
180pub struct Usage {
181    pub prompt_tokens: usize,
182    pub completion_tokens: usize,
183    pub total_tokens: usize,
184}
185
186impl std::fmt::Display for Usage {
187    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188        write!(
189            f,
190            "Prompt tokens: {} Total tokens: {}",
191            self.prompt_tokens, self.total_tokens
192        )
193    }
194}