use crate::error::{Error, Result};
use crate::models::{ChatMessage, ChatRequest, ChatResponse, Model};
use crate::retry::{RetryConfig, retry_async};
use reqwest::{Client, ClientBuilder};
use serde_json::Value;
use std::time::Duration;
use tracing::{debug, info, warn};
const DEFAULT_BASE_URL: &str = "https://api.x.ai";
const DEFAULT_TIMEOUT_SECS: u64 = 30;
const DEFAULT_MODEL: &str = "grok-4-1-fast-reasoning";
#[derive(Clone, Debug)]
pub struct GrokClient {
client: Client,
api_key: String,
base_url: String,
retry_config: RetryConfig,
}
impl GrokClient {
pub fn new(api_key: impl Into<String>) -> Result<Self> {
Self::builder().api_key(api_key).build()
}
pub fn builder() -> GrokClientBuilder {
GrokClientBuilder::default()
}
pub async fn chat(&self, message: impl Into<String>, model: Option<&str>) -> Result<String> {
let message_text: String = message.into();
let messages = vec![ChatMessage::user(message_text)];
let response: ChatResponse = self
.chat_with_history(&messages)
.model(model.unwrap_or(DEFAULT_MODEL))
.send()
.await?;
Ok(response.content().unwrap_or("").to_string())
}
pub fn chat_with_history(&self, messages: &[ChatMessage]) -> ChatRequestBuilder {
ChatRequestBuilder::new(self.clone(), messages.to_vec())
}
pub async fn list_models(&self) -> Result<Vec<String>> {
let url = format!("{}/v1/models", self.base_url);
let result = retry_async(&self.retry_config, || async {
let response = self
.client
.get(&url)
.bearer_auth(&self.api_key)
.send()
.await
.map_err(Error::from_reqwest)?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body = response.text().await.unwrap_or_default();
return Err(Error::Http {
status,
message: body,
});
}
let json: Value = response.json().await.map_err(Error::from_reqwest)?;
if let Some(data) = json.get("data").and_then(|d| d.as_array()) {
let models = data
.iter()
.filter_map(|m| m.get("id")?.as_str())
.map(String::from)
.collect();
Ok(models)
} else {
Ok(Model::all().iter().map(|m: &Model| m.to_string()).collect())
}
})
.await?;
Ok(result)
}
pub async fn test_connection(&self) -> Result<()> {
info!("Testing connection to Grok API");
let _ = self.chat("Hello", Some("grok-4-1-fast-reasoning")).await?;
info!("Connection test successful");
Ok(())
}
async fn send_chat_request(&self, request: ChatRequest) -> Result<ChatResponse> {
let url = format!("{}/v1/chat/completions", self.base_url);
debug!("Sending chat request to {}", url);
debug!(
"Model: {}, Messages: {}",
request.model,
request.messages.len()
);
let result = retry_async(&self.retry_config, || async {
let response = self
.client
.post(&url)
.bearer_auth(&self.api_key)
.json(&request)
.send()
.await
.map_err(Error::from_reqwest)?;
let status = response.status();
if !status.is_success() {
let status_code = status.as_u16();
let body = response.text().await.unwrap_or_default();
warn!("API request failed: {} - {}", status_code, body);
return Err(match status_code {
401 => Error::Authentication,
429 => Error::RateLimit,
404 => Error::ModelNotFound(request.model.clone()),
400 => Error::InvalidRequest(body),
500..=599 => Error::ServerError(body),
_ => Error::Http {
status: status_code,
message: body,
},
});
}
let response_text = response.text().await.map_err(Error::from_reqwest)?;
debug!("Response received: {} bytes", response_text.len());
let chat_response: ChatResponse =
serde_json::from_str(&response_text).map_err(|e| {
warn!("Failed to parse response: {}", e);
Error::Json(e)
})?;
info!(
"Chat completion successful - Model: {}, Tokens: {}",
chat_response.model, chat_response.usage.total_tokens
);
Ok(chat_response)
})
.await?;
Ok(result)
}
}
#[derive(Default)]
pub struct GrokClientBuilder {
api_key: Option<String>,
base_url: Option<String>,
timeout_secs: Option<u64>,
max_retries: Option<u32>,
retry_config: Option<RetryConfig>,
}
impl GrokClientBuilder {
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = Some(url.into());
self
}
pub fn timeout_secs(mut self, timeout: u64) -> Self {
self.timeout_secs = Some(timeout);
self
}
pub fn max_retries(mut self, retries: u32) -> Self {
self.max_retries = Some(retries);
self
}
pub fn retry_config(mut self, config: RetryConfig) -> Self {
self.retry_config = Some(config);
self
}
pub fn build(self) -> Result<GrokClient> {
let api_key = self.api_key.ok_or(Error::EmptyApiKey)?;
if api_key.is_empty() {
return Err(Error::EmptyApiKey);
}
let timeout = self.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
let client = ClientBuilder::new()
.timeout(Duration::from_secs(timeout))
.connect_timeout(Duration::from_secs(10))
.tcp_keepalive(Duration::from_secs(30))
.user_agent(format!("grok_api/{}", env!("CARGO_PKG_VERSION")))
.build()
.map_err(Error::from_reqwest)?;
let retry_config = self
.retry_config
.unwrap_or_else(|| RetryConfig::new(self.max_retries.unwrap_or(3)));
Ok(GrokClient {
client,
api_key,
base_url: self
.base_url
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string()),
retry_config,
})
}
}
pub struct ChatRequestBuilder {
client: GrokClient,
messages: Vec<ChatMessage>,
model: Option<String>,
temperature: Option<f32>,
max_tokens: Option<u32>,
tools: Option<Vec<Value>>,
top_p: Option<f32>,
}
impl ChatRequestBuilder {
fn new(client: GrokClient, messages: Vec<ChatMessage>) -> Self {
Self {
client,
messages,
model: None,
temperature: None,
max_tokens: None,
tools: None,
top_p: None,
}
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp.clamp(0.0, 2.0));
self
}
pub fn max_tokens(mut self, tokens: u32) -> Self {
self.max_tokens = Some(tokens);
self
}
pub fn tools(mut self, tools: Vec<Value>) -> Self {
self.tools = Some(tools);
self
}
pub fn top_p(mut self, p: f32) -> Self {
self.top_p = Some(p.clamp(0.0, 1.0));
self
}
pub async fn send(self) -> Result<ChatResponse> {
let request = ChatRequest {
model: self.model.unwrap_or_else(|| DEFAULT_MODEL.to_string()),
messages: self.messages,
temperature: self.temperature,
max_tokens: self.max_tokens,
stream: Some(false),
tools: self.tools,
top_p: self.top_p,
frequency_penalty: None,
presence_penalty: None,
};
self.client.send_chat_request(request).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_builder() {
let client = GrokClient::builder()
.api_key("test-key")
.timeout_secs(60)
.max_retries(5)
.build();
assert!(client.is_ok());
let client = client.unwrap();
assert_eq!(client.api_key, "test-key");
}
#[test]
fn test_empty_api_key_error() {
let result = GrokClient::new("");
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), Error::EmptyApiKey));
}
#[test]
fn test_builder_missing_api_key() {
let result = GrokClient::builder().timeout_secs(30).build();
assert!(result.is_err());
}
}