rig/providers/xai/
client.rs

1use http::Method;
2
3use super::completion::CompletionModel;
4use crate::{
5    client::{CompletionClient, ProviderClient, VerifyClient, VerifyError, impl_conversion_traits},
6    http_client::{self, HttpClientExt, NoBody, Result as HttpResult, with_bearer_auth},
7};
8
9// ================================================================
10// xAI Client
11// ================================================================
12const XAI_BASE_URL: &str = "https://api.x.ai";
13
14pub struct ClientBuilder<'a, T = reqwest::Client> {
15    api_key: &'a str,
16    base_url: &'a str,
17    http_client: T,
18}
19
20impl<'a, T> ClientBuilder<'a, T>
21where
22    T: Default,
23{
24    pub fn new(api_key: &'a str) -> Self {
25        Self {
26            api_key,
27            base_url: XAI_BASE_URL,
28            http_client: Default::default(),
29        }
30    }
31}
32
33impl<'a, T> ClientBuilder<'a, T> {
34    pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
35        Self {
36            api_key,
37            base_url: XAI_BASE_URL,
38            http_client,
39        }
40    }
41
42    pub fn base_url(mut self, base_url: &'a str) -> Self {
43        self.base_url = base_url;
44        self
45    }
46
47    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
48        ClientBuilder {
49            api_key: self.api_key,
50            base_url: self.base_url,
51            http_client,
52        }
53    }
54
55    pub fn build(self) -> Client<T> {
56        let mut default_headers = reqwest::header::HeaderMap::new();
57        default_headers.insert(
58            reqwest::header::CONTENT_TYPE,
59            "application/json".parse().unwrap(),
60        );
61
62        Client {
63            base_url: self.base_url.to_string(),
64            api_key: self.api_key.to_string(),
65            default_headers,
66            http_client: self.http_client,
67        }
68    }
69}
70
71#[derive(Clone)]
72pub struct Client<T = reqwest::Client> {
73    base_url: String,
74    api_key: String,
75    default_headers: http_client::HeaderMap,
76    pub http_client: T,
77}
78
79impl<T> std::fmt::Debug for Client<T>
80where
81    T: std::fmt::Debug,
82{
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        f.debug_struct("Client")
85            .field("base_url", &self.base_url)
86            .field("http_client", &self.http_client)
87            .field("default_headers", &self.default_headers)
88            .field("api_key", &"<REDACTED>")
89            .finish()
90    }
91}
92
93impl Client<reqwest::Client> {
94    /// Create a new xAI client builder.
95    ///
96    /// # Example
97    /// ```
98    /// use rig::providers::xai::{ClientBuilder, self};
99    ///
100    /// // Initialize the xAI client
101    /// let xai = Client::builder("your-xai-api-key")
102    ///    .build()
103    /// ```
104    pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
105        ClientBuilder::new(api_key)
106    }
107
108    /// Create a new xAI client. For more control, use the `builder` method.
109    ///
110    /// # Panics
111    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
112    pub fn new(api_key: &str) -> Self {
113        Self::builder(api_key).build()
114    }
115
116    pub fn from_env() -> Self {
117        <Self as ProviderClient>::from_env()
118    }
119}
120
121impl<T> Client<T>
122where
123    T: HttpClientExt,
124{
125    pub(crate) fn req(&self, method: Method, path: &str) -> HttpResult<http_client::Builder> {
126        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
127
128        let mut builder = http_client::Builder::new().uri(url).method(method);
129        for (header, value) in &self.default_headers {
130            builder = builder.header(header, value);
131        }
132
133        with_bearer_auth(builder, &self.api_key)
134    }
135
136    pub(crate) fn post(&self, path: &str) -> HttpResult<http_client::Builder> {
137        self.req(Method::POST, path)
138    }
139
140    pub(crate) fn get(&self, path: &str) -> HttpResult<http_client::Builder> {
141        self.req(Method::GET, path)
142    }
143}
144
145impl<T> ProviderClient for Client<T>
146where
147    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
148{
149    /// Create a new xAI client from the `XAI_API_KEY` environment variable.
150    /// Panics if the environment variable is not set.
151    fn from_env() -> Self {
152        let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set");
153        ClientBuilder::<T>::new(&api_key).build()
154    }
155
156    fn from_val(input: crate::client::ProviderValue) -> Self {
157        let crate::client::ProviderValue::Simple(api_key) = input else {
158            panic!("Incorrect provider value type")
159        };
160        ClientBuilder::<T>::new(&api_key).build()
161    }
162}
163
164impl<T> CompletionClient for Client<T>
165where
166    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
167{
168    type CompletionModel = CompletionModel<T>;
169
170    /// Create a completion model with the given name.
171    fn completion_model(&self, model: &str) -> CompletionModel<T> {
172        CompletionModel::new(self.clone(), model)
173    }
174}
175
176impl<T> VerifyClient for Client<T>
177where
178    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
179{
180    #[cfg_attr(feature = "worker", worker::send)]
181    async fn verify(&self) -> Result<(), VerifyError> {
182        let req = self.get("/v1/api-key").unwrap().body(NoBody).unwrap();
183
184        let response = self.http_client.send::<_, Vec<u8>>(req).await.unwrap();
185        let status = response.status();
186
187        match status {
188            reqwest::StatusCode::OK => Ok(()),
189            reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
190                Err(VerifyError::InvalidAuthentication)
191            }
192            reqwest::StatusCode::INTERNAL_SERVER_ERROR => Err(VerifyError::ProviderError(
193                http_client::text(response).await?,
194            )),
195            _ => Err(VerifyError::HttpError(http_client::Error::Instance(
196                http_client::text(response).await?.into(),
197            ))),
198        }
199    }
200}
201
202impl_conversion_traits!(
203    AsEmbeddings,
204    AsTranscription,
205    AsImageGeneration,
206    AsAudioGeneration for Client<T>
207);
208
209pub mod xai_api_types {
210    use serde::Deserialize;
211
212    impl ApiErrorResponse {
213        pub fn message(&self) -> String {
214            format!("Code `{}`: {}", self.code, self.error)
215        }
216    }
217
218    #[derive(Debug, Deserialize)]
219    pub struct ApiErrorResponse {
220        pub error: String,
221        pub code: String,
222    }
223
224    #[derive(Debug, Deserialize)]
225    #[serde(untagged)]
226    pub enum ApiResponse<T> {
227        Ok(T),
228        Error(ApiErrorResponse),
229    }
230}