bep/providers/xai/
client.rs

1use crate::{
2    agent::AgentBuilder,
3    embeddings::{self},
4    extractor::ExtractorBuilder,
5    Embed,
6};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9
10use super::{completion::CompletionModel, embedding::EmbeddingModel, EMBEDDING_V1};
11
12// ================================================================
13// xAI Client
14// ================================================================
15const XAI_BASE_URL: &str = "https://api.x.ai";
16
17#[derive(Clone)]
18pub struct Client {
19    base_url: String,
20    http_client: reqwest::Client,
21}
22
23impl Client {
24    pub fn new(api_key: &str) -> Self {
25        Self::from_url(api_key, XAI_BASE_URL)
26    }
27    fn from_url(api_key: &str, base_url: &str) -> Self {
28        Self {
29            base_url: base_url.to_string(),
30            http_client: reqwest::Client::builder()
31                .default_headers({
32                    let mut headers = reqwest::header::HeaderMap::new();
33                    headers.insert(
34                        reqwest::header::CONTENT_TYPE,
35                        "application/json".parse().unwrap(),
36                    );
37                    headers.insert(
38                        "Authorization",
39                        format!("Bearer {}", api_key)
40                            .parse()
41                            .expect("Bearer token should parse"),
42                    );
43                    headers
44                })
45                .build()
46                .expect("xAI reqwest client should build"),
47        }
48    }
49
50    /// Create a new xAI client from the `XAI_API_KEY` environment variable.
51    /// Panics if the environment variable is not set.
52    pub fn from_env() -> Self {
53        let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set");
54        Self::new(&api_key)
55    }
56
57    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
58        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
59
60        tracing::debug!("POST {}", url);
61        self.http_client.post(url)
62    }
63
64    /// Create an embedding model with the given name.
65    /// Note: default embedding dimension of 0 will be used if model is not known.
66    /// If this is the case, it's better to use function `embedding_model_with_ndims`
67    ///
68    /// # Example
69    /// ```
70    /// use bep::providers::xai::{Client, self};
71    ///
72    /// // Initialize the xAI client
73    /// let xai = Client::new("your-xai-api-key");
74    ///
75    /// let embedding_model = xai.embedding_model(xai::embedding::EMBEDDING_V1);
76    /// ```
77    pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
78        let ndims = match model {
79            EMBEDDING_V1 => 3072,
80            _ => 0,
81        };
82        EmbeddingModel::new(self.clone(), model, ndims)
83    }
84
85    /// Create an embedding model with the given name and the number of dimensions in the embedding
86    ///  generated by the model.
87    ///
88    /// # Example
89    /// ```
90    /// use bep::providers::xai::{Client, self};
91    ///
92    /// // Initialize the xAI client
93    /// let xai = Client::new("your-xai-api-key");
94    ///
95    /// let embedding_model = xai.embedding_model_with_ndims("model-unknown-to-bep", 1024);
96    /// ```
97    pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
98        EmbeddingModel::new(self.clone(), model, ndims)
99    }
100
101    /// Create an embedding builder with the given embedding model.
102    ///
103    /// # Example
104    /// ```
105    /// use bep::providers::xai::{Client, self};
106    ///
107    /// // Initialize the xAI client
108    /// let xai = Client::new("your-xai-api-key");
109    ///
110    /// let embeddings = xai.embeddings(xai::embedding::EMBEDDING_V1)
111    ///     .simple_document("doc0", "Hello, world!")
112    ///     .simple_document("doc1", "Goodbye, world!")
113    ///     .build()
114    ///     .await
115    ///     .expect("Failed to embed documents");
116    /// ```
117    pub fn embeddings<D: Embed>(
118        &self,
119        model: &str,
120    ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
121        embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
122    }
123
124    /// Create a completion model with the given name.
125    pub fn completion_model(&self, model: &str) -> CompletionModel {
126        CompletionModel::new(self.clone(), model)
127    }
128
129    /// Create an agent builder with the given completion model.
130    /// # Example
131    /// ```
132    /// use bep::providers::xai::{Client, self};
133    ///
134    /// // Initialize the xAI client
135    /// let xai = Client::new("your-xai-api-key");
136    ///
137    /// let agent = xai.agent(xai::completion::GROK_BETA)
138    ///    .preamble("You are comedian AI with a mission to make people laugh.")
139    ///    .temperature(0.0)
140    ///    .build();
141    /// ```
142    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
143        AgentBuilder::new(self.completion_model(model))
144    }
145
146    /// Create an extractor builder with the given completion model.
147    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
148        &self,
149        model: &str,
150    ) -> ExtractorBuilder<T, CompletionModel> {
151        ExtractorBuilder::new(self.completion_model(model))
152    }
153}
154
155pub mod xai_api_types {
156    use serde::Deserialize;
157
158    impl ApiErrorResponse {
159        pub fn message(&self) -> String {
160            format!("Code `{}`: {}", self.code, self.error)
161        }
162    }
163
164    #[derive(Debug, Deserialize)]
165    pub struct ApiErrorResponse {
166        pub error: String,
167        pub code: String,
168    }
169
170    #[derive(Debug, Deserialize)]
171    #[serde(untagged)]
172    pub enum ApiResponse<T> {
173        Ok(T),
174        Error(ApiErrorResponse),
175    }
176}