use async_trait::async_trait;
use serde::de::DeserializeOwned;
use std::str::FromStr;
use std::time::Duration;
use tracing::{debug, error, info, instrument, trace, warn};
use crate::backend::{
ChatMessage, GenerateResult, LLMClient, MaterializeInternalOutput, MaterializeResult,
ModelInfo, OpenAICompatibleChatCompletionRequest, OpenAICompatibleChatCompletionResponse,
OpenAICompatibleChatMessage, OpenAICompatibleMessageContent, ResponseFormat, TokenUsage,
ValidationFailureContext, check_response_status, convert_openai_compatible_chat_messages,
generate_with_retry_with_history, handle_http_error, materialize_with_media_with_retry,
parse_validate_and_create_output, prepare_strict_schema,
};
use crate::error::{ApiErrorKind, RStructorError, Result};
use crate::model::Instructor;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Model {
Grok43,
Grok4,
Grok4FastReasoning,
Grok4FastNonReasoning,
Grok41FastReasoning,
Grok41FastNonReasoning,
Grok3,
Grok3Mini,
GrokCodeFast1,
Grok2Vision,
Custom(String),
}
impl Model {
pub fn as_str(&self) -> &str {
match self {
Model::Grok43 => "grok-4.3",
Model::Grok4 => "grok-4-0709",
Model::Grok4FastReasoning => "grok-4-fast-reasoning",
Model::Grok4FastNonReasoning => "grok-4-fast-non-reasoning",
Model::Grok41FastReasoning => "grok-4-1-fast-reasoning",
Model::Grok41FastNonReasoning => "grok-4-1-fast-non-reasoning",
Model::Grok3 => "grok-3",
Model::Grok3Mini => "grok-3-mini",
Model::GrokCodeFast1 => "grok-code-fast-1",
Model::Grok2Vision => "grok-2-vision-1212",
Model::Custom(name) => name,
}
}
pub fn from_string(name: impl Into<String>) -> Self {
let name = name.into();
match name.as_str() {
"grok-4.3" => Model::Grok43,
"grok-4-0709" => Model::Grok4,
"grok-4-fast-reasoning" => Model::Grok4FastReasoning,
"grok-4-fast-non-reasoning" => Model::Grok4FastNonReasoning,
"grok-4-1-fast-reasoning" => Model::Grok41FastReasoning,
"grok-4-1-fast-non-reasoning" => Model::Grok41FastNonReasoning,
"grok-3" => Model::Grok3,
"grok-3-mini" => Model::Grok3Mini,
"grok-code-fast-1" => Model::GrokCodeFast1,
"grok-2-vision-1212" => Model::Grok2Vision,
_ => Model::Custom(name),
}
}
}
impl FromStr for Model {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
Ok(Model::from_string(s))
}
}
impl From<&str> for Model {
fn from(s: &str) -> Self {
Model::from_string(s)
}
}
impl From<String> for Model {
fn from(s: String) -> Self {
Model::from_string(s)
}
}
#[derive(Debug, Clone)]
pub struct GrokConfig {
pub api_key: String,
pub model: Model,
pub temperature: f32,
pub max_tokens: Option<u32>,
pub timeout: Option<Duration>,
pub max_retries: Option<usize>,
pub base_url: Option<String>,
}
#[derive(Clone)]
pub struct GrokClient {
config: GrokConfig,
client: reqwest::Client,
}
impl GrokClient {
#[instrument(name = "grok_client_new", skip(api_key), fields(model = ?Model::Grok43))]
pub fn new(api_key: impl Into<String>) -> Result<Self> {
let api_key = api_key.into();
if api_key.is_empty() {
return Err(RStructorError::api_error(
"Grok",
ApiErrorKind::AuthenticationFailed,
));
}
info!("Creating new Grok client");
trace!("API key length: {}", api_key.len());
let config = GrokConfig {
api_key,
model: Model::Grok43, temperature: 0.0,
max_tokens: None,
timeout: None, max_retries: Some(3), base_url: None, };
debug!("Grok client created with default configuration");
Ok(Self {
config,
client: reqwest::Client::new(),
})
}
#[instrument(name = "grok_client_from_env", fields(model = ?Model::Grok43))]
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("XAI_API_KEY")
.map_err(|_| RStructorError::api_error("Grok", ApiErrorKind::AuthenticationFailed))?;
info!("Creating new Grok client from environment variable");
trace!("API key length: {}", api_key.len());
let config = GrokConfig {
api_key,
model: Model::Grok43, temperature: 0.0,
max_tokens: None,
timeout: None, max_retries: Some(3), base_url: None, };
debug!("Grok client created with default configuration");
Ok(Self {
config,
client: reqwest::Client::new(),
})
}
}
impl GrokClient {
async fn materialize_internal<T>(
&self,
messages: &[ChatMessage],
) -> std::result::Result<
MaterializeInternalOutput<T>,
(RStructorError, Option<ValidationFailureContext>),
>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
info!("Generating structured response with Grok (native structured outputs)");
let schema = T::schema();
let schema_name = T::schema_name().unwrap_or_else(|| "output".to_string());
trace!(schema_name = schema_name, "Retrieved JSON schema for type");
let schema_json = prepare_strict_schema(&schema);
let api_messages =
convert_openai_compatible_chat_messages(messages, "Grok").map_err(|e| (e, None))?;
let response_format = ResponseFormat::json_schema(schema_name.clone(), schema_json, None);
debug!(
"Building Grok API request with structured outputs (history_len={})",
api_messages.len()
);
let request = OpenAICompatibleChatCompletionRequest {
model: self.config.model.as_str().to_string(),
messages: api_messages,
response_format: Some(response_format),
temperature: self.config.temperature,
max_tokens: self.config.max_tokens,
reasoning_effort: None,
};
let base_url = self
.config
.base_url
.as_deref()
.unwrap_or("https://api.x.ai/v1");
let url = format!("{}/chat/completions", base_url);
debug!(url = %url, "Sending request to Grok API with structured outputs");
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| (handle_http_error(e, "Grok"), None))?;
let response = check_response_status(response, "Grok")
.await
.map_err(|e| (e, None))?;
debug!("Successfully received response from Grok API");
let completion: OpenAICompatibleChatCompletionResponse =
response.json().await.map_err(|e| {
error!(error = %e, "Failed to parse JSON response from Grok API");
(RStructorError::from(e), None)
})?;
if completion.choices.is_empty() {
error!("Grok API returned empty choices array");
return Err((
RStructorError::api_error(
"Grok",
ApiErrorKind::UnexpectedResponse {
details: "No completion choices returned".to_string(),
},
),
None,
));
}
let model_name = completion
.model
.clone()
.unwrap_or_else(|| self.config.model.as_str().to_string());
let usage = completion
.usage
.as_ref()
.map(|u| TokenUsage::new(model_name.clone(), u.prompt_tokens, u.completion_tokens));
let message = &completion.choices[0].message;
trace!(finish_reason = %completion.choices[0].finish_reason, "Completion finish reason");
if let Some(content) = &message.content {
let raw_response = content.clone();
debug!(
content_len = raw_response.len(),
"Received structured output response"
);
trace!(json = %raw_response, "Parsing structured output response");
parse_validate_and_create_output(raw_response, usage)
} else {
error!("No content in Grok API response");
Err((
RStructorError::api_error(
"Grok",
ApiErrorKind::UnexpectedResponse {
details: "No content in response".to_string(),
},
),
None,
))
}
}
}
crate::impl_client_builder_methods! {
client_type: GrokClient,
config_type: GrokConfig,
model_type: Model,
provider_name: "Grok"
}
impl GrokClient {
#[tracing::instrument(skip(self, base_url))]
pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
let base_url_str = base_url.into();
tracing::debug!(
previous_base_url = ?self.config.base_url,
new_base_url = %base_url_str,
"Setting custom base URL"
);
self.config.base_url = Some(base_url_str);
self
}
}
#[async_trait]
impl LLMClient for GrokClient {
fn from_env() -> Result<Self> {
Self::from_env()
}
#[instrument(
name = "grok_materialize",
skip(self, prompt),
fields(
type_name = std::any::type_name::<T>(),
model = %self.config.model.as_str(),
prompt_len = prompt.len()
)
)]
async fn materialize<T>(&self, prompt: &str) -> Result<T>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
let output = generate_with_retry_with_history(
|messages: Vec<ChatMessage>| {
let this = self;
async move { this.materialize_internal::<T>(&messages).await }
},
prompt,
self.config.max_retries,
)
.await?;
Ok(output.data)
}
#[instrument(
name = "grok_materialize_with_media",
skip(self, prompt, media),
fields(
type_name = std::any::type_name::<T>(),
model = %self.config.model.as_str(),
prompt_len = prompt.len(),
media_len = media.len()
)
)]
async fn materialize_with_media<T>(&self, prompt: &str, media: &[super::MediaFile]) -> Result<T>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
materialize_with_media_with_retry(
|messages: Vec<ChatMessage>| {
let this = self;
async move { this.materialize_internal::<T>(&messages).await }
},
prompt,
media,
self.config.max_retries,
)
.await
}
#[instrument(
name = "grok_materialize_with_metadata",
skip(self, prompt),
fields(
type_name = std::any::type_name::<T>(),
model = %self.config.model.as_str(),
prompt_len = prompt.len()
)
)]
async fn materialize_with_metadata<T>(&self, prompt: &str) -> Result<MaterializeResult<T>>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
let output = generate_with_retry_with_history(
|messages: Vec<ChatMessage>| {
let this = self;
async move { this.materialize_internal::<T>(&messages).await }
},
prompt,
self.config.max_retries,
)
.await?;
Ok(MaterializeResult::new(output.data, output.usage))
}
#[instrument(
name = "grok_generate",
skip(self, prompt),
fields(
model = %self.config.model.as_str(),
prompt_len = prompt.len()
)
)]
async fn generate(&self, prompt: &str) -> Result<String> {
let result = self.generate_with_metadata(prompt).await?;
Ok(result.text)
}
#[instrument(
name = "grok_generate_with_metadata",
skip(self, prompt),
fields(
model = %self.config.model.as_str(),
prompt_len = prompt.len()
)
)]
async fn generate_with_metadata(&self, prompt: &str) -> Result<GenerateResult> {
info!("Generating raw text response with Grok");
debug!("Building Grok API request for text generation");
let request = OpenAICompatibleChatCompletionRequest {
model: self.config.model.as_str().to_string(),
messages: vec![OpenAICompatibleChatMessage {
role: "user".to_string(),
content: OpenAICompatibleMessageContent::Text(prompt.to_string()),
}],
response_format: None,
temperature: self.config.temperature,
max_tokens: self.config.max_tokens,
reasoning_effort: None,
};
let base_url = self
.config
.base_url
.as_deref()
.unwrap_or("https://api.x.ai/v1");
let url = format!("{}/chat/completions", base_url);
debug!(url = %url, "Sending request to Grok API");
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| handle_http_error(e, "Grok"))?;
let response = check_response_status(response, "Grok").await?;
debug!("Successfully received response from Grok API");
let completion: OpenAICompatibleChatCompletionResponse =
response.json().await.map_err(|e| {
error!(error = %e, "Failed to parse JSON response from Grok API");
e
})?;
if completion.choices.is_empty() {
error!("Grok API returned empty choices array");
return Err(RStructorError::api_error(
"Grok",
ApiErrorKind::UnexpectedResponse {
details: "No completion choices returned".to_string(),
},
));
}
let model_name = completion
.model
.clone()
.unwrap_or_else(|| self.config.model.as_str().to_string());
let usage = completion
.usage
.as_ref()
.map(|u| TokenUsage::new(model_name, u.prompt_tokens, u.completion_tokens));
let message = &completion.choices[0].message;
trace!(finish_reason = %completion.choices[0].finish_reason, "Completion finish reason");
if let Some(content) = &message.content {
debug!(
content_len = content.len(),
"Successfully extracted content from response"
);
Ok(GenerateResult::new(content.clone(), usage))
} else {
error!("No content in Grok API response");
Err(RStructorError::api_error(
"Grok",
ApiErrorKind::UnexpectedResponse {
details: "No content in response".to_string(),
},
))
}
}
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
let base_url = self
.config
.base_url
.as_deref()
.unwrap_or("https://api.x.ai/v1");
let url = format!("{}/models", base_url);
debug!(url = %url, "Fetching available models from Grok");
let response = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.send()
.await
.map_err(|e| handle_http_error(e, "Grok"))?;
let response = check_response_status(response, "Grok").await?;
let json: serde_json::Value = response.json().await.map_err(|e| {
error!(error = %e, "Failed to parse models response from Grok");
e
})?;
let models = json
.get("data")
.and_then(|data| data.as_array())
.map(|models_array| {
models_array
.iter()
.filter_map(|model| {
let id = model.get("id").and_then(|id| id.as_str())?;
if id.starts_with("grok-") {
Some(ModelInfo {
id: id.to_string(),
name: None,
description: None,
})
} else {
None
}
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
debug!(count = models.len(), "Fetched Grok models");
Ok(models)
}
}