use std::collections::HashMap;
use crate::types::{CommonParams, HttpConfig, WebSearchConfig};
#[derive(Debug, Clone)]
pub struct XaiConfig {
pub api_key: String,
pub base_url: String,
pub common_params: CommonParams,
pub http_config: HttpConfig,
pub web_search_config: WebSearchConfig,
}
impl XaiConfig {
pub fn new<S: Into<String>>(api_key: S) -> Self {
Self {
api_key: api_key.into(),
base_url: "https://api.x.ai/v1".to_string(),
common_params: CommonParams::default(),
http_config: HttpConfig::default(),
web_search_config: WebSearchConfig::default(),
}
}
pub fn with_base_url<S: Into<String>>(mut self, url: S) -> Self {
self.base_url = url.into();
self
}
pub fn with_model<S: Into<String>>(mut self, model: S) -> Self {
self.common_params.model = model.into();
self
}
pub const fn with_temperature(mut self, temperature: f32) -> Self {
self.common_params.temperature = Some(temperature);
self
}
pub const fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.common_params.max_tokens = Some(max_tokens);
self
}
pub fn with_web_search(mut self, config: Option<WebSearchConfig>) -> Self {
self.web_search_config = config.unwrap_or_else(|| WebSearchConfig {
enabled: true,
..Default::default()
});
self
}
pub const fn enable_web_search(mut self) -> Self {
self.web_search_config.enabled = true;
self
}
pub fn auth_header(&self) -> String {
format!("Bearer {}", self.api_key)
}
pub fn get_headers(&self) -> HashMap<String, String> {
let mut headers = HashMap::new();
headers.insert("Authorization".to_string(), self.auth_header());
headers.insert("Content-Type".to_string(), "application/json".to_string());
headers
}
pub fn validate(&self) -> Result<(), String> {
if self.api_key.is_empty() {
return Err("API key cannot be empty".to_string());
}
if self.base_url.is_empty() {
return Err("Base URL cannot be empty".to_string());
}
if !self.base_url.starts_with("http://") && !self.base_url.starts_with("https://") {
return Err("Base URL must start with http:// or https://".to_string());
}
if let Some(temp) = self.common_params.temperature
&& !(0.0..=2.0).contains(&temp)
{
return Err("Temperature must be between 0.0 and 2.0".to_string());
}
if let Some(top_p) = self.common_params.top_p
&& !(0.0..=1.0).contains(&top_p)
{
return Err("Top-p must be between 0.0 and 1.0".to_string());
}
Ok(())
}
}
impl Default for XaiConfig {
fn default() -> Self {
Self {
api_key: String::new(),
base_url: "https://api.x.ai/v1".to_string(),
common_params: CommonParams::default(),
http_config: HttpConfig::default(),
web_search_config: WebSearchConfig::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_creation() {
let config = XaiConfig::new("test-key");
assert_eq!(config.api_key, "test-key");
assert_eq!(config.base_url, "https://api.x.ai/v1");
}
#[test]
fn test_config_validation() {
let mut config = XaiConfig::new("test-key");
assert!(config.validate().is_ok());
config.api_key = String::new();
assert!(config.validate().is_err());
}
#[test]
fn test_headers() {
let config = XaiConfig::new("test-key");
let headers = config.get_headers();
assert_eq!(
headers.get("Authorization"),
Some(&"Bearer test-key".to_string())
);
assert_eq!(
headers.get("Content-Type"),
Some(&"application/json".to_string())
);
}
}