rig/providers/xai/
client.rs

1use super::{completion::CompletionModel, embedding::EmbeddingModel, EMBEDDING_V1};
2use crate::client::{impl_conversion_traits, CompletionClient, EmbeddingsClient, ProviderClient};
3
4// ================================================================
5// xAI Client
6// ================================================================
7const XAI_BASE_URL: &str = "https://api.x.ai";
8
9#[derive(Clone, Debug)]
10pub struct Client {
11    base_url: String,
12    http_client: reqwest::Client,
13}
14
15impl Client {
16    pub fn new(api_key: &str) -> Self {
17        Self::from_url(api_key, XAI_BASE_URL)
18    }
19    fn from_url(api_key: &str, base_url: &str) -> Self {
20        Self {
21            base_url: base_url.to_string(),
22            http_client: reqwest::Client::builder()
23                .default_headers({
24                    let mut headers = reqwest::header::HeaderMap::new();
25                    headers.insert(
26                        reqwest::header::CONTENT_TYPE,
27                        "application/json".parse().unwrap(),
28                    );
29                    headers.insert(
30                        "Authorization",
31                        format!("Bearer {api_key}")
32                            .parse()
33                            .expect("Bearer token should parse"),
34                    );
35                    headers
36                })
37                .build()
38                .expect("xAI reqwest client should build"),
39        }
40    }
41
42    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
43        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
44
45        tracing::debug!("POST {}", url);
46        self.http_client.post(url)
47    }
48}
49
50impl ProviderClient for Client {
51    /// Create a new xAI client from the `XAI_API_KEY` environment variable.
52    /// Panics if the environment variable is not set.
53    fn from_env() -> Self {
54        let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set");
55        Self::new(&api_key)
56    }
57}
58
59impl CompletionClient for Client {
60    type CompletionModel = CompletionModel;
61
62    /// Create a completion model with the given name.
63    fn completion_model(&self, model: &str) -> CompletionModel {
64        CompletionModel::new(self.clone(), model)
65    }
66}
67
68impl EmbeddingsClient for Client {
69    type EmbeddingModel = EmbeddingModel;
70    /// Create an embedding model with the given name.
71    /// Note: default embedding dimension of 0 will be used if model is not known.
72    /// If this is the case, it's better to use function `embedding_model_with_ndims`
73    ///
74    /// # Example
75    /// ```
76    /// use rig::providers::xai::{Client, self};
77    ///
78    /// // Initialize the xAI client
79    /// let xai = Client::new("your-xai-api-key");
80    ///
81    /// let embedding_model = xai.embedding_model(xai::embedding::EMBEDDING_V1);
82    /// ```
83    fn embedding_model(&self, model: &str) -> EmbeddingModel {
84        let ndims = match model {
85            EMBEDDING_V1 => 3072,
86            _ => 0,
87        };
88        EmbeddingModel::new(self.clone(), model, ndims)
89    }
90
91    /// Create an embedding model with the given name and the number of dimensions in the embedding
92    ///  generated by the model.
93    ///
94    /// # Example
95    /// ```
96    /// use rig::providers::xai::{Client, self};
97    ///
98    /// // Initialize the xAI client
99    /// let xai = Client::new("your-xai-api-key");
100    ///
101    /// let embedding_model = xai.embedding_model_with_ndims("model-unknown-to-rig", 1024);
102    /// ```
103    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
104        EmbeddingModel::new(self.clone(), model, ndims)
105    }
106}
107
108impl_conversion_traits!(
109    AsTranscription,
110    AsImageGeneration,
111    AsAudioGeneration for Client
112);
113
114pub mod xai_api_types {
115    use serde::Deserialize;
116
117    impl ApiErrorResponse {
118        pub fn message(&self) -> String {
119            format!("Code `{}`: {}", self.code, self.error)
120        }
121    }
122
123    #[derive(Debug, Deserialize)]
124    pub struct ApiErrorResponse {
125        pub error: String,
126        pub code: String,
127    }
128
129    #[derive(Debug, Deserialize)]
130    #[serde(untagged)]
131    pub enum ApiResponse<T> {
132        Ok(T),
133        Error(ApiErrorResponse),
134    }
135}