use crate::LlmBuilder;
use crate::error::LlmError;
use crate::retry_api::RetryOptions;
use crate::types::{CommonParams, HttpConfig, WebSearchConfig};
use super::client::XaiClient;
use super::config::XaiConfig;
#[derive(Debug, Clone)]
pub struct XaiBuilder {
config: XaiConfig,
tracing_config: Option<crate::tracing::TracingConfig>,
retry_options: Option<RetryOptions>,
}
impl XaiBuilder {
pub fn new() -> Self {
Self {
config: XaiConfig::default(),
tracing_config: None,
retry_options: None,
}
}
pub fn api_key<S: Into<String>>(mut self, api_key: S) -> Self {
self.config.api_key = api_key.into();
self
}
pub fn base_url<S: Into<String>>(mut self, base_url: S) -> Self {
self.config.base_url = base_url.into();
self
}
pub fn model<S: Into<String>>(mut self, model: S) -> Self {
self.config.common_params.model = model.into();
self
}
pub const fn temperature(mut self, temperature: f32) -> Self {
self.config.common_params.temperature = Some(temperature);
self
}
pub const fn max_tokens(mut self, max_tokens: u32) -> Self {
self.config.common_params.max_tokens = Some(max_tokens);
self
}
pub const fn top_p(mut self, top_p: f32) -> Self {
self.config.common_params.top_p = Some(top_p);
self
}
pub fn stop_sequences(mut self, stop: Vec<String>) -> Self {
self.config.common_params.stop_sequences = Some(stop);
self
}
pub const fn seed(mut self, seed: u64) -> Self {
self.config.common_params.seed = Some(seed);
self
}
pub fn common_params(mut self, params: CommonParams) -> Self {
self.config.common_params = params;
self
}
pub fn http_config(mut self, config: HttpConfig) -> Self {
self.config.http_config = config;
self
}
pub fn web_search_config(mut self, config: WebSearchConfig) -> Self {
self.config.web_search_config = config;
self
}
pub const fn enable_web_search(mut self) -> Self {
self.config.web_search_config.enabled = true;
self
}
pub fn config(mut self, config: XaiConfig) -> Self {
self.config = config;
self
}
pub fn tracing(mut self, config: crate::tracing::TracingConfig) -> Self {
self.tracing_config = Some(config);
self
}
pub fn debug_tracing(self) -> Self {
self.tracing(crate::tracing::TracingConfig::development())
}
pub fn minimal_tracing(self) -> Self {
self.tracing(crate::tracing::TracingConfig::minimal())
}
pub fn json_tracing(self) -> Self {
self.tracing(crate::tracing::TracingConfig::json_production())
}
pub fn pretty_json(mut self, pretty: bool) -> Self {
let config = self
.tracing_config
.take()
.unwrap_or_else(crate::tracing::TracingConfig::development)
.with_pretty_json(pretty);
self.tracing_config = Some(config);
self
}
pub fn mask_sensitive_values(mut self, mask: bool) -> Self {
let config = self
.tracing_config
.take()
.unwrap_or_else(crate::tracing::TracingConfig::development)
.with_mask_sensitive_values(mask);
self.tracing_config = Some(config);
self
}
pub fn with_retry(mut self, options: RetryOptions) -> Self {
self.retry_options = Some(options);
self
}
pub async fn build(self) -> Result<XaiClient, LlmError> {
self.config
.validate()
.map_err(|e| LlmError::InvalidInput(format!("Invalid xAI configuration: {e}")))?;
let _tracing_guard = if let Some(ref tracing_config) = self.tracing_config {
crate::tracing::init_tracing(tracing_config.clone())?
} else {
None
};
let mut config = self.config;
if config.common_params.model.is_empty() {
config.common_params.model = crate::providers::xai::models::popular::LATEST.to_string();
}
let mut client = XaiClient::new(config).await?;
client.set_tracing_guard(_tracing_guard);
client.set_tracing_config(self.tracing_config);
client.set_retry_options(self.retry_options.clone());
Ok(client)
}
pub async fn build_with_client(
self,
http_client: reqwest::Client,
) -> Result<XaiClient, LlmError> {
self.config
.validate()
.map_err(|e| LlmError::InvalidInput(format!("Invalid xAI configuration: {e}")))?;
let _tracing_guard = if let Some(ref tracing_config) = self.tracing_config {
crate::tracing::init_tracing(tracing_config.clone())?
} else {
None
};
let mut config = self.config;
if config.common_params.model.is_empty() {
config.common_params.model = crate::providers::xai::models::popular::LATEST.to_string();
}
let mut client = XaiClient::with_http_client(config, http_client).await?;
client.set_tracing_guard(_tracing_guard);
client.set_tracing_config(self.tracing_config);
client.set_retry_options(self.retry_options.clone());
Ok(client)
}
}
#[cfg(feature = "xai")]
pub struct XaiBuilderWrapper {
pub(crate) base: LlmBuilder,
xai_builder: crate::providers::xai::XaiBuilder,
}
#[cfg(feature = "xai")]
impl XaiBuilderWrapper {
pub fn new(base: LlmBuilder) -> Self {
Self {
base,
xai_builder: crate::providers::xai::XaiBuilder::new(),
}
}
pub fn api_key<S: Into<String>>(mut self, api_key: S) -> Self {
self.xai_builder = self.xai_builder.api_key(api_key);
self
}
pub fn base_url<S: Into<String>>(mut self, base_url: S) -> Self {
self.xai_builder = self.xai_builder.base_url(base_url);
self
}
pub fn model<S: Into<String>>(mut self, model: S) -> Self {
self.xai_builder = self.xai_builder.model(model);
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.xai_builder = self.xai_builder.temperature(temperature);
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.xai_builder = self.xai_builder.max_tokens(max_tokens);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.xai_builder = self.xai_builder.top_p(top_p);
self
}
pub fn stop_sequences(mut self, sequences: Vec<String>) -> Self {
self.xai_builder = self.xai_builder.stop_sequences(sequences);
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.xai_builder = self.xai_builder.seed(seed);
self
}
pub fn tracing(mut self, config: crate::tracing::TracingConfig) -> Self {
self.xai_builder = self.xai_builder.tracing(config);
self
}
pub fn debug_tracing(mut self) -> Self {
self.xai_builder = self.xai_builder.debug_tracing();
self
}
pub fn minimal_tracing(mut self) -> Self {
self.xai_builder = self.xai_builder.minimal_tracing();
self
}
pub fn json_tracing(mut self) -> Self {
self.xai_builder = self.xai_builder.json_tracing();
self
}
pub async fn build(self) -> Result<crate::providers::xai::XaiClient, LlmError> {
let http_client = self.base.build_http_client()?;
self.xai_builder.build_with_client(http_client).await
}
}
impl Default for XaiBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_creation() {
let builder = XaiBuilder::new();
assert_eq!(builder.config.base_url, "https://api.x.ai/v1");
assert!(builder.config.api_key.is_empty());
}
#[test]
fn test_builder_configuration() {
let builder = XaiBuilder::new()
.api_key("test-key")
.model("grok-3-latest")
.temperature(0.7)
.max_tokens(1000);
assert_eq!(builder.config.api_key, "test-key");
assert_eq!(builder.config.common_params.model, "grok-3-latest");
assert_eq!(builder.config.common_params.temperature, Some(0.7));
assert_eq!(builder.config.common_params.max_tokens, Some(1000));
}
#[tokio::test]
async fn test_builder_validation() {
let builder = XaiBuilder::new();
let result = builder.build().await;
assert!(result.is_err());
let builder = XaiBuilder::new().api_key("test-key");
let result = builder.build().await;
assert!(result.is_ok());
}
}