rig/providers/xai/
client.rs

1use super::completion::CompletionModel;
2use crate::client::{
3    ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError,
4    impl_conversion_traits,
5};
6
7// ================================================================
8// xAI Client
9// ================================================================
10const XAI_BASE_URL: &str = "https://api.x.ai";
11
12pub struct ClientBuilder<'a> {
13    api_key: &'a str,
14    base_url: &'a str,
15    http_client: Option<reqwest::Client>,
16}
17
18impl<'a> ClientBuilder<'a> {
19    pub fn new(api_key: &'a str) -> Self {
20        Self {
21            api_key,
22            base_url: XAI_BASE_URL,
23            http_client: None,
24        }
25    }
26
27    pub fn base_url(mut self, base_url: &'a str) -> Self {
28        self.base_url = base_url;
29        self
30    }
31
32    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
33        self.http_client = Some(client);
34        self
35    }
36
37    pub fn build(self) -> Result<Client, ClientBuilderError> {
38        let mut default_headers = reqwest::header::HeaderMap::new();
39        default_headers.insert(
40            reqwest::header::CONTENT_TYPE,
41            "application/json".parse().unwrap(),
42        );
43
44        let http_client = if let Some(http_client) = self.http_client {
45            http_client
46        } else {
47            reqwest::Client::builder().build()?
48        };
49
50        Ok(Client {
51            base_url: self.base_url.to_string(),
52            api_key: self.api_key.to_string(),
53            default_headers,
54            http_client,
55        })
56    }
57}
58
59#[derive(Clone)]
60pub struct Client {
61    base_url: String,
62    api_key: String,
63    default_headers: reqwest::header::HeaderMap,
64    http_client: reqwest::Client,
65}
66
67impl std::fmt::Debug for Client {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        f.debug_struct("Client")
70            .field("base_url", &self.base_url)
71            .field("http_client", &self.http_client)
72            .field("default_headers", &self.default_headers)
73            .field("api_key", &"<REDACTED>")
74            .finish()
75    }
76}
77
78impl Client {
79    /// Create a new xAI client builder.
80    ///
81    /// # Example
82    /// ```
83    /// use rig::providers::xai::{ClientBuilder, self};
84    ///
85    /// // Initialize the xAI client
86    /// let xai = Client::builder("your-xai-api-key")
87    ///    .build()
88    /// ```
89    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
90        ClientBuilder::new(api_key)
91    }
92
93    /// Create a new xAI client. For more control, use the `builder` method.
94    ///
95    /// # Panics
96    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
97    pub fn new(api_key: &str) -> Self {
98        Self::builder(api_key)
99            .build()
100            .expect("xAI client should build")
101    }
102
103    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
104        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
105
106        tracing::debug!("POST {}", url);
107        self.http_client
108            .post(url)
109            .bearer_auth(&self.api_key)
110            .headers(self.default_headers.clone())
111    }
112
113    pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
114        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
115
116        tracing::debug!("GET {}", url);
117        self.http_client
118            .get(url)
119            .bearer_auth(&self.api_key)
120            .headers(self.default_headers.clone())
121    }
122}
123
124impl ProviderClient for Client {
125    /// Create a new xAI client from the `XAI_API_KEY` environment variable.
126    /// Panics if the environment variable is not set.
127    fn from_env() -> Self {
128        let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set");
129        Self::new(&api_key)
130    }
131
132    fn from_val(input: crate::client::ProviderValue) -> Self {
133        let crate::client::ProviderValue::Simple(api_key) = input else {
134            panic!("Incorrect provider value type")
135        };
136        Self::new(&api_key)
137    }
138}
139
140impl CompletionClient for Client {
141    type CompletionModel = CompletionModel;
142
143    /// Create a completion model with the given name.
144    fn completion_model(&self, model: &str) -> CompletionModel {
145        CompletionModel::new(self.clone(), model)
146    }
147}
148
149impl VerifyClient for Client {
150    #[cfg_attr(feature = "worker", worker::send)]
151    async fn verify(&self) -> Result<(), VerifyError> {
152        let response = self.get("/v1/api-key").send().await?;
153        match response.status() {
154            reqwest::StatusCode::OK => Ok(()),
155            reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
156                Err(VerifyError::InvalidAuthentication)
157            }
158            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
159                Err(VerifyError::ProviderError(response.text().await?))
160            }
161            _ => {
162                response.error_for_status()?;
163                Ok(())
164            }
165        }
166    }
167}
168
169impl_conversion_traits!(
170    AsEmbeddings,
171    AsTranscription,
172    AsImageGeneration,
173    AsAudioGeneration for Client
174);
175
176pub mod xai_api_types {
177    use serde::Deserialize;
178
179    impl ApiErrorResponse {
180        pub fn message(&self) -> String {
181            format!("Code `{}`: {}", self.code, self.error)
182        }
183    }
184
185    #[derive(Debug, Deserialize)]
186    pub struct ApiErrorResponse {
187        pub error: String,
188        pub code: String,
189    }
190
191    #[derive(Debug, Deserialize)]
192    #[serde(untagged)]
193    pub enum ApiResponse<T> {
194        Ok(T),
195        Error(ApiErrorResponse),
196    }
197}