use crate::error::LlmError;
use crate::traits::ModelListingCapability;
use crate::types::{HttpConfig, ModelInfo};
use async_trait::async_trait;
use super::types::*;
use super::utils::build_headers;
#[derive(Debug, Clone)]
pub struct XaiModels {
pub api_key: String,
pub base_url: String,
pub http_client: reqwest::Client,
pub http_config: HttpConfig,
}
impl XaiModels {
pub const fn new(
api_key: String,
base_url: String,
http_client: reqwest::Client,
http_config: HttpConfig,
) -> Self {
Self {
api_key,
base_url,
http_client,
http_config,
}
}
fn convert_xai_model(&self, xai_model: XaiModel) -> ModelInfo {
let model_id = xai_model.id.clone();
ModelInfo {
id: model_id.clone(),
name: Some(model_id),
description: None, owned_by: xai_model.owned_by.clone(),
created: Some(xai_model.created),
capabilities: self.get_model_capabilities(&xai_model),
context_window: None, max_output_tokens: None, input_cost_per_token: None, output_cost_per_token: None,
}
}
fn get_model_capabilities(&self, model: &XaiModel) -> Vec<String> {
let mut capabilities = Vec::new();
capabilities.push("chat".to_string());
capabilities.push("streaming".to_string());
if !model.id.contains("beta") {
capabilities.push("function_calling".to_string());
}
if model.id.contains("grok-4") {
capabilities.push("vision".to_string());
}
if model.id.contains("grok-4") {
capabilities.push("reasoning".to_string());
}
if model.id.contains("image") {
capabilities.push("image_generation".to_string());
}
capabilities
}
async fn list_models_from_api(&self) -> Result<Vec<ModelInfo>, LlmError> {
let url = format!("{}/models", self.base_url);
let headers = build_headers(&self.api_key, &self.http_config.headers)?;
let response = self.http_client.get(&url).headers(headers).send().await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
let error_message = error_text.clone();
return Err(LlmError::ApiError {
code: status.as_u16(),
message: format!("xAI list models error: {error_message}"),
details: serde_json::from_str(&error_text).ok(),
});
}
let xai_response: XaiModelsResponse = response.json().await?;
let models = xai_response
.data
.into_iter()
.map(|m| self.convert_xai_model(m))
.collect();
Ok(models)
}
async fn get_model_from_api(&self, model_id: String) -> Result<ModelInfo, LlmError> {
let url = format!("{}/models/{}", self.base_url, model_id);
let headers = build_headers(&self.api_key, &self.http_config.headers)?;
let response = self.http_client.get(&url).headers(headers).send().await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
let error_message = error_text.clone();
return Err(LlmError::ApiError {
code: status.as_u16(),
message: format!("xAI get model error: {error_message}"),
details: serde_json::from_str(&error_text).ok(),
});
}
let xai_model: XaiModel = response.json().await?;
Ok(self.convert_xai_model(xai_model))
}
fn get_hardcoded_models(&self) -> Vec<ModelInfo> {
crate::providers::xai::models::all_models()
.into_iter()
.map(|model_id| ModelInfo {
id: model_id.to_string(),
name: Some(model_id.to_string()),
description: Some(format!("xAI {} model", model_id)),
owned_by: "xai".to_string(),
created: None,
capabilities: self.get_hardcoded_capabilities(model_id),
context_window: None,
max_output_tokens: None,
input_cost_per_token: None,
output_cost_per_token: None,
})
.collect()
}
fn get_hardcoded_capabilities(&self, model_id: &str) -> Vec<String> {
let mut capabilities = vec!["chat".to_string(), "streaming".to_string()];
if model_id.contains("grok-4") {
capabilities.extend_from_slice(&[
"vision".to_string(),
"reasoning".to_string(),
"function_calling".to_string(),
]);
} else if !model_id.contains("beta") {
capabilities.push("function_calling".to_string());
}
if model_id.contains("image") {
capabilities.push("image_generation".to_string());
}
capabilities
}
}
#[async_trait]
impl ModelListingCapability for XaiModels {
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
match self.list_models_from_api().await {
Ok(models) => Ok(models),
Err(_) => {
Ok(self.get_hardcoded_models())
}
}
}
async fn get_model(&self, model_id: String) -> Result<ModelInfo, LlmError> {
match self.get_model_from_api(model_id.clone()).await {
Ok(model) => Ok(model),
Err(_) => {
let hardcoded_models = self.get_hardcoded_models();
hardcoded_models
.into_iter()
.find(|m| m.id == model_id)
.ok_or_else(|| {
LlmError::InvalidInput(format!("Model '{}' not found", model_id))
})
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::HttpConfig;
fn create_test_models() -> XaiModels {
XaiModels::new(
"test-api-key".to_string(),
"https://api.x.ai/v1".to_string(),
reqwest::Client::new(),
HttpConfig::default(),
)
}
#[test]
fn test_get_hardcoded_models() {
let models = create_test_models();
let hardcoded = models.get_hardcoded_models();
assert!(!hardcoded.is_empty());
assert!(hardcoded.iter().any(|m| m.id.contains("grok")));
}
#[test]
fn test_get_hardcoded_capabilities() {
let models = create_test_models();
let grok4_caps = models.get_hardcoded_capabilities("grok-4");
assert!(grok4_caps.contains(&"vision".to_string()));
assert!(grok4_caps.contains(&"reasoning".to_string()));
let grok3_caps = models.get_hardcoded_capabilities("grok-3");
assert!(grok3_caps.contains(&"chat".to_string()));
assert!(grok3_caps.contains(&"function_calling".to_string()));
}
}