use async_trait::async_trait;
use futures::stream::BoxStream;
use tracing::debug;
use crate::error::{LlmError, Result};
use crate::model_config::{
ModelCapabilities, ModelCard, ModelType, ProviderConfig, ProviderType as ConfigProviderType,
};
use crate::providers::openai_compatible::OpenAICompatibleProvider;
use crate::traits::StreamChunk;
use crate::traits::{ChatMessage, CompletionOptions, EmbeddingProvider, LLMProvider, LLMResponse};
const XAI_BASE_URL: &str = "https://api.x.ai/v1";
const XAI_DEFAULT_MODEL: &str = "grok-4";
const XAI_PROVIDER_NAME: &str = "xai";
const XAI_MODELS: &[(&str, &str, usize)] = &[
("grok-4", "Grok 4 (Flagship, 256K)", 262144),
("grok-4-0709", "Grok 4 (July 2025)", 262144),
("grok-4-latest", "Grok 4 Latest", 262144),
("grok-4-1-fast", "Grok 4.1 Fast (2M context)", 2000000),
(
"grok-4-1-fast-reasoning",
"Grok 4.1 Fast Reasoning",
2000000,
),
(
"grok-4-1-fast-non-reasoning",
"Grok 4.1 Fast Non-Reasoning",
2000000,
),
("grok-3", "Grok 3", 131072),
("grok-3-latest", "Grok 3 Latest", 131072),
("grok-3-mini", "Grok 3 Mini", 131072),
("grok-3-mini-latest", "Grok 3 Mini Latest", 131072),
("grok-2-vision-1212", "Grok 2 Vision", 32768),
("grok-code-fast-1", "Grok Code Fast", 131072),
];
#[derive(Debug)]
pub struct XAIProvider {
inner: OpenAICompatibleProvider,
model: String,
}
impl XAIProvider {
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("XAI_API_KEY").map_err(|_| {
LlmError::ConfigError(
"XAI_API_KEY environment variable not set. \
Get your API key from https://console.x.ai"
.to_string(),
)
})?;
if api_key.is_empty() {
return Err(LlmError::ConfigError(
"XAI_API_KEY is empty. Please set a valid API key.".to_string(),
));
}
let model = std::env::var("XAI_MODEL").unwrap_or_else(|_| XAI_DEFAULT_MODEL.to_string());
let base_url = std::env::var("XAI_BASE_URL").unwrap_or_else(|_| XAI_BASE_URL.to_string());
Self::new(api_key, model, Some(base_url))
}
pub fn new(api_key: String, model: String, base_url: Option<String>) -> Result<Self> {
let config = Self::build_config(&api_key, &model, base_url.as_deref());
let inner = OpenAICompatibleProvider::from_config(config)?;
debug!(
provider = XAI_PROVIDER_NAME,
model = %model,
"Created xAI provider"
);
Ok(Self { inner, model })
}
pub fn with_model(mut self, model: &str) -> Self {
self.model = model.to_string();
self.inner = self.inner.with_model(model);
self
}
fn build_config(_api_key: &str, model: &str, base_url: Option<&str>) -> ProviderConfig {
let models: Vec<ModelCard> = XAI_MODELS
.iter()
.map(|(name, display, context)| {
let supports_vision = name.starts_with("grok-4")
|| name.starts_with("grok-4.1")
|| name.contains("vision");
let supports_thinking =
name.starts_with("grok-4") && !name.contains("non-reasoning");
ModelCard {
name: name.to_string(),
display_name: display.to_string(),
model_type: ModelType::Llm,
capabilities: ModelCapabilities {
context_length: *context,
supports_function_calling: true,
supports_json_mode: true,
supports_streaming: true,
supports_system_message: true,
supports_vision,
supports_thinking,
..Default::default()
},
..Default::default()
}
})
.collect();
ProviderConfig {
name: XAI_PROVIDER_NAME.to_string(),
display_name: "xAI Grok".to_string(),
provider_type: ConfigProviderType::OpenAICompatible,
api_key_env: Some("XAI_API_KEY".to_string()),
base_url: Some(base_url.unwrap_or(XAI_BASE_URL).to_string()),
base_url_env: Some("XAI_BASE_URL".to_string()),
default_llm_model: Some(model.to_string()),
default_embedding_model: None, models,
headers: std::collections::HashMap::new(),
enabled: true,
timeout_seconds: 600,
..Default::default()
}
}
pub fn context_length(model: &str) -> usize {
XAI_MODELS
.iter()
.find(|(name, _, _)| *name == model)
.map(|(_, _, ctx)| *ctx)
.unwrap_or(262144) }
pub fn available_models() -> Vec<(&'static str, &'static str, usize)> {
XAI_MODELS.to_vec()
}
}
#[async_trait]
impl LLMProvider for XAIProvider {
fn name(&self) -> &str {
XAI_PROVIDER_NAME
}
fn model(&self) -> &str {
&self.model
}
fn max_context_length(&self) -> usize {
Self::context_length(&self.model)
}
async fn complete(&self, prompt: &str) -> Result<LLMResponse> {
self.inner.complete(prompt).await
}
async fn complete_with_options(
&self,
prompt: &str,
options: &CompletionOptions,
) -> Result<LLMResponse> {
self.inner.complete_with_options(prompt, options).await
}
async fn chat(
&self,
messages: &[ChatMessage],
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
self.inner.chat(messages, options).await
}
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[crate::traits::ToolDefinition],
tool_choice: Option<crate::traits::ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
self.inner
.chat_with_tools(messages, tools, tool_choice, options)
.await
}
async fn chat_with_tools_stream(
&self,
messages: &[ChatMessage],
tools: &[crate::traits::ToolDefinition],
tool_choice: Option<crate::traits::ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<BoxStream<'static, Result<StreamChunk>>> {
self.inner
.chat_with_tools_stream(messages, tools, tool_choice, options)
.await
}
async fn stream(&self, prompt: &str) -> Result<BoxStream<'static, Result<String>>> {
self.inner.stream(prompt).await
}
fn supports_function_calling(&self) -> bool {
self.inner.supports_function_calling()
}
fn supports_tool_streaming(&self) -> bool {
self.inner.supports_tool_streaming()
}
}
#[async_trait]
impl EmbeddingProvider for XAIProvider {
fn name(&self) -> &str {
XAI_PROVIDER_NAME
}
fn model(&self) -> &str {
"none"
}
fn dimension(&self) -> usize {
0 }
fn max_tokens(&self) -> usize {
0 }
async fn embed(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
Err(LlmError::ConfigError(
"xAI does not provide an embeddings API. \
Use OpenAI or another provider for embeddings."
.to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_length_known_model() {
assert_eq!(XAIProvider::context_length("grok-4"), 262144); assert_eq!(XAIProvider::context_length("grok-4-1-fast"), 2000000); assert_eq!(XAIProvider::context_length("grok-2-vision-1212"), 32768); }
#[test]
fn test_context_length_unknown_model() {
assert_eq!(XAIProvider::context_length("grok-unknown"), 262144);
}
#[test]
fn test_available_models() {
let models = XAIProvider::available_models();
assert!(!models.is_empty());
assert!(models.iter().any(|(name, _, _)| *name == "grok-4"));
assert!(models.iter().any(|(name, _, _)| *name == "grok-4-1-fast"));
}
#[test]
fn test_build_config() {
let config = XAIProvider::build_config("test-key", "grok-4", None);
assert_eq!(config.name, "xai");
assert_eq!(config.base_url, Some("https://api.x.ai/v1".to_string()));
assert_eq!(config.default_llm_model, Some("grok-4".to_string()));
}
#[test]
fn test_build_config_custom_url() {
let config = XAIProvider::build_config("test-key", "grok-3", Some("https://custom.api"));
assert_eq!(config.base_url, Some("https://custom.api".to_string()));
}
#[test]
fn test_context_length_grok3_series() {
assert_eq!(XAIProvider::context_length("grok-3"), 131072); assert_eq!(XAIProvider::context_length("grok-3-latest"), 131072);
assert_eq!(XAIProvider::context_length("grok-3-mini"), 131072);
assert_eq!(XAIProvider::context_length("grok-3-mini-latest"), 131072);
}
#[test]
fn test_context_length_grok4_series() {
assert_eq!(XAIProvider::context_length("grok-4"), 262144); assert_eq!(XAIProvider::context_length("grok-4-0709"), 262144);
assert_eq!(XAIProvider::context_length("grok-4-latest"), 262144);
}
#[test]
fn test_context_length_grok41_fast_series() {
assert_eq!(XAIProvider::context_length("grok-4-1-fast"), 2000000); assert_eq!(
XAIProvider::context_length("grok-4-1-fast-reasoning"),
2000000
);
assert_eq!(
XAIProvider::context_length("grok-4-1-fast-non-reasoning"),
2000000
);
}
#[test]
fn test_context_length_specialized_models() {
assert_eq!(XAIProvider::context_length("grok-2-vision-1212"), 32768); assert_eq!(XAIProvider::context_length("grok-code-fast-1"), 131072); }
#[test]
fn test_build_config_model_cards() {
let config = XAIProvider::build_config("test-key", "grok-4", None);
assert!(!config.models.is_empty());
assert!(config.models.iter().any(|m| m.name == "grok-4"));
}
#[test]
fn test_build_config_api_key_env() {
let config = XAIProvider::build_config("my-api-key", "grok-4", None);
assert_eq!(config.api_key_env, Some("XAI_API_KEY".to_string()));
}
#[test]
fn test_available_models_contains_all_series() {
let models = XAIProvider::available_models();
assert!(models.iter().any(|(name, _, _)| *name == "grok-4"));
assert!(models.iter().any(|(name, _, _)| *name == "grok-4-latest"));
assert!(models.iter().any(|(name, _, _)| *name == "grok-4-1-fast"));
assert!(models.iter().any(|(name, _, _)| *name == "grok-3"));
assert!(models.iter().any(|(name, _, _)| *name == "grok-3-mini"));
assert!(models
.iter()
.any(|(name, _, _)| *name == "grok-2-vision-1212"));
assert!(models
.iter()
.any(|(name, _, _)| *name == "grok-code-fast-1"));
}
#[test]
fn test_available_models_has_context_lengths() {
let models = XAIProvider::available_models();
for (name, _desc, context_len) in models {
assert!(
context_len > 0,
"Model {} should have positive context length",
name
);
}
}
#[test]
fn test_from_env_missing_api_key() {
std::env::remove_var("XAI_API_KEY");
std::env::remove_var("XAI_MODEL");
std::env::remove_var("XAI_BASE_URL");
let result = XAIProvider::from_env();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("XAI_API_KEY"));
}
#[test]
fn test_provider_name_constant() {
assert_eq!(XAI_PROVIDER_NAME, "xai");
}
#[test]
fn test_default_model_constant() {
assert_eq!(XAI_DEFAULT_MODEL, "grok-4");
}
#[test]
fn test_default_base_url_constant() {
assert_eq!(XAI_BASE_URL, "https://api.x.ai/v1");
}
}