bep/providers/xai/client.rs
1use crate::{
2 agent::AgentBuilder,
3 embeddings::{self},
4 extractor::ExtractorBuilder,
5 Embed,
6};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9
10use super::{completion::CompletionModel, embedding::EmbeddingModel, EMBEDDING_V1};
11
12// ================================================================
13// xAI Client
14// ================================================================
15const XAI_BASE_URL: &str = "https://api.x.ai";
16
17#[derive(Clone)]
18pub struct Client {
19 base_url: String,
20 http_client: reqwest::Client,
21}
22
23impl Client {
24 pub fn new(api_key: &str) -> Self {
25 Self::from_url(api_key, XAI_BASE_URL)
26 }
27 fn from_url(api_key: &str, base_url: &str) -> Self {
28 Self {
29 base_url: base_url.to_string(),
30 http_client: reqwest::Client::builder()
31 .default_headers({
32 let mut headers = reqwest::header::HeaderMap::new();
33 headers.insert(
34 reqwest::header::CONTENT_TYPE,
35 "application/json".parse().unwrap(),
36 );
37 headers.insert(
38 "Authorization",
39 format!("Bearer {}", api_key)
40 .parse()
41 .expect("Bearer token should parse"),
42 );
43 headers
44 })
45 .build()
46 .expect("xAI reqwest client should build"),
47 }
48 }
49
50 /// Create a new xAI client from the `XAI_API_KEY` environment variable.
51 /// Panics if the environment variable is not set.
52 pub fn from_env() -> Self {
53 let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set");
54 Self::new(&api_key)
55 }
56
57 pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
58 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
59
60 tracing::debug!("POST {}", url);
61 self.http_client.post(url)
62 }
63
64 /// Create an embedding model with the given name.
65 /// Note: default embedding dimension of 0 will be used if model is not known.
66 /// If this is the case, it's better to use function `embedding_model_with_ndims`
67 ///
68 /// # Example
69 /// ```
70 /// use bep::providers::xai::{Client, self};
71 ///
72 /// // Initialize the xAI client
73 /// let xai = Client::new("your-xai-api-key");
74 ///
75 /// let embedding_model = xai.embedding_model(xai::embedding::EMBEDDING_V1);
76 /// ```
77 pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
78 let ndims = match model {
79 EMBEDDING_V1 => 3072,
80 _ => 0,
81 };
82 EmbeddingModel::new(self.clone(), model, ndims)
83 }
84
85 /// Create an embedding model with the given name and the number of dimensions in the embedding
86 /// generated by the model.
87 ///
88 /// # Example
89 /// ```
90 /// use bep::providers::xai::{Client, self};
91 ///
92 /// // Initialize the xAI client
93 /// let xai = Client::new("your-xai-api-key");
94 ///
95 /// let embedding_model = xai.embedding_model_with_ndims("model-unknown-to-bep", 1024);
96 /// ```
97 pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
98 EmbeddingModel::new(self.clone(), model, ndims)
99 }
100
101 /// Create an embedding builder with the given embedding model.
102 ///
103 /// # Example
104 /// ```
105 /// use bep::providers::xai::{Client, self};
106 ///
107 /// // Initialize the xAI client
108 /// let xai = Client::new("your-xai-api-key");
109 ///
110 /// let embeddings = xai.embeddings(xai::embedding::EMBEDDING_V1)
111 /// .simple_document("doc0", "Hello, world!")
112 /// .simple_document("doc1", "Goodbye, world!")
113 /// .build()
114 /// .await
115 /// .expect("Failed to embed documents");
116 /// ```
117 pub fn embeddings<D: Embed>(
118 &self,
119 model: &str,
120 ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
121 embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
122 }
123
124 /// Create a completion model with the given name.
125 pub fn completion_model(&self, model: &str) -> CompletionModel {
126 CompletionModel::new(self.clone(), model)
127 }
128
129 /// Create an agent builder with the given completion model.
130 /// # Example
131 /// ```
132 /// use bep::providers::xai::{Client, self};
133 ///
134 /// // Initialize the xAI client
135 /// let xai = Client::new("your-xai-api-key");
136 ///
137 /// let agent = xai.agent(xai::completion::GROK_BETA)
138 /// .preamble("You are comedian AI with a mission to make people laugh.")
139 /// .temperature(0.0)
140 /// .build();
141 /// ```
142 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
143 AgentBuilder::new(self.completion_model(model))
144 }
145
146 /// Create an extractor builder with the given completion model.
147 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
148 &self,
149 model: &str,
150 ) -> ExtractorBuilder<T, CompletionModel> {
151 ExtractorBuilder::new(self.completion_model(model))
152 }
153}
154
155pub mod xai_api_types {
156 use serde::Deserialize;
157
158 impl ApiErrorResponse {
159 pub fn message(&self) -> String {
160 format!("Code `{}`: {}", self.code, self.error)
161 }
162 }
163
164 #[derive(Debug, Deserialize)]
165 pub struct ApiErrorResponse {
166 pub error: String,
167 pub code: String,
168 }
169
170 #[derive(Debug, Deserialize)]
171 #[serde(untagged)]
172 pub enum ApiResponse<T> {
173 Ok(T),
174 Error(ApiErrorResponse),
175 }
176}