rig/providers/openrouter/
client.rs

1use crate::{
2    client::{ClientBuilderError, CompletionClient, ProviderClient},
3    impl_conversion_traits,
4};
5use serde::{Deserialize, Serialize};
6
7use super::completion::CompletionModel;
8
9// ================================================================
10// Main openrouter Client
11// ================================================================
12const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
13
14pub struct ClientBuilder<'a> {
15    api_key: &'a str,
16    base_url: &'a str,
17    http_client: Option<reqwest::Client>,
18}
19
20impl<'a> ClientBuilder<'a> {
21    pub fn new(api_key: &'a str) -> Self {
22        Self {
23            api_key,
24            base_url: OPENROUTER_API_BASE_URL,
25            http_client: None,
26        }
27    }
28
29    pub fn base_url(mut self, base_url: &'a str) -> Self {
30        self.base_url = base_url;
31        self
32    }
33
34    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
35        self.http_client = Some(client);
36        self
37    }
38
39    pub fn build(self) -> Result<Client, ClientBuilderError> {
40        let http_client = if let Some(http_client) = self.http_client {
41            http_client
42        } else {
43            reqwest::Client::builder().build()?
44        };
45
46        Ok(Client {
47            base_url: self.base_url.to_string(),
48            api_key: self.api_key.to_string(),
49            http_client,
50        })
51    }
52}
53
54#[derive(Clone)]
55pub struct Client {
56    base_url: String,
57    api_key: String,
58    http_client: reqwest::Client,
59}
60
61impl std::fmt::Debug for Client {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        f.debug_struct("Client")
64            .field("base_url", &self.base_url)
65            .field("http_client", &self.http_client)
66            .field("api_key", &"<REDACTED>")
67            .finish()
68    }
69}
70
71impl Client {
72    /// Create a new OpenRouter client builder.
73    ///
74    /// # Example
75    /// ```
76    /// use rig::providers::openrouter::{ClientBuilder, self};
77    ///
78    /// // Initialize the OpenRouter client
79    /// let openrouter = Client::builder("your-openrouter-api-key")
80    ///    .build()
81    /// ```
82    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
83        ClientBuilder::new(api_key)
84    }
85
86    /// Create a new OpenRouter client. For more control, use the `builder` method.
87    ///
88    /// # Panics
89    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
90    pub fn new(api_key: &str) -> Self {
91        Self::builder(api_key)
92            .build()
93            .expect("OpenRouter client should build")
94    }
95
96    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
97        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
98        self.http_client.post(url).bearer_auth(&self.api_key)
99    }
100}
101
102impl ProviderClient for Client {
103    /// Create a new openrouter client from the `OPENROUTER_API_KEY` environment variable.
104    /// Panics if the environment variable is not set.
105    fn from_env() -> Self {
106        let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
107        Self::new(&api_key)
108    }
109
110    fn from_val(input: crate::client::ProviderValue) -> Self {
111        let crate::client::ProviderValue::Simple(api_key) = input else {
112            panic!("Incorrect provider value type")
113        };
114        Self::new(&api_key)
115    }
116}
117
118impl CompletionClient for Client {
119    type CompletionModel = CompletionModel;
120
121    /// Create a completion model with the given name.
122    ///
123    /// # Example
124    /// ```
125    /// use rig::providers::openrouter::{Client, self};
126    ///
127    /// // Initialize the openrouter client
128    /// let openrouter = Client::new("your-openrouter-api-key");
129    ///
130    /// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
131    /// ```
132    fn completion_model(&self, model: &str) -> CompletionModel {
133        CompletionModel::new(self.clone(), model)
134    }
135}
136
137impl_conversion_traits!(
138    AsEmbeddings,
139    AsTranscription,
140    AsImageGeneration,
141    AsAudioGeneration for Client
142);
143
144#[derive(Debug, Deserialize)]
145pub(crate) struct ApiErrorResponse {
146    pub message: String,
147}
148
149#[derive(Debug, Deserialize)]
150#[serde(untagged)]
151pub(crate) enum ApiResponse<T> {
152    Ok(T),
153    Err(ApiErrorResponse),
154}
155
156#[derive(Clone, Debug, Deserialize, Serialize)]
157pub struct Usage {
158    pub prompt_tokens: usize,
159    pub completion_tokens: usize,
160    pub total_tokens: usize,
161}
162
163impl std::fmt::Display for Usage {
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        write!(
166            f,
167            "Prompt tokens: {} Total tokens: {}",
168            self.prompt_tokens, self.total_tokens
169        )
170    }
171}