use async_trait::async_trait;
use futures::stream::BoxStream;
use tracing::{debug, warn};
use crate::error::{LlmError, Result};
use crate::model_config::{
ModelCapabilities, ModelCard, ModelType, ProviderConfig, ProviderType as ConfigProviderType,
};
use crate::providers::openai_compatible::OpenAICompatibleProvider;
use crate::traits::StreamChunk;
use crate::traits::{
ChatMessage, CompletionOptions, EmbeddingProvider, LLMProvider, LLMResponse, ToolChoice,
ToolDefinition,
};
const XAI_BASE_URL: &str = "https://api.x.ai/v1";
const XAI_DEFAULT_MODEL: &str = "grok-4.20";
const XAI_PROVIDER_NAME: &str = "xai";
const XAI_MODELS: &[(&str, &str, usize)] = &[
("grok-4.20", "Grok 4.20 (Latest Flagship, 2M)", 2_000_000),
("grok-4.20-latest", "Grok 4.20 Latest (2M)", 2_000_000),
("grok-4.20-reasoning", "Grok 4.20 Reasoning (2M)", 2_000_000),
(
"grok-4.20-non-reasoning",
"Grok 4.20 Non-Reasoning (2M)",
2_000_000,
),
("grok-4.20-0309", "Grok 4.20 (0309 dated, 2M)", 2_000_000),
(
"grok-4.20-0309-reasoning",
"Grok 4.20 Reasoning (0309 dated, 2M)",
2_000_000,
),
(
"grok-4.20-0309-non-reasoning",
"Grok 4.20 Non-Reasoning (0309 dated, 2M)",
2_000_000,
),
(
"grok-4.20-multi-agent",
"Grok 4.20 Multi-Agent (2M)",
2_000_000,
),
(
"grok-4.20-multi-agent-0309",
"Grok 4.20 Multi-Agent (0309 dated, 2M)",
2_000_000,
),
("grok-4", "Grok 4 (256K, reasoning)", 262_144),
("grok-4-0709", "Grok 4 (July 2025, 256K)", 262_144),
("grok-4-latest", "Grok 4 Latest (256K)", 262_144),
("grok-4-1-fast", "Grok 4.1 Fast (2M, reasoning)", 2_000_000),
(
"grok-4-1-fast-reasoning",
"Grok 4.1 Fast Reasoning (2M)",
2_000_000,
),
(
"grok-4-1-fast-non-reasoning",
"Grok 4.1 Fast Non-Reasoning (2M)",
2_000_000,
),
("grok-3", "Grok 3 (128K)", 131_072),
("grok-3-latest", "Grok 3 Latest (128K)", 131_072),
("grok-3-mini", "Grok 3 Mini (128K)", 131_072),
("grok-3-mini-latest", "Grok 3 Mini Latest (128K)", 131_072),
("grok-2-vision-1212", "Grok 2 Vision (32K)", 32_768),
("grok-code-fast-1", "Grok Code Fast (128K)", 131_072),
];
#[derive(Debug)]
pub struct XAIProvider {
inner: OpenAICompatibleProvider,
model: String,
}
impl XAIProvider {
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("XAI_API_KEY").map_err(|_| {
LlmError::ConfigError(
"XAI_API_KEY environment variable not set. \
Get your API key from https://console.x.ai"
.to_string(),
)
})?;
if api_key.is_empty() {
return Err(LlmError::ConfigError(
"XAI_API_KEY is empty. Please set a valid API key.".to_string(),
));
}
let model = std::env::var("XAI_MODEL").unwrap_or_else(|_| XAI_DEFAULT_MODEL.to_string());
let base_url = std::env::var("XAI_BASE_URL").unwrap_or_else(|_| XAI_BASE_URL.to_string());
Self::new(api_key, model, Some(base_url))
}
pub fn new(api_key: String, model: String, base_url: Option<String>) -> Result<Self> {
let config = Self::build_config(&api_key, &model, base_url.as_deref());
let inner = OpenAICompatibleProvider::from_config(config)?;
debug!(
provider = XAI_PROVIDER_NAME,
model = %model,
"Created xAI provider"
);
Ok(Self { inner, model })
}
pub fn with_model(mut self, model: &str) -> Self {
self.model = model.to_string();
self.inner = self.inner.with_model(model);
self
}
fn build_config(_api_key: &str, model: &str, base_url: Option<&str>) -> ProviderConfig {
let models: Vec<ModelCard> = XAI_MODELS
.iter()
.map(|(name, display, context)| {
let supports_vision = name.starts_with("grok-4") || name.contains("vision");
let supports_thinking = Self::is_reasoning_model(name);
ModelCard {
name: name.to_string(),
display_name: display.to_string(),
model_type: ModelType::Llm,
capabilities: ModelCapabilities {
context_length: *context,
supports_function_calling: true,
supports_json_mode: true,
supports_streaming: true,
supports_system_message: true,
supports_vision,
supports_thinking,
..Default::default()
},
..Default::default()
}
})
.collect();
ProviderConfig {
name: XAI_PROVIDER_NAME.to_string(),
display_name: "xAI Grok".to_string(),
provider_type: ConfigProviderType::OpenAICompatible,
api_key_env: Some("XAI_API_KEY".to_string()),
base_url: Some(base_url.unwrap_or(XAI_BASE_URL).to_string()),
base_url_env: Some("XAI_BASE_URL".to_string()),
default_llm_model: Some(model.to_string()),
default_embedding_model: None, models,
headers: std::collections::HashMap::new(),
enabled: true,
timeout_seconds: 600,
..Default::default()
}
}
pub fn context_length(model: &str) -> usize {
XAI_MODELS
.iter()
.find(|(name, _, _)| *name == model)
.map(|(_, _, ctx)| *ctx)
.unwrap_or(262_144) }
pub fn available_models() -> Vec<(&'static str, &'static str, usize)> {
XAI_MODELS.to_vec()
}
pub fn is_reasoning_model(model: &str) -> bool {
model.starts_with("grok-4") && !model.ends_with("-non-reasoning")
}
fn filter_for_reasoning(options: &CompletionOptions) -> CompletionOptions {
if options.presence_penalty.is_none()
&& options.frequency_penalty.is_none()
&& options.stop.is_none()
&& options.reasoning_effort.is_none()
{
return options.clone();
}
if options.presence_penalty.is_some()
|| options.frequency_penalty.is_some()
|| options.stop.is_some()
|| options.reasoning_effort.is_some()
{
warn!(
model = %"xai",
"Stripping presence_penalty / frequency_penalty / stop / reasoning_effort \
from options — these are not supported by xAI reasoning models and would \
cause a HTTP 400 error. Use a *-non-reasoning model variant to keep them."
);
}
CompletionOptions {
presence_penalty: None,
frequency_penalty: None,
stop: None,
reasoning_effort: None,
..options.clone()
}
}
fn resolve_options<'o>(
&self,
options: Option<&'o CompletionOptions>,
) -> std::borrow::Cow<'o, CompletionOptions> {
match options {
None => std::borrow::Cow::Owned(CompletionOptions::default()),
Some(opts) if Self::is_reasoning_model(&self.model) => {
std::borrow::Cow::Owned(Self::filter_for_reasoning(opts))
}
Some(opts) => std::borrow::Cow::Borrowed(opts),
}
}
}
#[async_trait]
impl LLMProvider for XAIProvider {
fn name(&self) -> &str {
XAI_PROVIDER_NAME
}
fn model(&self) -> &str {
&self.model
}
fn max_context_length(&self) -> usize {
Self::context_length(&self.model)
}
async fn complete(&self, prompt: &str) -> Result<LLMResponse> {
self.complete_with_options(prompt, &CompletionOptions::default())
.await
}
async fn complete_with_options(
&self,
prompt: &str,
options: &CompletionOptions,
) -> Result<LLMResponse> {
let filtered = self.resolve_options(Some(options));
self.inner
.complete_with_options(prompt, filtered.as_ref())
.await
}
async fn chat(
&self,
messages: &[ChatMessage],
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
let filtered = self.resolve_options(options);
self.inner.chat(messages, Some(filtered.as_ref())).await
}
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
let filtered = self.resolve_options(options);
self.inner
.chat_with_tools(messages, tools, tool_choice, Some(filtered.as_ref()))
.await
}
async fn chat_with_tools_stream(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<BoxStream<'static, Result<StreamChunk>>> {
let filtered = self.resolve_options(options);
let owned = filtered.into_owned();
self.inner
.chat_with_tools_stream(messages, tools, tool_choice, Some(&owned))
.await
}
async fn stream(&self, prompt: &str) -> Result<BoxStream<'static, Result<String>>> {
self.inner.stream(prompt).await
}
fn supports_function_calling(&self) -> bool {
self.inner.supports_function_calling()
}
fn supports_tool_streaming(&self) -> bool {
self.inner.supports_tool_streaming()
}
}
#[async_trait]
impl EmbeddingProvider for XAIProvider {
fn name(&self) -> &str {
XAI_PROVIDER_NAME
}
fn model(&self) -> &str {
"none"
}
fn dimension(&self) -> usize {
0 }
fn max_tokens(&self) -> usize {
0 }
async fn embed(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
Err(LlmError::ConfigError(
"xAI does not provide an embeddings API. \
Use OpenAI or another provider for embeddings."
.to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_name_constant() {
assert_eq!(XAI_PROVIDER_NAME, "xai");
}
#[test]
fn test_default_model_constant() {
assert_eq!(XAI_DEFAULT_MODEL, "grok-4.20");
}
#[test]
fn test_default_base_url_constant() {
assert_eq!(XAI_BASE_URL, "https://api.x.ai/v1");
}
#[test]
fn test_context_length_grok420_series() {
assert_eq!(XAIProvider::context_length("grok-4.20"), 2_000_000);
assert_eq!(XAIProvider::context_length("grok-4.20-latest"), 2_000_000);
assert_eq!(
XAIProvider::context_length("grok-4.20-reasoning"),
2_000_000
);
assert_eq!(
XAIProvider::context_length("grok-4.20-non-reasoning"),
2_000_000
);
assert_eq!(XAIProvider::context_length("grok-4.20-0309"), 2_000_000);
assert_eq!(
XAIProvider::context_length("grok-4.20-0309-reasoning"),
2_000_000
);
assert_eq!(
XAIProvider::context_length("grok-4.20-0309-non-reasoning"),
2_000_000
);
assert_eq!(
XAIProvider::context_length("grok-4.20-multi-agent"),
2_000_000
);
assert_eq!(
XAIProvider::context_length("grok-4.20-multi-agent-0309"),
2_000_000
);
}
#[test]
fn test_context_length_grok4_series() {
assert_eq!(XAIProvider::context_length("grok-4"), 262_144);
assert_eq!(XAIProvider::context_length("grok-4-0709"), 262_144);
assert_eq!(XAIProvider::context_length("grok-4-latest"), 262_144);
}
#[test]
fn test_context_length_grok41_fast_series() {
assert_eq!(XAIProvider::context_length("grok-4-1-fast"), 2_000_000);
assert_eq!(
XAIProvider::context_length("grok-4-1-fast-reasoning"),
2_000_000
);
assert_eq!(
XAIProvider::context_length("grok-4-1-fast-non-reasoning"),
2_000_000
);
}
#[test]
fn test_context_length_grok3_series() {
assert_eq!(XAIProvider::context_length("grok-3"), 131_072);
assert_eq!(XAIProvider::context_length("grok-3-latest"), 131_072);
assert_eq!(XAIProvider::context_length("grok-3-mini"), 131_072);
assert_eq!(XAIProvider::context_length("grok-3-mini-latest"), 131_072);
}
#[test]
fn test_context_length_specialized_models() {
assert_eq!(XAIProvider::context_length("grok-2-vision-1212"), 32_768); assert_eq!(XAIProvider::context_length("grok-code-fast-1"), 131_072); }
#[test]
fn test_context_length_unknown_model_defaults_256k() {
assert_eq!(XAIProvider::context_length("grok-unknown"), 262_144);
assert_eq!(XAIProvider::context_length("custom-model"), 262_144);
}
#[test]
fn test_is_reasoning_model_grok4_base() {
assert!(XAIProvider::is_reasoning_model("grok-4"));
assert!(XAIProvider::is_reasoning_model("grok-4-0709"));
assert!(XAIProvider::is_reasoning_model("grok-4-latest"));
}
#[test]
fn test_is_reasoning_model_grok420_series() {
assert!(XAIProvider::is_reasoning_model("grok-4.20"));
assert!(XAIProvider::is_reasoning_model("grok-4.20-latest"));
assert!(XAIProvider::is_reasoning_model("grok-4.20-reasoning"));
assert!(XAIProvider::is_reasoning_model("grok-4.20-0309"));
assert!(XAIProvider::is_reasoning_model("grok-4.20-0309-reasoning"));
assert!(XAIProvider::is_reasoning_model("grok-4.20-multi-agent"));
assert!(XAIProvider::is_reasoning_model(
"grok-4.20-multi-agent-0309"
));
}
#[test]
fn test_is_reasoning_model_grok420_non_reasoning() {
assert!(!XAIProvider::is_reasoning_model("grok-4.20-non-reasoning"));
assert!(!XAIProvider::is_reasoning_model(
"grok-4.20-0309-non-reasoning"
));
}
#[test]
fn test_is_reasoning_model_grok41_fast() {
assert!(XAIProvider::is_reasoning_model("grok-4-1-fast"));
assert!(XAIProvider::is_reasoning_model("grok-4-1-fast-reasoning"));
assert!(!XAIProvider::is_reasoning_model(
"grok-4-1-fast-non-reasoning"
));
}
#[test]
fn test_is_reasoning_model_grok3_series_not_reasoning() {
assert!(!XAIProvider::is_reasoning_model("grok-3"));
assert!(!XAIProvider::is_reasoning_model("grok-3-latest"));
assert!(!XAIProvider::is_reasoning_model("grok-3-mini"));
assert!(!XAIProvider::is_reasoning_model("grok-2-vision-1212"));
assert!(!XAIProvider::is_reasoning_model("grok-code-fast-1"));
}
#[test]
fn test_filter_for_reasoning_strips_prohibited_fields() {
let opts = CompletionOptions {
temperature: Some(0.7),
max_tokens: Some(1000),
presence_penalty: Some(0.5),
frequency_penalty: Some(0.3),
stop: Some(vec!["END".to_string()]),
reasoning_effort: Some("high".to_string()),
..Default::default()
};
let filtered = XAIProvider::filter_for_reasoning(&opts);
assert!(filtered.presence_penalty.is_none());
assert!(filtered.frequency_penalty.is_none());
assert!(filtered.stop.is_none());
assert!(filtered.reasoning_effort.is_none());
assert_eq!(filtered.temperature, Some(0.7));
assert_eq!(filtered.max_tokens, Some(1000));
}
#[test]
fn test_filter_for_reasoning_noop_when_clean() {
let opts = CompletionOptions {
temperature: Some(0.5),
max_tokens: Some(512),
..Default::default()
};
let filtered = XAIProvider::filter_for_reasoning(&opts);
assert_eq!(filtered.temperature, Some(0.5));
assert_eq!(filtered.max_tokens, Some(512));
assert!(filtered.presence_penalty.is_none());
assert!(filtered.frequency_penalty.is_none());
assert!(filtered.stop.is_none());
assert!(filtered.reasoning_effort.is_none());
}
#[test]
fn test_filter_for_reasoning_preserves_system_prompt_and_format() {
let opts = CompletionOptions {
system_prompt: Some("You are helpful.".to_string()),
response_format: Some("json_object".to_string()),
temperature: Some(0.0),
frequency_penalty: Some(1.0), ..Default::default()
};
let filtered = XAIProvider::filter_for_reasoning(&opts);
assert_eq!(filtered.system_prompt, Some("You are helpful.".to_string()));
assert_eq!(filtered.response_format, Some("json_object".to_string()));
assert_eq!(filtered.temperature, Some(0.0));
assert!(filtered.frequency_penalty.is_none());
}
#[test]
fn test_available_models_contains_grok420_series() {
let models = XAIProvider::available_models();
let names: Vec<&str> = models.iter().map(|(n, _, _)| *n).collect();
assert!(names.contains(&"grok-4.20"), "missing grok-4.20");
assert!(
names.contains(&"grok-4.20-latest"),
"missing grok-4.20-latest"
);
assert!(
names.contains(&"grok-4.20-reasoning"),
"missing grok-4.20-reasoning"
);
assert!(
names.contains(&"grok-4.20-non-reasoning"),
"missing grok-4.20-non-reasoning"
);
assert!(names.contains(&"grok-4.20-0309"), "missing grok-4.20-0309");
assert!(
names.contains(&"grok-4.20-0309-reasoning"),
"missing grok-4.20-0309-reasoning"
);
assert!(
names.contains(&"grok-4.20-0309-non-reasoning"),
"missing grok-4.20-0309-non-reasoning"
);
assert!(
names.contains(&"grok-4.20-multi-agent"),
"missing grok-4.20-multi-agent"
);
assert!(
names.contains(&"grok-4.20-multi-agent-0309"),
"missing grok-4.20-multi-agent-0309"
);
}
#[test]
fn test_available_models_contains_all_legacy_series() {
let models = XAIProvider::available_models();
let names: Vec<&str> = models.iter().map(|(n, _, _)| *n).collect();
assert!(names.contains(&"grok-4"));
assert!(names.contains(&"grok-4-0709"));
assert!(names.contains(&"grok-4-latest"));
assert!(names.contains(&"grok-4-1-fast"));
assert!(names.contains(&"grok-4-1-fast-reasoning"));
assert!(names.contains(&"grok-4-1-fast-non-reasoning"));
assert!(names.contains(&"grok-3"));
assert!(names.contains(&"grok-3-mini"));
assert!(names.contains(&"grok-2-vision-1212"));
assert!(names.contains(&"grok-code-fast-1"));
}
#[test]
fn test_available_models_all_have_positive_context_length() {
for (name, _desc, ctx) in XAIProvider::available_models() {
assert!(ctx > 0, "Model '{}' has zero context length", name);
}
}
#[test]
fn test_build_config_defaults() {
let config = XAIProvider::build_config("test-key", "grok-4.20", None);
assert_eq!(config.name, "xai");
assert_eq!(config.display_name, "xAI Grok");
assert_eq!(config.base_url, Some("https://api.x.ai/v1".to_string()));
assert_eq!(config.api_key_env, Some("XAI_API_KEY".to_string()));
assert_eq!(config.default_llm_model, Some("grok-4.20".to_string()));
assert!(config.enabled);
assert_eq!(config.timeout_seconds, 600);
}
#[test]
fn test_build_config_custom_base_url() {
let config = XAIProvider::build_config("test-key", "grok-3", Some("https://custom.api"));
assert_eq!(config.base_url, Some("https://custom.api".to_string()));
assert_eq!(config.default_llm_model, Some("grok-3".to_string()));
}
#[test]
fn test_build_config_model_cards_not_empty() {
let config = XAIProvider::build_config("test-key", "grok-4.20", None);
assert!(!config.models.is_empty());
let card = config.models.iter().find(|m| m.name == "grok-4.20");
assert!(card.is_some(), "grok-4.20 model card missing");
let card = card.unwrap();
assert!(card.capabilities.supports_function_calling);
assert!(card.capabilities.supports_json_mode);
assert!(card.capabilities.supports_streaming);
assert!(card.capabilities.supports_vision);
assert!(card.capabilities.supports_thinking); }
#[test]
fn test_build_config_non_reasoning_model_card_no_thinking() {
let config = XAIProvider::build_config("test-key", "grok-4.20-non-reasoning", None);
let card = config
.models
.iter()
.find(|m| m.name == "grok-4.20-non-reasoning");
assert!(card.is_some());
let card = card.unwrap();
assert!(!card.capabilities.supports_thinking);
}
#[test]
fn test_build_config_grok3_model_card_no_thinking() {
let config = XAIProvider::build_config("test-key", "grok-3", None);
let card = config.models.iter().find(|m| m.name == "grok-3");
assert!(card.is_some());
let card = card.unwrap();
assert!(!card.capabilities.supports_thinking);
assert!(!card.capabilities.supports_vision);
}
#[test]
fn test_from_env_missing_api_key() {
std::env::remove_var("XAI_API_KEY");
std::env::remove_var("XAI_MODEL");
std::env::remove_var("XAI_BASE_URL");
let result = XAIProvider::from_env();
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("XAI_API_KEY"),
"Error should mention XAI_API_KEY, got: {}",
err
);
}
#[test]
fn test_from_env_empty_api_key_rejected() {
std::env::set_var("XAI_API_KEY", "");
std::env::remove_var("XAI_MODEL");
std::env::remove_var("XAI_BASE_URL");
let result = XAIProvider::from_env();
std::env::remove_var("XAI_API_KEY");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("empty"),
"Error should mention 'empty', got: {}",
err
);
}
}