rig/providers/xai/
client.rs

1use super::completion::CompletionModel;
2use crate::client::{CompletionClient, ProviderClient, impl_conversion_traits};
3
4// ================================================================
5// xAI Client
6// ================================================================
7const XAI_BASE_URL: &str = "https://api.x.ai";
8
9#[derive(Clone)]
10pub struct Client {
11    base_url: String,
12    api_key: String,
13    default_headers: reqwest::header::HeaderMap,
14    http_client: reqwest::Client,
15}
16
17impl std::fmt::Debug for Client {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_struct("Client")
20            .field("base_url", &self.base_url)
21            .field("http_client", &self.http_client)
22            .field("default_headers", &self.default_headers)
23            .field("api_key", &"<REDACTED>")
24            .finish()
25    }
26}
27
28impl Client {
29    pub fn new(api_key: &str) -> Self {
30        Self::from_url(api_key, XAI_BASE_URL)
31    }
32
33    fn from_url(api_key: &str, base_url: &str) -> Self {
34        let mut default_headers = reqwest::header::HeaderMap::new();
35        default_headers.insert(
36            reqwest::header::CONTENT_TYPE,
37            "application/json".parse().unwrap(),
38        );
39
40        Self {
41            base_url: base_url.to_string(),
42            api_key: api_key.to_string(),
43            default_headers,
44            http_client: reqwest::Client::builder()
45                .build()
46                .expect("xAI reqwest client should build"),
47        }
48    }
49
50    /// Use your own `reqwest::Client`.
51    /// The required headers will be automatically attached upon trying to make a request.
52    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
53        self.http_client = client;
54
55        self
56    }
57
58    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
59        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
60
61        tracing::debug!("POST {}", url);
62        self.http_client
63            .post(url)
64            .bearer_auth(&self.api_key)
65            .headers(self.default_headers.clone())
66    }
67}
68
69impl ProviderClient for Client {
70    /// Create a new xAI client from the `XAI_API_KEY` environment variable.
71    /// Panics if the environment variable is not set.
72    fn from_env() -> Self {
73        let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set");
74        Self::new(&api_key)
75    }
76
77    fn from_val(input: crate::client::ProviderValue) -> Self {
78        let crate::client::ProviderValue::Simple(api_key) = input else {
79            panic!("Incorrect provider value type")
80        };
81        Self::new(&api_key)
82    }
83}
84
85impl CompletionClient for Client {
86    type CompletionModel = CompletionModel;
87
88    /// Create a completion model with the given name.
89    fn completion_model(&self, model: &str) -> CompletionModel {
90        CompletionModel::new(self.clone(), model)
91    }
92}
93
94impl_conversion_traits!(
95    AsEmbeddings,
96    AsTranscription,
97    AsImageGeneration,
98    AsAudioGeneration for Client
99);
100
101pub mod xai_api_types {
102    use serde::Deserialize;
103
104    impl ApiErrorResponse {
105        pub fn message(&self) -> String {
106            format!("Code `{}`: {}", self.code, self.error)
107        }
108    }
109
110    #[derive(Debug, Deserialize)]
111    pub struct ApiErrorResponse {
112        pub error: String,
113        pub code: String,
114    }
115
116    #[derive(Debug, Deserialize)]
117    #[serde(untagged)]
118    pub enum ApiResponse<T> {
119        Ok(T),
120        Error(ApiErrorResponse),
121    }
122}