rig/providers/xai/
client.rs

1use super::completion::CompletionModel;
2use crate::{
3    client::{CompletionClient, ProviderClient, VerifyClient, VerifyError, impl_conversion_traits},
4    http_client,
5};
6
7// ================================================================
8// xAI Client
9// ================================================================
10const XAI_BASE_URL: &str = "https://api.x.ai";
11
12pub struct ClientBuilder<'a, T = reqwest::Client> {
13    api_key: &'a str,
14    base_url: &'a str,
15    http_client: T,
16}
17
18impl<'a, T> ClientBuilder<'a, T>
19where
20    T: Default,
21{
22    pub fn new(api_key: &'a str) -> Self {
23        Self {
24            api_key,
25            base_url: XAI_BASE_URL,
26            http_client: Default::default(),
27        }
28    }
29}
30
31impl<'a, T> ClientBuilder<'a, T> {
32    pub fn base_url(mut self, base_url: &'a str) -> Self {
33        self.base_url = base_url;
34        self
35    }
36
37    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
38        ClientBuilder {
39            api_key: self.api_key,
40            base_url: self.base_url,
41            http_client,
42        }
43    }
44
45    pub fn build(self) -> Client<T> {
46        let mut default_headers = reqwest::header::HeaderMap::new();
47        default_headers.insert(
48            reqwest::header::CONTENT_TYPE,
49            "application/json".parse().unwrap(),
50        );
51
52        Client {
53            base_url: self.base_url.to_string(),
54            api_key: self.api_key.to_string(),
55            default_headers,
56            http_client: self.http_client,
57        }
58    }
59}
60
61#[derive(Clone)]
62pub struct Client<T = reqwest::Client> {
63    base_url: String,
64    api_key: String,
65    default_headers: http_client::HeaderMap,
66    http_client: T,
67}
68
69impl<T> std::fmt::Debug for Client<T>
70where
71    T: std::fmt::Debug,
72{
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        f.debug_struct("Client")
75            .field("base_url", &self.base_url)
76            .field("http_client", &self.http_client)
77            .field("default_headers", &self.default_headers)
78            .field("api_key", &"<REDACTED>")
79            .finish()
80    }
81}
82
83impl<T> Client<T>
84where
85    T: Default,
86{
87    /// Create a new xAI client builder.
88    ///
89    /// # Example
90    /// ```
91    /// use rig::providers::xai::{ClientBuilder, self};
92    ///
93    /// // Initialize the xAI client
94    /// let xai = Client::builder("your-xai-api-key")
95    ///    .build()
96    /// ```
97    pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
98        ClientBuilder::new(api_key)
99    }
100
101    /// Create a new xAI client. For more control, use the `builder` method.
102    ///
103    /// # Panics
104    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
105    pub fn new(api_key: &str) -> Self {
106        Self::builder(api_key).build()
107    }
108}
109
110impl Client<reqwest::Client> {
111    pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
112        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
113
114        tracing::debug!("POST {}", url);
115
116        self.http_client
117            .post(url)
118            .bearer_auth(&self.api_key)
119            .headers(self.default_headers.clone())
120    }
121
122    pub(crate) fn reqwest_get(&self, path: &str) -> reqwest::RequestBuilder {
123        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
124
125        tracing::debug!("GET {}", url);
126
127        self.http_client
128            .get(url)
129            .bearer_auth(&self.api_key)
130            .headers(self.default_headers.clone())
131    }
132}
133
134impl ProviderClient for Client<reqwest::Client> {
135    /// Create a new xAI client from the `XAI_API_KEY` environment variable.
136    /// Panics if the environment variable is not set.
137    fn from_env() -> Self {
138        let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set");
139        Self::new(&api_key)
140    }
141
142    fn from_val(input: crate::client::ProviderValue) -> Self {
143        let crate::client::ProviderValue::Simple(api_key) = input else {
144            panic!("Incorrect provider value type")
145        };
146        Self::new(&api_key)
147    }
148}
149
150impl CompletionClient for Client<reqwest::Client> {
151    type CompletionModel = CompletionModel<reqwest::Client>;
152
153    /// Create a completion model with the given name.
154    fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
155        CompletionModel::new(self.clone(), model)
156    }
157}
158
159impl VerifyClient for Client<reqwest::Client> {
160    #[cfg_attr(feature = "worker", worker::send)]
161    async fn verify(&self) -> Result<(), VerifyError> {
162        let response = self
163            .reqwest_get("/v1/api-key")
164            .send()
165            .await
166            .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
167
168        match response.status() {
169            reqwest::StatusCode::OK => Ok(()),
170            reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
171                Err(VerifyError::InvalidAuthentication)
172            }
173            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
174                Err(VerifyError::ProviderError(response.text().await.map_err(
175                    |e| VerifyError::HttpError(http_client::Error::Instance(e.into())),
176                )?))
177            }
178            _ => {
179                response
180                    .error_for_status()
181                    .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
182                Ok(())
183            }
184        }
185    }
186}
187
188impl_conversion_traits!(
189    AsEmbeddings,
190    AsTranscription,
191    AsImageGeneration,
192    AsAudioGeneration for Client<T>
193);
194
195pub mod xai_api_types {
196    use serde::Deserialize;
197
198    impl ApiErrorResponse {
199        pub fn message(&self) -> String {
200            format!("Code `{}`: {}", self.code, self.error)
201        }
202    }
203
204    #[derive(Debug, Deserialize)]
205    pub struct ApiErrorResponse {
206        pub error: String,
207        pub code: String,
208    }
209
210    #[derive(Debug, Deserialize)]
211    #[serde(untagged)]
212    pub enum ApiResponse<T> {
213        Ok(T),
214        Error(ApiErrorResponse),
215    }
216}