rig/providers/xai/
client.rs1use super::{completion::CompletionModel, embedding::EmbeddingModel, EMBEDDING_V1};
2use crate::client::{impl_conversion_traits, CompletionClient, EmbeddingsClient, ProviderClient};
3
4const XAI_BASE_URL: &str = "https://api.x.ai";
8
9#[derive(Clone, Debug)]
10pub struct Client {
11 base_url: String,
12 http_client: reqwest::Client,
13}
14
15impl Client {
16 pub fn new(api_key: &str) -> Self {
17 Self::from_url(api_key, XAI_BASE_URL)
18 }
19 fn from_url(api_key: &str, base_url: &str) -> Self {
20 Self {
21 base_url: base_url.to_string(),
22 http_client: reqwest::Client::builder()
23 .default_headers({
24 let mut headers = reqwest::header::HeaderMap::new();
25 headers.insert(
26 reqwest::header::CONTENT_TYPE,
27 "application/json".parse().unwrap(),
28 );
29 headers.insert(
30 "Authorization",
31 format!("Bearer {api_key}")
32 .parse()
33 .expect("Bearer token should parse"),
34 );
35 headers
36 })
37 .build()
38 .expect("xAI reqwest client should build"),
39 }
40 }
41
42 pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
43 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
44
45 tracing::debug!("POST {}", url);
46 self.http_client.post(url)
47 }
48}
49
50impl ProviderClient for Client {
51 fn from_env() -> Self {
54 let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set");
55 Self::new(&api_key)
56 }
57}
58
59impl CompletionClient for Client {
60 type CompletionModel = CompletionModel;
61
62 fn completion_model(&self, model: &str) -> CompletionModel {
64 CompletionModel::new(self.clone(), model)
65 }
66}
67
68impl EmbeddingsClient for Client {
69 type EmbeddingModel = EmbeddingModel;
70 fn embedding_model(&self, model: &str) -> EmbeddingModel {
84 let ndims = match model {
85 EMBEDDING_V1 => 3072,
86 _ => 0,
87 };
88 EmbeddingModel::new(self.clone(), model, ndims)
89 }
90
91 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
104 EmbeddingModel::new(self.clone(), model, ndims)
105 }
106}
107
108impl_conversion_traits!(
109 AsTranscription,
110 AsImageGeneration,
111 AsAudioGeneration for Client
112);
113
114pub mod xai_api_types {
115 use serde::Deserialize;
116
117 impl ApiErrorResponse {
118 pub fn message(&self) -> String {
119 format!("Code `{}`: {}", self.code, self.error)
120 }
121 }
122
123 #[derive(Debug, Deserialize)]
124 pub struct ApiErrorResponse {
125 pub error: String,
126 pub code: String,
127 }
128
129 #[derive(Debug, Deserialize)]
130 #[serde(untagged)]
131 pub enum ApiResponse<T> {
132 Ok(T),
133 Error(ApiErrorResponse),
134 }
135}