use async_trait::async_trait;
use serde::de::DeserializeOwned;
use std::time::Duration;
use tracing::{debug, error, info, instrument, trace, warn};
use crate::backend::model_macro::define_model_enum;
use crate::backend::{
ChatMessage, DEFAULT_REQUEST_TIMEOUT, GenerateResult, LLMClient, MaterializeInternalOutput,
MaterializeResult, ModelInfo, OpenAICompatibleChatCompletionRequest,
OpenAICompatibleChatCompletionResponse, ResponseFormat, TokenUsage, ValidationFailureContext,
build_http_client, check_response_status, convert_openai_compatible_chat_messages,
generate_with_retry_with_history, handle_http_error, materialize_with_media_with_retry,
parse_validate_and_create_output, prepare_strict_schema,
};
#[cfg(feature = "streaming")]
use crate::backend::{OpenAICompatibleChatMessage, OpenAICompatibleMessageContent};
use crate::error::{ApiErrorKind, RStructorError, Result};
use crate::model::Instructor;
define_model_enum! {
pub enum Model {
Grok43 => "grok-4.3",
Grok420Reasoning => "grok-4.20-0309-reasoning",
Grok420NonReasoning => "grok-4.20-0309-non-reasoning",
Grok420MultiAgent => "grok-4.20-multi-agent-0309",
GrokBuild01 => "grok-build-0.1",
}
}
#[derive(Debug, Clone)]
pub struct GrokConfig {
pub api_key: String,
pub model: Model,
pub temperature: f32,
pub max_tokens: Option<u32>,
pub timeout: Option<Duration>,
pub max_retries: Option<usize>,
pub base_url: Option<String>,
}
#[derive(Clone)]
pub struct GrokClient {
config: GrokConfig,
client: reqwest::Client,
}
impl GrokClient {
#[instrument(name = "grok_client_new", skip(api_key), fields(model = ?Model::Grok43))]
pub fn new(api_key: impl Into<String>) -> Result<Self> {
let api_key = api_key.into();
if api_key.is_empty() {
return Err(RStructorError::api_error(
"Grok",
ApiErrorKind::AuthenticationFailed,
));
}
info!("Creating new Grok client");
trace!("API key length: {}", api_key.len());
let config = GrokConfig {
api_key,
model: Model::Grok43, temperature: 0.0,
max_tokens: None,
timeout: Some(DEFAULT_REQUEST_TIMEOUT), max_retries: Some(3), base_url: None, };
debug!("Grok client created with default configuration");
Ok(Self {
config,
client: build_http_client(DEFAULT_REQUEST_TIMEOUT),
})
}
#[instrument(name = "grok_client_from_env", fields(model = ?Model::Grok43))]
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("XAI_API_KEY")
.map_err(|_| RStructorError::api_error("Grok", ApiErrorKind::AuthenticationFailed))?;
info!("Creating new Grok client from environment variable");
trace!("API key length: {}", api_key.len());
let config = GrokConfig {
api_key,
model: Model::Grok43, temperature: 0.0,
max_tokens: None,
timeout: Some(DEFAULT_REQUEST_TIMEOUT), max_retries: Some(3), base_url: None, };
debug!("Grok client created with default configuration");
Ok(Self {
config,
client: build_http_client(DEFAULT_REQUEST_TIMEOUT),
})
}
}
impl GrokClient {
async fn materialize_internal<T>(
&self,
messages: &[ChatMessage],
) -> std::result::Result<
MaterializeInternalOutput<T>,
(RStructorError, Option<ValidationFailureContext>),
>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
info!("Generating structured response with Grok (native structured outputs)");
let schema = T::schema();
let schema_name = T::schema_name().unwrap_or_else(|| "output".to_string());
trace!(schema_name = schema_name, "Retrieved JSON schema for type");
let schema_json = prepare_strict_schema(&schema);
let api_messages =
convert_openai_compatible_chat_messages(messages, "Grok").map_err(|e| (e, None))?;
let response_format = ResponseFormat::json_schema(schema_name.clone(), schema_json, None);
debug!(
"Building Grok API request with structured outputs (history_len={})",
api_messages.len()
);
let request = OpenAICompatibleChatCompletionRequest {
model: self.config.model.as_str().to_string(),
messages: api_messages,
response_format: Some(response_format),
temperature: self.config.temperature,
max_tokens: self.config.max_tokens,
reasoning_effort: None,
};
let base_url = self
.config
.base_url
.as_deref()
.unwrap_or("https://api.x.ai/v1");
let url = format!("{}/chat/completions", base_url);
debug!(url = %url, "Sending request to Grok API with structured outputs");
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| (handle_http_error(e, "Grok"), None))?;
let response = check_response_status(response, "Grok")
.await
.map_err(|e| (e, None))?;
debug!("Successfully received response from Grok API");
let completion: OpenAICompatibleChatCompletionResponse =
response.json().await.map_err(|e| {
error!(error = %e, "Failed to parse JSON response from Grok API");
(RStructorError::from(e), None)
})?;
if completion.choices.is_empty() {
error!("Grok API returned empty choices array");
return Err((
RStructorError::api_error(
"Grok",
ApiErrorKind::UnexpectedResponse {
details: "No completion choices returned".to_string(),
},
),
None,
));
}
let model_name = completion
.model
.clone()
.unwrap_or_else(|| self.config.model.as_str().to_string());
let usage = completion
.usage
.as_ref()
.map(|u| TokenUsage::new(model_name.clone(), u.prompt_tokens, u.completion_tokens));
let message = &completion.choices[0].message;
trace!(finish_reason = %completion.choices[0].finish_reason, "Completion finish reason");
if let Some(content) = &message.content {
let raw_response = content.clone();
debug!(
content_len = raw_response.len(),
"Received structured output response"
);
trace!(json = %raw_response, "Parsing structured output response");
parse_validate_and_create_output(raw_response, usage)
} else {
error!("No content in Grok API response");
Err((
RStructorError::api_error(
"Grok",
ApiErrorKind::UnexpectedResponse {
details: "No content in response".to_string(),
},
),
None,
))
}
}
async fn generate_internal(&self, messages: &[ChatMessage]) -> Result<GenerateResult> {
info!("Generating raw text response with Grok");
debug!("Building Grok API request for text generation");
let request = OpenAICompatibleChatCompletionRequest {
model: self.config.model.as_str().to_string(),
messages: convert_openai_compatible_chat_messages(messages, "Grok")?,
response_format: None,
temperature: self.config.temperature,
max_tokens: self.config.max_tokens,
reasoning_effort: None,
};
let base_url = self
.config
.base_url
.as_deref()
.unwrap_or("https://api.x.ai/v1");
let url = format!("{}/chat/completions", base_url);
debug!(url = %url, "Sending request to Grok API");
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| handle_http_error(e, "Grok"))?;
let response = check_response_status(response, "Grok").await?;
debug!("Successfully received response from Grok API");
let completion: OpenAICompatibleChatCompletionResponse =
response.json().await.map_err(|e| {
error!(error = %e, "Failed to parse JSON response from Grok API");
e
})?;
if completion.choices.is_empty() {
error!("Grok API returned empty choices array");
return Err(RStructorError::api_error(
"Grok",
ApiErrorKind::UnexpectedResponse {
details: "No completion choices returned".to_string(),
},
));
}
let model_name = completion
.model
.clone()
.unwrap_or_else(|| self.config.model.as_str().to_string());
let usage = completion
.usage
.as_ref()
.map(|u| TokenUsage::new(model_name, u.prompt_tokens, u.completion_tokens));
let message = &completion.choices[0].message;
trace!(finish_reason = %completion.choices[0].finish_reason, "Completion finish reason");
if let Some(content) = &message.content {
debug!(
content_len = content.len(),
"Successfully extracted content from response"
);
Ok(GenerateResult::new(content.clone(), usage))
} else {
error!("No content in Grok API response");
Err(RStructorError::api_error(
"Grok",
ApiErrorKind::UnexpectedResponse {
details: "No content in response".to_string(),
},
))
}
}
}
crate::impl_client_builder_methods! {
client_type: GrokClient,
config_type: GrokConfig,
model_type: Model,
provider_name: "Grok"
}
impl GrokClient {
#[tracing::instrument(skip(self, base_url))]
pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
let base_url_str = base_url.into();
tracing::debug!(
previous_base_url = ?self.config.base_url,
new_base_url = %base_url_str,
"Setting custom base URL"
);
self.config.base_url = Some(base_url_str);
self
}
}
#[cfg(feature = "streaming")]
impl GrokClient {
fn stream_body(
&self,
prompt: &str,
response_format: Option<ResponseFormat>,
) -> serde_json::Value {
let request = OpenAICompatibleChatCompletionRequest {
model: self.config.model.as_str().to_string(),
messages: vec![OpenAICompatibleChatMessage {
role: "user".to_string(),
content: OpenAICompatibleMessageContent::Text(prompt.to_string()),
}],
response_format,
temperature: self.config.temperature,
max_tokens: self.config.max_tokens,
reasoning_effort: None,
};
let mut body = serde_json::to_value(&request).unwrap_or_else(|_| serde_json::json!({}));
body["stream"] = serde_json::Value::Bool(true);
body
}
fn send_stream(
&self,
body: serde_json::Value,
) -> impl std::future::Future<Output = Result<reqwest::Response>> + Send + 'static {
let client = self.client.clone();
let api_key = self.config.api_key.clone();
let base_url = self
.config
.base_url
.clone()
.unwrap_or_else(|| "https://api.x.ai/v1".to_string());
async move {
let url = format!("{}/chat/completions", base_url);
let resp = client
.post(&url)
.header("Authorization", format!("Bearer {api_key}"))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| handle_http_error(e, "Grok"))?;
check_response_status(resp, "Grok").await
}
}
}
#[cfg(feature = "tools")]
#[async_trait]
impl crate::backend::tools::ToolRunner for GrokClient {
async fn run_tool_loop(
&self,
system: Option<&str>,
prompt: &str,
media: &[super::MediaFile],
toolbox: &crate::backend::tools::Toolbox,
max_iterations: usize,
) -> Result<String> {
let base_url = self
.config
.base_url
.as_deref()
.unwrap_or("https://api.x.ai/v1");
let url = format!("{}/chat/completions", base_url);
crate::backend::tools::run_openai_compatible_tools(
&self.client,
&url,
&self.config.api_key,
"Grok",
self.config.model.as_str(),
self.config.temperature,
self.config.max_tokens,
None,
system,
prompt,
media,
toolbox,
max_iterations,
)
.await
}
}
#[async_trait]
impl LLMClient for GrokClient {
fn from_env() -> Result<Self> {
Self::from_env()
}
#[instrument(
name = "grok_materialize",
skip(self, prompt),
fields(
type_name = std::any::type_name::<T>(),
model = %self.config.model.as_str(),
prompt_len = prompt.len()
)
)]
async fn materialize<T>(&self, prompt: &str) -> Result<T>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
let output = generate_with_retry_with_history(
|messages: Vec<ChatMessage>| {
let this = self;
async move { this.materialize_internal::<T>(&messages).await }
},
prompt,
self.config.max_retries,
)
.await?;
Ok(output.data)
}
#[instrument(
name = "grok_materialize_with_media",
skip(self, prompt, media),
fields(
type_name = std::any::type_name::<T>(),
model = %self.config.model.as_str(),
prompt_len = prompt.len(),
media_len = media.len()
)
)]
async fn materialize_with_media<T>(&self, prompt: &str, media: &[super::MediaFile]) -> Result<T>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
materialize_with_media_with_retry(
|messages: Vec<ChatMessage>| {
let this = self;
async move { this.materialize_internal::<T>(&messages).await }
},
prompt,
media,
self.config.max_retries,
)
.await
}
#[instrument(
name = "grok_materialize_with_metadata",
skip(self, prompt),
fields(
type_name = std::any::type_name::<T>(),
model = %self.config.model.as_str(),
prompt_len = prompt.len()
)
)]
async fn materialize_with_metadata<T>(&self, prompt: &str) -> Result<MaterializeResult<T>>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
let output = generate_with_retry_with_history(
|messages: Vec<ChatMessage>| {
let this = self;
async move { this.materialize_internal::<T>(&messages).await }
},
prompt,
self.config.max_retries,
)
.await?;
Ok(MaterializeResult::new(output.data, output.usage))
}
#[instrument(
name = "grok_generate",
skip(self, prompt),
fields(
model = %self.config.model.as_str(),
prompt_len = prompt.len()
)
)]
async fn generate(&self, prompt: &str) -> Result<String> {
let result = self.generate_with_metadata(prompt).await?;
Ok(result.text)
}
#[instrument(
name = "grok_generate_with_media",
skip(self, prompt, media),
fields(
model = %self.config.model.as_str(),
prompt_len = prompt.len(),
media_len = media.len()
)
)]
async fn generate_with_media(
&self,
prompt: &str,
media: &[super::MediaFile],
) -> Result<String> {
let result = self
.generate_internal(&[ChatMessage::user_with_media(prompt, media.to_vec())])
.await?;
Ok(result.text)
}
#[instrument(
name = "grok_generate_with_metadata",
skip(self, prompt),
fields(
model = %self.config.model.as_str(),
prompt_len = prompt.len()
)
)]
async fn generate_with_metadata(&self, prompt: &str) -> Result<GenerateResult> {
self.generate_internal(&[ChatMessage::user(prompt)]).await
}
#[cfg(feature = "streaming")]
fn generate_stream<'a>(&'a self, prompt: &'a str) -> crate::backend::streaming::TextStream<'a>
where
Self: Sync,
{
let body = self.stream_body(prompt, None);
crate::backend::streaming::sse_text_stream(
self.send_stream(body),
crate::backend::streaming::openai_delta,
)
}
#[cfg(feature = "streaming")]
fn materialize_stream<'a, T>(
&'a self,
prompt: &'a str,
) -> crate::backend::streaming::ObjectStream<'a, T>
where
T: Instructor + DeserializeOwned + Send + 'static,
Self: Sync,
{
let schema = T::schema();
let schema_name = T::schema_name().unwrap_or_else(|| "output".to_string());
let schema_json = prepare_strict_schema(&schema);
let response_format = ResponseFormat::json_schema(schema_name, schema_json, None);
let body = self.stream_body(prompt, Some(response_format));
crate::backend::streaming::object_stream(
self.send_stream(body),
crate::backend::streaming::openai_delta,
)
}
#[cfg(feature = "streaming")]
fn materialize_iter<'a, T>(
&'a self,
prompt: &'a str,
) -> crate::backend::streaming::ItemStream<'a, T>
where
T: Instructor + DeserializeOwned + Send + 'static,
Self: Sync,
{
let item_schema = prepare_strict_schema(&T::schema());
let wrapper = crate::backend::streaming::array_wrapper_schema(item_schema, true);
let response_format = ResponseFormat::json_schema("items".to_string(), wrapper, None);
let body = self.stream_body(prompt, Some(response_format));
crate::backend::streaming::iter_stream(
self.send_stream(body),
crate::backend::streaming::openai_delta,
crate::backend::streaming::finalize_item::<T>,
)
}
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
let base_url = self
.config
.base_url
.as_deref()
.unwrap_or("https://api.x.ai/v1");
let url = format!("{}/models", base_url);
debug!(url = %url, "Fetching available models from Grok");
let response = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.send()
.await
.map_err(|e| handle_http_error(e, "Grok"))?;
let response = check_response_status(response, "Grok").await?;
let json: serde_json::Value = response.json().await.map_err(|e| {
error!(error = %e, "Failed to parse models response from Grok");
e
})?;
let models = json
.get("data")
.and_then(|data| data.as_array())
.map(|models_array| {
models_array
.iter()
.filter_map(|model| {
let id = model.get("id").and_then(|id| id.as_str())?;
if id.starts_with("grok-") {
Some(ModelInfo {
id: id.to_string(),
name: None,
description: None,
})
} else {
None
}
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
debug!(count = models.len(), "Fetched Grok models");
Ok(models)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::DEFAULT_REQUEST_TIMEOUT;
#[test]
fn default_config_has_default_timeout() {
let client = GrokClient::new("test-key").unwrap();
assert_eq!(client.config.timeout, Some(DEFAULT_REQUEST_TIMEOUT));
}
#[test]
fn explicit_timeout_overrides_default() {
let client = GrokClient::new("test-key")
.unwrap()
.timeout(Duration::from_secs(10));
assert_eq!(client.config.timeout, Some(Duration::from_secs(10)));
}
}