rig/providers/xai/
client.rs1use http::Method;
2
3use super::completion::CompletionModel;
4use crate::{
5 client::{CompletionClient, ProviderClient, VerifyClient, VerifyError, impl_conversion_traits},
6 http_client::{self, HttpClientExt, NoBody, Result as HttpResult, with_bearer_auth},
7};
8
9const XAI_BASE_URL: &str = "https://api.x.ai";
13
14pub struct ClientBuilder<'a, T = reqwest::Client> {
15 api_key: &'a str,
16 base_url: &'a str,
17 http_client: T,
18}
19
20impl<'a, T> ClientBuilder<'a, T>
21where
22 T: Default,
23{
24 pub fn new(api_key: &'a str) -> Self {
25 Self {
26 api_key,
27 base_url: XAI_BASE_URL,
28 http_client: Default::default(),
29 }
30 }
31}
32
33impl<'a, T> ClientBuilder<'a, T> {
34 pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
35 Self {
36 api_key,
37 base_url: XAI_BASE_URL,
38 http_client,
39 }
40 }
41
42 pub fn base_url(mut self, base_url: &'a str) -> Self {
43 self.base_url = base_url;
44 self
45 }
46
47 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
48 ClientBuilder {
49 api_key: self.api_key,
50 base_url: self.base_url,
51 http_client,
52 }
53 }
54
55 pub fn build(self) -> Client<T> {
56 let mut default_headers = reqwest::header::HeaderMap::new();
57 default_headers.insert(
58 reqwest::header::CONTENT_TYPE,
59 "application/json".parse().unwrap(),
60 );
61
62 Client {
63 base_url: self.base_url.to_string(),
64 api_key: self.api_key.to_string(),
65 default_headers,
66 http_client: self.http_client,
67 }
68 }
69}
70
71#[derive(Clone)]
72pub struct Client<T = reqwest::Client> {
73 base_url: String,
74 api_key: String,
75 default_headers: http_client::HeaderMap,
76 pub http_client: T,
77}
78
79impl<T> std::fmt::Debug for Client<T>
80where
81 T: std::fmt::Debug,
82{
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 f.debug_struct("Client")
85 .field("base_url", &self.base_url)
86 .field("http_client", &self.http_client)
87 .field("default_headers", &self.default_headers)
88 .field("api_key", &"<REDACTED>")
89 .finish()
90 }
91}
92
93impl Client<reqwest::Client> {
94 pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
105 ClientBuilder::new(api_key)
106 }
107
108 pub fn new(api_key: &str) -> Self {
113 Self::builder(api_key).build()
114 }
115
116 pub fn from_env() -> Self {
117 <Self as ProviderClient>::from_env()
118 }
119}
120
121impl<T> Client<T>
122where
123 T: HttpClientExt,
124{
125 pub(crate) fn req(&self, method: Method, path: &str) -> HttpResult<http_client::Builder> {
126 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
127
128 let mut builder = http_client::Builder::new().uri(url).method(method);
129 for (header, value) in &self.default_headers {
130 builder = builder.header(header, value);
131 }
132
133 with_bearer_auth(builder, &self.api_key)
134 }
135
136 pub(crate) fn post(&self, path: &str) -> HttpResult<http_client::Builder> {
137 self.req(Method::POST, path)
138 }
139
140 pub(crate) fn get(&self, path: &str) -> HttpResult<http_client::Builder> {
141 self.req(Method::GET, path)
142 }
143}
144
145impl<T> ProviderClient for Client<T>
146where
147 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
148{
149 fn from_env() -> Self {
152 let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set");
153 ClientBuilder::<T>::new(&api_key).build()
154 }
155
156 fn from_val(input: crate::client::ProviderValue) -> Self {
157 let crate::client::ProviderValue::Simple(api_key) = input else {
158 panic!("Incorrect provider value type")
159 };
160 ClientBuilder::<T>::new(&api_key).build()
161 }
162}
163
164impl<T> CompletionClient for Client<T>
165where
166 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
167{
168 type CompletionModel = CompletionModel<T>;
169
170 fn completion_model(&self, model: &str) -> CompletionModel<T> {
172 CompletionModel::new(self.clone(), model)
173 }
174}
175
176impl<T> VerifyClient for Client<T>
177where
178 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
179{
180 #[cfg_attr(feature = "worker", worker::send)]
181 async fn verify(&self) -> Result<(), VerifyError> {
182 let req = self.get("/v1/api-key").unwrap().body(NoBody).unwrap();
183
184 let response = self.http_client.send::<_, Vec<u8>>(req).await.unwrap();
185 let status = response.status();
186
187 match status {
188 reqwest::StatusCode::OK => Ok(()),
189 reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
190 Err(VerifyError::InvalidAuthentication)
191 }
192 reqwest::StatusCode::INTERNAL_SERVER_ERROR => Err(VerifyError::ProviderError(
193 http_client::text(response).await?,
194 )),
195 _ => Err(VerifyError::HttpError(http_client::Error::Instance(
196 http_client::text(response).await?.into(),
197 ))),
198 }
199 }
200}
201
202impl_conversion_traits!(
203 AsEmbeddings,
204 AsTranscription,
205 AsImageGeneration,
206 AsAudioGeneration for Client<T>
207);
208
209pub mod xai_api_types {
210 use serde::Deserialize;
211
212 impl ApiErrorResponse {
213 pub fn message(&self) -> String {
214 format!("Code `{}`: {}", self.code, self.error)
215 }
216 }
217
218 #[derive(Debug, Deserialize)]
219 pub struct ApiErrorResponse {
220 pub error: String,
221 pub code: String,
222 }
223
224 #[derive(Debug, Deserialize)]
225 #[serde(untagged)]
226 pub enum ApiResponse<T> {
227 Ok(T),
228 Error(ApiErrorResponse),
229 }
230}