use std::sync::Arc;
use crate::ToolCall;
#[cfg(feature = "xai")]
use crate::{
builder::LLMBackend,
chat::{
ChatMessage, ChatProvider, ChatResponse, ChatRole, StructuredOutputFormat, Tool, Usage,
},
completion::{CompletionProvider, CompletionRequest, CompletionResponse},
embedding::EmbeddingProvider,
error::LLMError,
models::{ModelListRequest, ModelListResponse, ModelsProvider, StandardModelListResponse},
stt::SpeechToTextProvider,
tts::TextToSpeechProvider,
LLMProvider,
};
use async_trait::async_trait;
use futures::stream::Stream;
use reqwest::Client;
use serde::{Deserialize, Serialize};
#[derive(Debug)]
pub struct XAIConfig {
pub api_key: String,
pub model: String,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub system: Option<String>,
pub timeout_seconds: Option<u64>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub embedding_encoding_format: Option<String>,
pub embedding_dimensions: Option<u32>,
pub json_schema: Option<StructuredOutputFormat>,
pub xai_search_mode: Option<String>,
pub xai_search_source_type: Option<String>,
pub xai_search_excluded_websites: Option<Vec<String>>,
pub xai_search_max_results: Option<u32>,
pub xai_search_from_date: Option<String>,
pub xai_search_to_date: Option<String>,
}
#[derive(Debug, Clone)]
pub struct XAI {
pub config: Arc<XAIConfig>,
pub client: Client,
}
const AUDIO_UNSUPPORTED: &str = "XAI does not support audio chat messages";
#[derive(Debug, Clone, serde::Serialize)]
pub struct XaiSearchSource {
#[serde(rename = "type")]
pub source_type: String,
pub excluded_websites: Option<Vec<String>>,
}
#[derive(Debug, Clone, Default, serde::Serialize)]
pub struct XaiSearchParameters {
pub mode: Option<String>,
pub sources: Option<Vec<XaiSearchSource>>,
pub max_search_results: Option<u32>,
pub from_date: Option<String>,
pub to_date: Option<String>,
}
#[derive(Serialize)]
struct XAIChatMessage<'a> {
role: &'a str,
content: &'a str,
}
#[derive(Serialize)]
struct XAIChatRequest<'a> {
model: &'a str,
messages: Vec<XAIChatMessage<'a>>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<XAIResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
search_parameters: Option<&'a XaiSearchParameters>,
}
#[derive(Deserialize, Debug)]
struct XAIChatResponse {
choices: Vec<XAIChatChoice>,
usage: Option<Usage>,
}
impl std::fmt::Display for XAIChatResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.text().unwrap_or_default())
}
}
impl ChatResponse for XAIChatResponse {
fn text(&self) -> Option<String> {
self.choices.first().map(|c| c.message.content.clone())
}
fn tool_calls(&self) -> Option<Vec<ToolCall>> {
None
}
fn usage(&self) -> Option<Usage> {
self.usage.clone()
}
}
#[derive(Deserialize, Debug)]
struct XAIChatChoice {
message: XAIChatMsg,
}
#[derive(Deserialize, Debug)]
struct XAIChatMsg {
content: String,
}
#[derive(Debug, Serialize)]
struct XAIEmbeddingRequest<'a> {
model: &'a str,
input: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
encoding_format: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<u32>,
}
#[derive(Deserialize)]
struct XAIEmbeddingData {
embedding: Vec<f32>,
}
#[derive(Deserialize, Debug)]
struct XAIStreamResponse {
choices: Vec<XAIStreamChoice>,
}
#[derive(Deserialize, Debug)]
struct XAIStreamChoice {
delta: XAIStreamDelta,
}
#[derive(Deserialize, Debug)]
struct XAIStreamDelta {
content: Option<String>,
}
#[derive(Deserialize)]
struct XAIEmbeddingResponse {
data: Vec<XAIEmbeddingData>,
}
#[derive(Deserialize, Debug, Serialize)]
enum XAIResponseType {
#[serde(rename = "text")]
Text,
#[serde(rename = "json_schema")]
JsonSchema,
#[serde(rename = "json_object")]
JsonObject,
}
#[derive(Deserialize, Debug, Serialize)]
struct XAIResponseFormat {
#[serde(rename = "type")]
response_type: XAIResponseType,
#[serde(skip_serializing_if = "Option::is_none")]
json_schema: Option<StructuredOutputFormat>,
}
impl XAI {
#[allow(clippy::too_many_arguments)]
pub fn new(
api_key: impl Into<String>,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
timeout_seconds: Option<u64>,
system: Option<String>,
top_p: Option<f32>,
top_k: Option<u32>,
embedding_encoding_format: Option<String>,
embedding_dimensions: Option<u32>,
json_schema: Option<StructuredOutputFormat>,
xai_search_mode: Option<String>,
xai_search_source_type: Option<String>,
xai_search_excluded_websites: Option<Vec<String>>,
xai_search_max_results: Option<u32>,
xai_search_from_date: Option<String>,
xai_search_to_date: Option<String>,
) -> Self {
let mut builder = Client::builder();
if let Some(sec) = timeout_seconds {
builder = builder.timeout(std::time::Duration::from_secs(sec));
}
Self::with_client(
builder.build().expect("Failed to build reqwest Client"),
api_key,
model,
max_tokens,
temperature,
timeout_seconds,
system,
top_p,
top_k,
embedding_encoding_format,
embedding_dimensions,
json_schema,
xai_search_mode,
xai_search_source_type,
xai_search_excluded_websites,
xai_search_max_results,
xai_search_from_date,
xai_search_to_date,
)
}
#[allow(clippy::too_many_arguments)]
pub fn with_client(
client: Client,
api_key: impl Into<String>,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
timeout_seconds: Option<u64>,
system: Option<String>,
top_p: Option<f32>,
top_k: Option<u32>,
embedding_encoding_format: Option<String>,
embedding_dimensions: Option<u32>,
json_schema: Option<StructuredOutputFormat>,
xai_search_mode: Option<String>,
xai_search_source_type: Option<String>,
xai_search_excluded_websites: Option<Vec<String>>,
xai_search_max_results: Option<u32>,
xai_search_from_date: Option<String>,
xai_search_to_date: Option<String>,
) -> Self {
Self {
config: Arc::new(XAIConfig {
api_key: api_key.into(),
model: model.unwrap_or("grok-2-latest".to_string()),
max_tokens,
temperature,
system,
timeout_seconds,
top_p,
top_k,
embedding_encoding_format,
embedding_dimensions,
json_schema,
xai_search_mode,
xai_search_source_type,
xai_search_excluded_websites,
xai_search_max_results,
xai_search_from_date,
xai_search_to_date,
}),
client,
}
}
pub fn api_key(&self) -> &str {
&self.config.api_key
}
pub fn model(&self) -> &str {
&self.config.model
}
pub fn max_tokens(&self) -> Option<u32> {
self.config.max_tokens
}
pub fn temperature(&self) -> Option<f32> {
self.config.temperature
}
pub fn timeout_seconds(&self) -> Option<u64> {
self.config.timeout_seconds
}
pub fn system(&self) -> Option<&str> {
self.config.system.as_deref()
}
pub fn top_p(&self) -> Option<f32> {
self.config.top_p
}
pub fn top_k(&self) -> Option<u32> {
self.config.top_k
}
pub fn embedding_encoding_format(&self) -> Option<&str> {
self.config.embedding_encoding_format.as_deref()
}
pub fn embedding_dimensions(&self) -> Option<u32> {
self.config.embedding_dimensions
}
pub fn json_schema(&self) -> Option<&StructuredOutputFormat> {
self.config.json_schema.as_ref()
}
pub fn client(&self) -> &Client {
&self.client
}
}
#[async_trait]
impl ChatProvider for XAI {
async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
crate::chat::ensure_no_audio(messages, AUDIO_UNSUPPORTED)?;
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError("Missing X.AI API key".to_string()));
}
let mut xai_msgs: Vec<XAIChatMessage> = messages
.iter()
.map(|m| XAIChatMessage {
role: match m.role {
ChatRole::User => "user",
ChatRole::Assistant => "assistant",
},
content: &m.content,
})
.collect();
if let Some(system) = &self.config.system {
xai_msgs.insert(
0,
XAIChatMessage {
role: "system",
content: system,
},
);
}
let response_format: Option<XAIResponseFormat> =
self.config.json_schema.as_ref().map(|s| XAIResponseFormat {
response_type: XAIResponseType::JsonSchema,
json_schema: Some(s.clone()),
});
let search_parameters = XaiSearchParameters {
mode: self.config.xai_search_mode.clone(),
sources: Some(vec![XaiSearchSource {
source_type: self
.config
.xai_search_source_type
.clone()
.unwrap_or("web".to_string()),
excluded_websites: self.config.xai_search_excluded_websites.clone(),
}]),
max_search_results: self.config.xai_search_max_results,
from_date: self.config.xai_search_from_date.clone(),
to_date: self.config.xai_search_to_date.clone(),
};
let body = XAIChatRequest {
model: &self.config.model,
messages: xai_msgs,
max_tokens: self.config.max_tokens,
temperature: self.config.temperature,
stream: false,
top_p: self.config.top_p,
top_k: self.config.top_k,
response_format,
search_parameters: Some(&search_parameters),
};
if log::log_enabled!(log::Level::Trace) {
if let Ok(json) = serde_json::to_string(&body) {
log::trace!("XAI request payload: {}", json);
}
}
let mut request = self
.client
.post("https://api.x.ai/v1/chat/completions")
.bearer_auth(&self.config.api_key)
.json(&body);
if let Some(timeout) = self.config.timeout_seconds {
request = request.timeout(std::time::Duration::from_secs(timeout));
}
let resp = request.send().await?;
log::debug!("XAI HTTP status: {}", resp.status());
let resp = resp.error_for_status()?;
let json_resp: XAIChatResponse = resp.json().await?;
Ok(Box::new(json_resp))
}
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
_tools: Option<&[Tool]>,
) -> Result<Box<dyn ChatResponse>, LLMError> {
self.chat(messages).await
}
async fn chat_stream(
&self,
messages: &[ChatMessage],
) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
{
crate::chat::ensure_no_audio(messages, AUDIO_UNSUPPORTED)?;
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError("Missing X.AI API key".to_string()));
}
let mut xai_msgs: Vec<XAIChatMessage> = messages
.iter()
.map(|m| XAIChatMessage {
role: match m.role {
ChatRole::User => "user",
ChatRole::Assistant => "assistant",
},
content: &m.content,
})
.collect();
if let Some(system) = &self.config.system {
xai_msgs.insert(
0,
XAIChatMessage {
role: "system",
content: system,
},
);
}
let body = XAIChatRequest {
model: &self.config.model,
messages: xai_msgs,
max_tokens: self.config.max_tokens,
temperature: self.config.temperature,
stream: true,
top_p: self.config.top_p,
top_k: self.config.top_k,
response_format: None,
search_parameters: None,
};
let mut request = self
.client
.post("https://api.x.ai/v1/chat/completions")
.bearer_auth(&self.config.api_key)
.json(&body);
if let Some(timeout) = self.config.timeout_seconds {
request = request.timeout(std::time::Duration::from_secs(timeout));
}
let response = request.send().await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await?;
return Err(LLMError::ResponseFormatError {
message: format!("X.AI API returned error status: {status}"),
raw_response: error_text,
});
}
Ok(crate::chat::create_sse_stream(
response,
parse_xai_sse_chunk,
))
}
}
#[async_trait]
impl CompletionProvider for XAI {
async fn complete(&self, _req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
Ok(CompletionResponse {
text: "X.AI completion not implemented.".into(),
})
}
}
#[async_trait]
impl EmbeddingProvider for XAI {
async fn embed(&self, text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError("Missing X.AI API key".into()));
}
let emb_format = self
.config
.embedding_encoding_format
.clone()
.unwrap_or_else(|| "float".to_string());
let body = XAIEmbeddingRequest {
model: &self.config.model,
input: text,
encoding_format: Some(&emb_format),
dimensions: self.config.embedding_dimensions,
};
let resp = self
.client
.post("https://api.x.ai/v1/embeddings")
.bearer_auth(&self.config.api_key)
.json(&body)
.send()
.await?
.error_for_status()?;
let json_resp: XAIEmbeddingResponse = resp.json().await?;
let embeddings = json_resp.data.into_iter().map(|d| d.embedding).collect();
Ok(embeddings)
}
}
#[async_trait]
impl SpeechToTextProvider for XAI {
async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
Err(LLMError::ProviderError(
"XAI does not implement speech to text endpoint yet.".into(),
))
}
}
#[async_trait]
impl TextToSpeechProvider for XAI {}
#[async_trait]
impl ModelsProvider for XAI {
async fn list_models(
&self,
_request: Option<&ModelListRequest>,
) -> Result<Box<dyn ModelListResponse>, LLMError> {
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError("Missing X.AI API key".to_string()));
}
let mut request = self
.client
.get("https://api.x.ai/v1/models")
.bearer_auth(&self.config.api_key);
if let Some(timeout) = self.config.timeout_seconds {
request = request.timeout(std::time::Duration::from_secs(timeout));
}
let resp = request.send().await?.error_for_status()?;
let result = StandardModelListResponse {
inner: resp.json().await?,
backend: LLMBackend::XAI,
};
Ok(Box::new(result))
}
}
impl LLMProvider for XAI {}
fn parse_xai_sse_chunk(chunk: &str) -> Result<Option<String>, LLMError> {
for line in chunk.lines() {
let line = line.trim();
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
return Ok(None);
}
match serde_json::from_str::<XAIStreamResponse>(data) {
Ok(response) => {
if let Some(choice) = response.choices.first() {
if let Some(content) = &choice.delta.content {
return Ok(Some(content.clone()));
}
}
return Ok(None);
}
Err(_) => continue,
}
}
}
Ok(None)
}