rig/providers/xai/
client.rs

1use super::completion::CompletionModel;
2use crate::client::{CompletionClient, ProviderClient, impl_conversion_traits};
3
4// ================================================================
5// xAI Client
6// ================================================================
7const XAI_BASE_URL: &str = "https://api.x.ai";
8
9#[derive(Clone)]
10pub struct Client {
11    base_url: String,
12    api_key: String,
13    default_headers: reqwest::header::HeaderMap,
14    http_client: reqwest::Client,
15}
16
17impl std::fmt::Debug for Client {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_struct("Client")
20            .field("base_url", &self.base_url)
21            .field("http_client", &self.http_client)
22            .field("default_headers", &self.default_headers)
23            .field("api_key", &"<REDACTED>")
24            .finish()
25    }
26}
27
28impl Client {
29    pub fn new(api_key: &str) -> Self {
30        Self::from_url(api_key, XAI_BASE_URL)
31    }
32
33    fn from_url(api_key: &str, base_url: &str) -> Self {
34        let mut default_headers = reqwest::header::HeaderMap::new();
35        default_headers.insert(
36            reqwest::header::CONTENT_TYPE,
37            "application/json".parse().unwrap(),
38        );
39
40        Self {
41            base_url: base_url.to_string(),
42            api_key: api_key.to_string(),
43            default_headers,
44            http_client: reqwest::Client::builder()
45                .build()
46                .expect("xAI reqwest client should build"),
47        }
48    }
49
50    /// Use your own `reqwest::Client`.
51    /// The required headers will be automatically attached upon trying to make a request.
52    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
53        self.http_client = client;
54
55        self
56    }
57
58    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
59        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
60
61        tracing::debug!("POST {}", url);
62        self.http_client
63            .post(url)
64            .bearer_auth(&self.api_key)
65            .headers(self.default_headers.clone())
66    }
67}
68
69impl ProviderClient for Client {
70    /// Create a new xAI client from the `XAI_API_KEY` environment variable.
71    /// Panics if the environment variable is not set.
72    fn from_env() -> Self {
73        let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set");
74        Self::new(&api_key)
75    }
76}
77
78impl CompletionClient for Client {
79    type CompletionModel = CompletionModel;
80
81    /// Create a completion model with the given name.
82    fn completion_model(&self, model: &str) -> CompletionModel {
83        CompletionModel::new(self.clone(), model)
84    }
85}
86
87impl_conversion_traits!(
88    AsEmbeddings,
89    AsTranscription,
90    AsImageGeneration,
91    AsAudioGeneration for Client
92);
93
94pub mod xai_api_types {
95    use serde::Deserialize;
96
97    impl ApiErrorResponse {
98        pub fn message(&self) -> String {
99            format!("Code `{}`: {}", self.code, self.error)
100        }
101    }
102
103    #[derive(Debug, Deserialize)]
104    pub struct ApiErrorResponse {
105        pub error: String,
106        pub code: String,
107    }
108
109    #[derive(Debug, Deserialize)]
110    #[serde(untagged)]
111    pub enum ApiResponse<T> {
112        Ok(T),
113        Error(ApiErrorResponse),
114    }
115}