rig/providers/xai/
client.rs1use super::completion::CompletionModel;
2use crate::{
3 client::{CompletionClient, ProviderClient, VerifyClient, VerifyError, impl_conversion_traits},
4 http_client,
5};
6
7const XAI_BASE_URL: &str = "https://api.x.ai";
11
12pub struct ClientBuilder<'a, T = reqwest::Client> {
13 api_key: &'a str,
14 base_url: &'a str,
15 http_client: T,
16}
17
18impl<'a, T> ClientBuilder<'a, T>
19where
20 T: Default,
21{
22 pub fn new(api_key: &'a str) -> Self {
23 Self {
24 api_key,
25 base_url: XAI_BASE_URL,
26 http_client: Default::default(),
27 }
28 }
29}
30
31impl<'a, T> ClientBuilder<'a, T> {
32 pub fn base_url(mut self, base_url: &'a str) -> Self {
33 self.base_url = base_url;
34 self
35 }
36
37 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
38 ClientBuilder {
39 api_key: self.api_key,
40 base_url: self.base_url,
41 http_client,
42 }
43 }
44
45 pub fn build(self) -> Client<T> {
46 let mut default_headers = reqwest::header::HeaderMap::new();
47 default_headers.insert(
48 reqwest::header::CONTENT_TYPE,
49 "application/json".parse().unwrap(),
50 );
51
52 Client {
53 base_url: self.base_url.to_string(),
54 api_key: self.api_key.to_string(),
55 default_headers,
56 http_client: self.http_client,
57 }
58 }
59}
60
61#[derive(Clone)]
62pub struct Client<T = reqwest::Client> {
63 base_url: String,
64 api_key: String,
65 default_headers: http_client::HeaderMap,
66 http_client: T,
67}
68
69impl<T> std::fmt::Debug for Client<T>
70where
71 T: std::fmt::Debug,
72{
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 f.debug_struct("Client")
75 .field("base_url", &self.base_url)
76 .field("http_client", &self.http_client)
77 .field("default_headers", &self.default_headers)
78 .field("api_key", &"<REDACTED>")
79 .finish()
80 }
81}
82
83impl<T> Client<T>
84where
85 T: Default,
86{
87 pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
98 ClientBuilder::new(api_key)
99 }
100
101 pub fn new(api_key: &str) -> Self {
106 Self::builder(api_key).build()
107 }
108}
109
110impl Client<reqwest::Client> {
111 pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
112 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
113
114 tracing::debug!("POST {}", url);
115
116 self.http_client
117 .post(url)
118 .bearer_auth(&self.api_key)
119 .headers(self.default_headers.clone())
120 }
121
122 pub(crate) fn reqwest_get(&self, path: &str) -> reqwest::RequestBuilder {
123 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
124
125 tracing::debug!("GET {}", url);
126
127 self.http_client
128 .get(url)
129 .bearer_auth(&self.api_key)
130 .headers(self.default_headers.clone())
131 }
132}
133
134impl ProviderClient for Client<reqwest::Client> {
135 fn from_env() -> Self {
138 let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set");
139 Self::new(&api_key)
140 }
141
142 fn from_val(input: crate::client::ProviderValue) -> Self {
143 let crate::client::ProviderValue::Simple(api_key) = input else {
144 panic!("Incorrect provider value type")
145 };
146 Self::new(&api_key)
147 }
148}
149
150impl CompletionClient for Client<reqwest::Client> {
151 type CompletionModel = CompletionModel<reqwest::Client>;
152
153 fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
155 CompletionModel::new(self.clone(), model)
156 }
157}
158
159impl VerifyClient for Client<reqwest::Client> {
160 #[cfg_attr(feature = "worker", worker::send)]
161 async fn verify(&self) -> Result<(), VerifyError> {
162 let response = self
163 .reqwest_get("/v1/api-key")
164 .send()
165 .await
166 .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
167
168 match response.status() {
169 reqwest::StatusCode::OK => Ok(()),
170 reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
171 Err(VerifyError::InvalidAuthentication)
172 }
173 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
174 Err(VerifyError::ProviderError(response.text().await.map_err(
175 |e| VerifyError::HttpError(http_client::Error::Instance(e.into())),
176 )?))
177 }
178 _ => {
179 response
180 .error_for_status()
181 .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
182 Ok(())
183 }
184 }
185 }
186}
187
188impl_conversion_traits!(
189 AsEmbeddings,
190 AsTranscription,
191 AsImageGeneration,
192 AsAudioGeneration for Client<T>
193);
194
195pub mod xai_api_types {
196 use serde::Deserialize;
197
198 impl ApiErrorResponse {
199 pub fn message(&self) -> String {
200 format!("Code `{}`: {}", self.code, self.error)
201 }
202 }
203
204 #[derive(Debug, Deserialize)]
205 pub struct ApiErrorResponse {
206 pub error: String,
207 pub code: String,
208 }
209
210 #[derive(Debug, Deserialize)]
211 #[serde(untagged)]
212 pub enum ApiResponse<T> {
213 Ok(T),
214 Error(ApiErrorResponse),
215 }
216}