ai_client/clients/
grok.rs1use async_trait::async_trait;
2use backoff::{ExponentialBackoff, future::retry};
3use log::{debug, error, info};
4use reqwest::{Client, ClientBuilder, Response};
5use serde::Serialize;
6use std::env;
7use std::time::Duration;
8use tokio::sync::RwLock;
9use std::sync::Arc;
10use md5;
11use super::super::entities::{Message, ChatCompletionResponse};
12use super::super::metrics::Metrics;
13use super::super::cache::ResponseCache;
14use super::super::error::LlmClientError;
15use super::ChatCompletionClient;
16
17#[derive(Debug)]
19struct GrokApiConfig {
20 api_key: String,
21 endpoint: String,
22 model: String,
23}
24
25#[derive(Debug, Serialize)]
27struct GrokChatCompletionRequest {
28 messages: Vec<Message>,
29 reasoning_effort: String,
30 model: String,
31}
32
33pub struct GrokClient {
35 client: Client,
36 config: GrokApiConfig,
37 cache: ResponseCache,
38 metrics: Arc<RwLock<Metrics>>,
39}
40
41impl GrokClient {
42 pub fn new() -> Result<Self, LlmClientError> {
44 let _ = dotenv::dotenv();
45
46 let api_key = env::var("GROK_API_KEY")
47 .map_err(|_| LlmClientError::EnvVarMissing("GROK_API_KEY".to_string()))?;
48 let endpoint = env::var("GROK_API_ENDPOINT")
49 .unwrap_or_else(|_| "https://api.x.ai/v1/chat/completions".to_string());
50 let model = env::var("GROK_MODEL")
51 .unwrap_or_else(|_| "grok-3-mini-fast-latest".to_string());
52 let cache_size = env::var("GROK_CACHE_SIZE")
53 .unwrap_or_else(|_| "100".to_string())
54 .parse::<usize>()
55 .map_err(|_| LlmClientError::ValidationError("Invalid GROK_CACHE_SIZE".to_string()))?;
56
57 let client = ClientBuilder::new()
58 .timeout(Duration::from_secs(30))
59 .connect_timeout(Duration::from_secs(5))
60 .build()?;
61
62 Ok(GrokClient {
63 client,
64 config: GrokApiConfig {
65 api_key,
66 endpoint,
67 model,
68 },
69 cache: ResponseCache::new(cache_size),
70 metrics: Arc::new(RwLock::new(Metrics::default())),
71 })
72 }
73
74 fn validate_input(&self, messages: &[Message], reasoning_effort: &str) -> Result<(), LlmClientError> {
75 if messages.is_empty() {
76 return Err(LlmClientError::ValidationError("Messages cannot be empty".to_string()));
77 }
78 for msg in messages {
79 if msg.role.is_empty() || msg.content.is_empty() {
80 return Err(LlmClientError::ValidationError(
81 "Message role and content cannot be empty".to_string(),
82 ));
83 }
84 if !["system", "user", "assistant"].contains(&msg.role.as_str()) {
85 return Err(LlmClientError::ValidationError(
86 format!("Invalid role: {}", msg.role),
87 ));
88 }
89 }
90 if !["low", "medium", "high"].contains(&reasoning_effort) {
91 return Err(LlmClientError::ValidationError(
92 format!("Invalid reasoning_effort: {}", reasoning_effort),
93 ));
94 }
95 Ok(())
96 }
97
98 fn generate_cache_key(&self, messages: &[Message], reasoning_effort: &str) -> String {
99 let mut key = String::new();
100 for msg in messages {
101 key.push_str(&msg.role);
102 key.push_str(&msg.content);
103 }
104 key.push_str(reasoning_effort);
105 key.push_str(&self.config.model);
106 format!("{:x}", md5::compute(key))
107 }
108
109 async fn handle_rate_limit(&self, response: &Response) -> Result<(), LlmClientError> {
110 if response.status().as_u16() == 429 {
111 let retry_after = response
112 .headers()
113 .get("Retry-After")
114 .and_then(|v| v.to_str().ok())
115 .and_then(|v| v.parse::<u64>().ok())
116 .unwrap_or(1);
117 let error_msg = format!("Rate limit exceeded, retry after {} seconds", retry_after);
118 error!("{}", error_msg);
119 return Err(LlmClientError::RateLimitExceeded(error_msg));
120 }
121 Ok(())
122 }
123}
124
125#[async_trait]
126impl ChatCompletionClient for GrokClient {
127 async fn send_chat_completion(
128 &self,
129 messages: Vec<Message>,
130 reasoning_effort: &str,
131 ) -> Result<ChatCompletionResponse, LlmClientError> {
132 self.validate_input(&messages, reasoning_effort)?;
133 let mut metrics = self.metrics.write().await;
134 metrics.increment_request();
135
136 let cache_key = self.generate_cache_key(&messages, reasoning_effort);
137 if let Some(cached_response) = self.cache.get(&cache_key).await {
138 metrics.increment_cache_hit();
139 info!("Cache hit for key: {}", cache_key);
140 metrics.increment_success();
141 return Ok(cached_response);
142 }
143
144 let payload = GrokChatCompletionRequest {
145 messages,
146 reasoning_effort: reasoning_effort.to_string(),
147 model: self.config.model.clone(),
148 };
149
150 let backoff = ExponentialBackoff {
151 max_elapsed_time: Some(Duration::from_secs(60)),
152 ..Default::default()
153 };
154
155 info!("Sending request to Grok API with model: {}", self.config.model);
156
157 let response = retry(backoff, || async {
158 debug!("Attempting API request to {}", self.config.endpoint);
159 let result = self
160 .client
161 .post(&self.config.endpoint)
162 .header("Authorization", format!("Bearer {}", self.config.api_key))
163 .json(&payload)
164 .send()
165 .await;
166
167 match result {
168 Ok(resp) => {
169 if resp.status().is_success() {
170 Ok(resp)
171 } else {
172 self.handle_rate_limit(&resp).await?;
173 let status = resp.status();
174 let error_msg = resp.text().await.unwrap_or_default();
175 error!("API request failed with status {}: {}", status, error_msg);
176 Err(backoff::Error::Permanent(LlmClientError::ApiError(
177 format!("API request failed with status: {}", status),
178 )))
179 }
180 }
181 Err(err) if err.is_timeout() || err.is_connect() => {
182 debug!("Retrying due to transient error: {}", err);
183 Err(backoff::Error::Transient {
184 err: LlmClientError::HttpError(err),
185 retry_after: None,
186 })
187 }
188 Err(err) => {
189 error!("Permanent HTTP error: {}", err);
190 Err(backoff::Error::Permanent(LlmClientError::HttpError(err)))
191 }
192 }
193 })
194 .await?;
195
196 let chat_response = response.json::<ChatCompletionResponse>().await?;
197 info!("Received successful response with ID: {}", chat_response.id);
198
199 if let Err(e) = self.cache.put(cache_key, chat_response.clone()).await {
200 error!("Failed to cache response: {}", e);
201 }
202
203 metrics.increment_success();
204 Ok(chat_response)
205 }
206
207 async fn get_metrics(&self) -> Metrics {
208 let metrics = self.metrics.read().await;
209 metrics.clone()
210 }
211
212 async fn stream_chat_completion(
213 &self,
214 _messages: Vec<Message>,
215 _reasoning_effort: &str,
216 ) -> Result<(), LlmClientError> {
217 Err(LlmClientError::ApiError("Streaming not yet supported".to_string()))
218 }
219}