rig/providers/xai/
client.rs1use super::completion::CompletionModel;
2use crate::client::{CompletionClient, ProviderClient, impl_conversion_traits};
3
4const XAI_BASE_URL: &str = "https://api.x.ai";
8
9#[derive(Clone)]
10pub struct Client {
11 base_url: String,
12 api_key: String,
13 default_headers: reqwest::header::HeaderMap,
14 http_client: reqwest::Client,
15}
16
17impl std::fmt::Debug for Client {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 f.debug_struct("Client")
20 .field("base_url", &self.base_url)
21 .field("http_client", &self.http_client)
22 .field("default_headers", &self.default_headers)
23 .field("api_key", &"<REDACTED>")
24 .finish()
25 }
26}
27
28impl Client {
29 pub fn new(api_key: &str) -> Self {
30 Self::from_url(api_key, XAI_BASE_URL)
31 }
32
33 fn from_url(api_key: &str, base_url: &str) -> Self {
34 let mut default_headers = reqwest::header::HeaderMap::new();
35 default_headers.insert(
36 reqwest::header::CONTENT_TYPE,
37 "application/json".parse().unwrap(),
38 );
39
40 Self {
41 base_url: base_url.to_string(),
42 api_key: api_key.to_string(),
43 default_headers,
44 http_client: reqwest::Client::builder()
45 .build()
46 .expect("xAI reqwest client should build"),
47 }
48 }
49
50 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
53 self.http_client = client;
54
55 self
56 }
57
58 pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
59 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
60
61 tracing::debug!("POST {}", url);
62 self.http_client
63 .post(url)
64 .bearer_auth(&self.api_key)
65 .headers(self.default_headers.clone())
66 }
67}
68
69impl ProviderClient for Client {
70 fn from_env() -> Self {
73 let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set");
74 Self::new(&api_key)
75 }
76
77 fn from_val(input: crate::client::ProviderValue) -> Self {
78 let crate::client::ProviderValue::Simple(api_key) = input else {
79 panic!("Incorrect provider value type")
80 };
81 Self::new(&api_key)
82 }
83}
84
85impl CompletionClient for Client {
86 type CompletionModel = CompletionModel;
87
88 fn completion_model(&self, model: &str) -> CompletionModel {
90 CompletionModel::new(self.clone(), model)
91 }
92}
93
94impl_conversion_traits!(
95 AsEmbeddings,
96 AsTranscription,
97 AsImageGeneration,
98 AsAudioGeneration for Client
99);
100
101pub mod xai_api_types {
102 use serde::Deserialize;
103
104 impl ApiErrorResponse {
105 pub fn message(&self) -> String {
106 format!("Code `{}`: {}", self.code, self.error)
107 }
108 }
109
110 #[derive(Debug, Deserialize)]
111 pub struct ApiErrorResponse {
112 pub error: String,
113 pub code: String,
114 }
115
116 #[derive(Debug, Deserialize)]
117 #[serde(untagged)]
118 pub enum ApiResponse<T> {
119 Ok(T),
120 Error(ApiErrorResponse),
121 }
122}