rig/providers/xai/
client.rs

1use super::completion::CompletionModel;
2use crate::client::{ClientBuilderError, CompletionClient, ProviderClient, impl_conversion_traits};
3
4// ================================================================
5// xAI Client
6// ================================================================
7const XAI_BASE_URL: &str = "https://api.x.ai";
8
9pub struct ClientBuilder<'a> {
10    api_key: &'a str,
11    base_url: &'a str,
12    http_client: Option<reqwest::Client>,
13}
14
15impl<'a> ClientBuilder<'a> {
16    pub fn new(api_key: &'a str) -> Self {
17        Self {
18            api_key,
19            base_url: XAI_BASE_URL,
20            http_client: None,
21        }
22    }
23
24    pub fn base_url(mut self, base_url: &'a str) -> Self {
25        self.base_url = base_url;
26        self
27    }
28
29    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
30        self.http_client = Some(client);
31        self
32    }
33
34    pub fn build(self) -> Result<Client, ClientBuilderError> {
35        let mut default_headers = reqwest::header::HeaderMap::new();
36        default_headers.insert(
37            reqwest::header::CONTENT_TYPE,
38            "application/json".parse().unwrap(),
39        );
40
41        let http_client = if let Some(http_client) = self.http_client {
42            http_client
43        } else {
44            reqwest::Client::builder().build()?
45        };
46
47        Ok(Client {
48            base_url: self.base_url.to_string(),
49            api_key: self.api_key.to_string(),
50            default_headers,
51            http_client,
52        })
53    }
54}
55
56#[derive(Clone)]
57pub struct Client {
58    base_url: String,
59    api_key: String,
60    default_headers: reqwest::header::HeaderMap,
61    http_client: reqwest::Client,
62}
63
64impl std::fmt::Debug for Client {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("Client")
67            .field("base_url", &self.base_url)
68            .field("http_client", &self.http_client)
69            .field("default_headers", &self.default_headers)
70            .field("api_key", &"<REDACTED>")
71            .finish()
72    }
73}
74
75impl Client {
76    /// Create a new xAI client builder.
77    ///
78    /// # Example
79    /// ```
80    /// use rig::providers::xai::{ClientBuilder, self};
81    ///
82    /// // Initialize the xAI client
83    /// let xai = Client::builder("your-xai-api-key")
84    ///    .build()
85    /// ```
86    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
87        ClientBuilder::new(api_key)
88    }
89
90    /// Create a new xAI client. For more control, use the `builder` method.
91    ///
92    /// # Panics
93    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
94    pub fn new(api_key: &str) -> Self {
95        Self::builder(api_key)
96            .build()
97            .expect("xAI client should build")
98    }
99
100    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
101        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
102
103        tracing::debug!("POST {}", url);
104        self.http_client
105            .post(url)
106            .bearer_auth(&self.api_key)
107            .headers(self.default_headers.clone())
108    }
109}
110
111impl ProviderClient for Client {
112    /// Create a new xAI client from the `XAI_API_KEY` environment variable.
113    /// Panics if the environment variable is not set.
114    fn from_env() -> Self {
115        let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set");
116        Self::new(&api_key)
117    }
118
119    fn from_val(input: crate::client::ProviderValue) -> Self {
120        let crate::client::ProviderValue::Simple(api_key) = input else {
121            panic!("Incorrect provider value type")
122        };
123        Self::new(&api_key)
124    }
125}
126
127impl CompletionClient for Client {
128    type CompletionModel = CompletionModel;
129
130    /// Create a completion model with the given name.
131    fn completion_model(&self, model: &str) -> CompletionModel {
132        CompletionModel::new(self.clone(), model)
133    }
134}
135
136impl_conversion_traits!(
137    AsEmbeddings,
138    AsTranscription,
139    AsImageGeneration,
140    AsAudioGeneration for Client
141);
142
143pub mod xai_api_types {
144    use serde::Deserialize;
145
146    impl ApiErrorResponse {
147        pub fn message(&self) -> String {
148            format!("Code `{}`: {}", self.code, self.error)
149        }
150    }
151
152    #[derive(Debug, Deserialize)]
153    pub struct ApiErrorResponse {
154        pub error: String,
155        pub code: String,
156    }
157
158    #[derive(Debug, Deserialize)]
159    #[serde(untagged)]
160    pub enum ApiResponse<T> {
161        Ok(T),
162        Error(ApiErrorResponse),
163    }
164}