use async_trait::async_trait;
use std::time::Duration;
use crate::client::LlmClient;
use crate::error::LlmError;
use crate::retry_api::RetryOptions;
use crate::stream::ChatStream;
use crate::traits::{ChatCapability, ModelListingCapability, ProviderCapabilities};
use crate::types::*;
use super::api::XaiModels;
use super::chat::XaiChatCapability;
use super::config::XaiConfig;
#[derive(Debug)]
pub struct XaiClient {
pub chat_capability: XaiChatCapability,
pub models_capability: XaiModels,
pub common_params: CommonParams,
pub http_client: reqwest::Client,
tracing_config: Option<crate::tracing::TracingConfig>,
_tracing_guard: Option<tracing_appender::non_blocking::WorkerGuard>,
retry_options: Option<RetryOptions>,
}
impl Clone for XaiClient {
fn clone(&self) -> Self {
Self {
chat_capability: self.chat_capability.clone(),
models_capability: self.models_capability.clone(),
common_params: self.common_params.clone(),
http_client: self.http_client.clone(),
tracing_config: self.tracing_config.clone(),
_tracing_guard: None, retry_options: self.retry_options.clone(),
}
}
}
impl XaiClient {
pub async fn new(config: XaiConfig) -> Result<Self, LlmError> {
config
.validate()
.map_err(|e| LlmError::InvalidInput(format!("Invalid xAI configuration: {e}")))?;
let http_client = reqwest::Client::builder()
.timeout(
config
.http_config
.timeout
.unwrap_or(Duration::from_secs(30)),
)
.build()
.map_err(|e| LlmError::HttpError(format!("Failed to create HTTP client: {e}")))?;
Self::with_http_client(config, http_client).await
}
pub async fn with_http_client(
config: XaiConfig,
http_client: reqwest::Client,
) -> Result<Self, LlmError> {
config
.validate()
.map_err(|e| LlmError::InvalidInput(format!("Invalid xAI configuration: {e}")))?;
let chat_capability = XaiChatCapability::new(
config.api_key.clone(),
config.base_url.clone(),
http_client.clone(),
config.http_config.clone(),
config.common_params.clone(),
);
let models_capability = XaiModels::new(
config.api_key.clone(),
config.base_url.clone(),
http_client.clone(),
config.http_config.clone(),
);
Ok(Self {
chat_capability,
models_capability,
common_params: config.common_params,
http_client,
tracing_config: None,
_tracing_guard: None,
retry_options: None,
})
}
pub fn config(&self) -> XaiConfig {
XaiConfig {
api_key: self.chat_capability.api_key.clone(),
base_url: self.chat_capability.base_url.clone(),
common_params: self.common_params.clone(),
http_config: self.chat_capability.http_config.clone(),
web_search_config: WebSearchConfig::default(),
}
}
pub fn with_common_params(mut self, params: CommonParams) -> Self {
self.common_params = params;
self
}
pub fn with_model<S: Into<String>>(mut self, model: S) -> Self {
self.common_params.model = model.into();
self
}
pub const fn with_temperature(mut self, temperature: f32) -> Self {
self.common_params.temperature = Some(temperature);
self
}
pub const fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.common_params.max_tokens = Some(max_tokens);
self
}
}
#[async_trait]
impl LlmClient for XaiClient {
fn provider_name(&self) -> &'static str {
"xai"
}
fn supported_models(&self) -> Vec<String> {
crate::providers::xai::models::all_models()
.into_iter()
.map(|s| s.to_string())
.collect()
}
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities::new()
.with_chat()
.with_streaming()
.with_tools()
.with_vision()
.with_custom_feature("reasoning", true)
.with_custom_feature("deferred_completion", true)
.with_custom_feature("structured_outputs", true)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn clone_box(&self) -> Box<dyn LlmClient> {
Box::new(self.clone())
}
}
impl XaiClient {
async fn chat_with_tools_inner(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
) -> Result<ChatResponse, LlmError> {
let request = ChatRequest {
messages,
tools,
common_params: self.common_params.clone(),
provider_params: None,
http_config: None,
web_search: None,
stream: false,
};
self.chat_capability.chat(request).await
}
}
#[async_trait]
impl ChatCapability for XaiClient {
async fn chat_with_tools(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
) -> Result<ChatResponse, LlmError> {
if let Some(opts) = &self.retry_options {
crate::retry_api::retry_with(
|| {
let m = messages.clone();
let t = tools.clone();
async move { self.chat_with_tools_inner(m, t).await }
},
opts.clone(),
)
.await
} else {
self.chat_with_tools_inner(messages, tools).await
}
}
async fn chat_stream(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
) -> Result<ChatStream, LlmError> {
self.chat_capability.chat_stream(messages, tools).await
}
}
impl XaiClient {
pub async fn chat_with_reasoning(
&self,
messages: Vec<ChatMessage>,
reasoning_effort: &str,
) -> Result<ChatResponse, LlmError> {
let mut provider_params = std::collections::HashMap::new();
provider_params.insert(
"reasoning_effort".to_string(),
serde_json::Value::String(reasoning_effort.to_string()),
);
let request = ChatRequest {
messages,
tools: None,
common_params: self.common_params.clone(),
provider_params: Some(ProviderParams {
params: provider_params,
}),
http_config: None,
web_search: None,
stream: false,
};
self.chat_capability.chat(request).await
}
pub async fn create_deferred_completion(
&self,
messages: Vec<ChatMessage>,
) -> Result<String, LlmError> {
let mut provider_params = std::collections::HashMap::new();
provider_params.insert("deferred".to_string(), serde_json::Value::Bool(true));
let request = ChatRequest {
messages,
tools: None,
common_params: self.common_params.clone(),
provider_params: Some(ProviderParams {
params: provider_params,
}),
http_config: None,
web_search: None,
stream: false,
};
let _response = self.chat_capability.chat(request).await?;
Err(LlmError::UnsupportedOperation(
"Deferred completion not implemented yet".to_string(),
))
}
pub async fn get_deferred_completion(
&self,
request_id: &str,
) -> Result<ChatResponse, LlmError> {
let url = format!(
"{}/chat/deferred-completion/{}",
self.chat_capability.base_url, request_id
);
let headers = super::utils::build_headers(
&self.chat_capability.api_key,
&self.chat_capability.http_config.headers,
)?;
let response = self.http_client.get(&url).headers(headers).send().await?;
match response.status().as_u16() {
200 => {
let _xai_response: super::types::XaiChatResponse = response.json().await?;
Err(LlmError::UnsupportedOperation(
"Get deferred completion not implemented yet".to_string(),
))
}
202 => Err(LlmError::ApiError {
code: 202,
message: "Deferred completion not ready yet".to_string(),
details: None,
}),
_ => {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
Err(LlmError::ApiError {
code: status.as_u16(),
message: format!("xAI API error: {error_text}"),
details: serde_json::from_str(&error_text).ok(),
})
}
}
}
pub(crate) fn set_tracing_guard(
&mut self,
guard: Option<tracing_appender::non_blocking::WorkerGuard>,
) {
self._tracing_guard = guard;
}
pub(crate) fn set_tracing_config(&mut self, config: Option<crate::tracing::TracingConfig>) {
self.tracing_config = config;
}
pub fn set_retry_options(&mut self, options: Option<RetryOptions>) {
self.retry_options = options;
}
}
#[async_trait]
impl ModelListingCapability for XaiClient {
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
self.models_capability.list_models().await
}
async fn get_model(&self, model_id: String) -> Result<ModelInfo, LlmError> {
self.models_capability.get_model(model_id).await
}
}