use async_trait::async_trait;
use backoff::{ExponentialBackoff, future::retry};
use log::{debug, error, info};
use reqwest::{Client, ClientBuilder, Response};
use serde::Serialize;
use std::env;
use std::time::Duration;
use tokio::sync::RwLock;
use std::sync::Arc;
use md5;
use super::super::entities::{Message, ChatCompletionResponse};
use super::super::metrics::Metrics;
use super::super::cache::ResponseCache;
use super::super::error::LlmClientError;
use super::ChatCompletionClient;
#[derive(Debug)]
struct GrokApiConfig {
api_key: String,
endpoint: String,
model: String,
}
#[derive(Debug, Serialize)]
struct GrokChatCompletionRequest {
messages: Vec<Message>,
reasoning_effort: String,
model: String,
}
pub struct GrokClient {
client: Client,
config: GrokApiConfig,
cache: ResponseCache,
metrics: Arc<RwLock<Metrics>>,
}
impl GrokClient {
pub fn new() -> Result<Self, LlmClientError> {
let _ = dotenv::dotenv();
let api_key = env::var("GROK_API_KEY")
.map_err(|_| LlmClientError::EnvVarMissing("GROK_API_KEY".to_string()))?;
let endpoint = env::var("GROK_API_ENDPOINT")
.unwrap_or_else(|_| "https://api.x.ai/v1/chat/completions".to_string());
let model = env::var("GROK_MODEL")
.unwrap_or_else(|_| "grok-3-mini-fast-latest".to_string());
let cache_size = env::var("GROK_CACHE_SIZE")
.unwrap_or_else(|_| "100".to_string())
.parse::<usize>()
.map_err(|_| LlmClientError::ValidationError("Invalid GROK_CACHE_SIZE".to_string()))?;
let client = ClientBuilder::new()
.timeout(Duration::from_secs(30))
.connect_timeout(Duration::from_secs(5))
.build()?;
Ok(GrokClient {
client,
config: GrokApiConfig {
api_key,
endpoint,
model,
},
cache: ResponseCache::new(cache_size),
metrics: Arc::new(RwLock::new(Metrics::default())),
})
}
fn validate_input(&self, messages: &[Message], reasoning_effort: &str) -> Result<(), LlmClientError> {
if messages.is_empty() {
return Err(LlmClientError::ValidationError("Messages cannot be empty".to_string()));
}
for msg in messages {
if msg.role.is_empty() || msg.content.is_empty() {
return Err(LlmClientError::ValidationError(
"Message role and content cannot be empty".to_string(),
));
}
if !["system", "user", "assistant"].contains(&msg.role.as_str()) {
return Err(LlmClientError::ValidationError(
format!("Invalid role: {}", msg.role),
));
}
}
if !["low", "medium", "high"].contains(&reasoning_effort) {
return Err(LlmClientError::ValidationError(
format!("Invalid reasoning_effort: {}", reasoning_effort),
));
}
Ok(())
}
fn generate_cache_key(&self, messages: &[Message], reasoning_effort: &str) -> String {
let mut key = String::new();
for msg in messages {
key.push_str(&msg.role);
key.push_str(&msg.content);
}
key.push_str(reasoning_effort);
key.push_str(&self.config.model);
format!("{:x}", md5::compute(key))
}
async fn handle_rate_limit(&self, response: &Response) -> Result<(), LlmClientError> {
if response.status().as_u16() == 429 {
let retry_after = response
.headers()
.get("Retry-After")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(1);
let error_msg = format!("Rate limit exceeded, retry after {} seconds", retry_after);
error!("{}", error_msg);
return Err(LlmClientError::RateLimitExceeded(error_msg));
}
Ok(())
}
}
#[async_trait]
impl ChatCompletionClient for GrokClient {
async fn send_chat_completion(
&self,
messages: Vec<Message>,
reasoning_effort: &str,
) -> Result<ChatCompletionResponse, LlmClientError> {
self.validate_input(&messages, reasoning_effort)?;
let mut metrics = self.metrics.write().await;
metrics.increment_request();
let cache_key = self.generate_cache_key(&messages, reasoning_effort);
if let Some(cached_response) = self.cache.get(&cache_key).await {
metrics.increment_cache_hit();
info!("Cache hit for key: {}", cache_key);
metrics.increment_success();
return Ok(cached_response);
}
let payload = GrokChatCompletionRequest {
messages,
reasoning_effort: reasoning_effort.to_string(),
model: self.config.model.clone(),
};
let backoff = ExponentialBackoff {
max_elapsed_time: Some(Duration::from_secs(60)),
..Default::default()
};
info!("Sending request to Grok API with model: {}", self.config.model);
let response = retry(backoff, || async {
debug!("Attempting API request to {}", self.config.endpoint);
let result = self
.client
.post(&self.config.endpoint)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.json(&payload)
.send()
.await;
match result {
Ok(resp) => {
if resp.status().is_success() {
Ok(resp)
} else {
self.handle_rate_limit(&resp).await?;
let status = resp.status();
let error_msg = resp.text().await.unwrap_or_default();
error!("API request failed with status {}: {}", status, error_msg);
Err(backoff::Error::Permanent(LlmClientError::ApiError(
format!("API request failed with status: {}", status),
)))
}
}
Err(err) if err.is_timeout() || err.is_connect() => {
debug!("Retrying due to transient error: {}", err);
Err(backoff::Error::Transient {
err: LlmClientError::HttpError(err),
retry_after: None,
})
}
Err(err) => {
error!("Permanent HTTP error: {}", err);
Err(backoff::Error::Permanent(LlmClientError::HttpError(err)))
}
}
})
.await?;
let chat_response = response.json::<ChatCompletionResponse>().await?;
info!("Received successful response with ID: {}", chat_response.id);
if let Err(e) = self.cache.put(cache_key, chat_response.clone()).await {
error!("Failed to cache response: {}", e);
}
metrics.increment_success();
Ok(chat_response)
}
async fn get_metrics(&self) -> Metrics {
let metrics = self.metrics.read().await;
metrics.clone()
}
async fn stream_chat_completion(
&self,
_messages: Vec<Message>,
_reasoning_effort: &str,
) -> Result<(), LlmClientError> {
Err(LlmClientError::ApiError("Streaming not yet supported".to_string()))
}
}