ai_client/clients/
grok.rs

1use async_trait::async_trait;
2use backoff::{ExponentialBackoff, future::retry};
3use log::{debug, error, info};
4use reqwest::{Client, ClientBuilder, Response};
5use serde::Serialize;
6use std::env;
7use std::time::Duration;
8use tokio::sync::RwLock;
9use std::sync::Arc;
10use md5;
11use super::super::entities::{Message, ChatCompletionResponse};
12use super::super::metrics::Metrics;
13use super::super::cache::ResponseCache;
14use super::super::error::LlmClientError;
15use super::ChatCompletionClient;
16
17/// Configuration for the Grok API client
18#[derive(Debug)]
19struct GrokApiConfig {
20    api_key: String,
21    endpoint: String,
22    model: String,
23}
24
25/// Request structure for Grok chat completions
26#[derive(Debug, Serialize)]
27struct GrokChatCompletionRequest {
28    messages: Vec<Message>,
29    reasoning_effort: String,
30    model: String,
31}
32
33/// Grok client implementing the ChatCompletionClient trait
34pub struct GrokClient {
35    client: Client,
36    config: GrokApiConfig,
37    cache: ResponseCache,
38    metrics: Arc<RwLock<Metrics>>,
39}
40
41impl GrokClient {
42    /// Creates a new GrokClient instance
43    pub fn new() -> Result<Self, LlmClientError> {
44        let _ = dotenv::dotenv();
45
46        let api_key = env::var("GROK_API_KEY")
47            .map_err(|_| LlmClientError::EnvVarMissing("GROK_API_KEY".to_string()))?;
48        let endpoint = env::var("GROK_API_ENDPOINT")
49            .unwrap_or_else(|_| "https://api.x.ai/v1/chat/completions".to_string());
50        let model = env::var("GROK_MODEL")
51            .unwrap_or_else(|_| "grok-3-mini-fast-latest".to_string());
52        let cache_size = env::var("GROK_CACHE_SIZE")
53            .unwrap_or_else(|_| "100".to_string())
54            .parse::<usize>()
55            .map_err(|_| LlmClientError::ValidationError("Invalid GROK_CACHE_SIZE".to_string()))?;
56
57        let client = ClientBuilder::new()
58            .timeout(Duration::from_secs(30))
59            .connect_timeout(Duration::from_secs(5))
60            .build()?;
61
62        Ok(GrokClient {
63            client,
64            config: GrokApiConfig {
65                api_key,
66                endpoint,
67                model,
68            },
69            cache: ResponseCache::new(cache_size),
70            metrics: Arc::new(RwLock::new(Metrics::default())),
71        })
72    }
73
74    fn validate_input(&self, messages: &[Message], reasoning_effort: &str) -> Result<(), LlmClientError> {
75        if messages.is_empty() {
76            return Err(LlmClientError::ValidationError("Messages cannot be empty".to_string()));
77        }
78        for msg in messages {
79            if msg.role.is_empty() || msg.content.is_empty() {
80                return Err(LlmClientError::ValidationError(
81                    "Message role and content cannot be empty".to_string(),
82                ));
83            }
84            if !["system", "user", "assistant"].contains(&msg.role.as_str()) {
85                return Err(LlmClientError::ValidationError(
86                    format!("Invalid role: {}", msg.role),
87                ));
88            }
89        }
90        if !["low", "medium", "high"].contains(&reasoning_effort) {
91            return Err(LlmClientError::ValidationError(
92                format!("Invalid reasoning_effort: {}", reasoning_effort),
93            ));
94        }
95        Ok(())
96    }
97
98    fn generate_cache_key(&self, messages: &[Message], reasoning_effort: &str) -> String {
99        let mut key = String::new();
100        for msg in messages {
101            key.push_str(&msg.role);
102            key.push_str(&msg.content);
103        }
104        key.push_str(reasoning_effort);
105        key.push_str(&self.config.model);
106        format!("{:x}", md5::compute(key))
107    }
108
109    async fn handle_rate_limit(&self, response: &Response) -> Result<(), LlmClientError> {
110        if response.status().as_u16() == 429 {
111            let retry_after = response
112                .headers()
113                .get("Retry-After")
114                .and_then(|v| v.to_str().ok())
115                .and_then(|v| v.parse::<u64>().ok())
116                .unwrap_or(1);
117            let error_msg = format!("Rate limit exceeded, retry after {} seconds", retry_after);
118            error!("{}", error_msg);
119            return Err(LlmClientError::RateLimitExceeded(error_msg));
120        }
121        Ok(())
122    }
123}
124
125#[async_trait]
126impl ChatCompletionClient for GrokClient {
127    async fn send_chat_completion(
128        &self,
129        messages: Vec<Message>,
130        reasoning_effort: &str,
131    ) -> Result<ChatCompletionResponse, LlmClientError> {
132        self.validate_input(&messages, reasoning_effort)?;
133        let mut metrics = self.metrics.write().await;
134        metrics.increment_request();
135
136        let cache_key = self.generate_cache_key(&messages, reasoning_effort);
137        if let Some(cached_response) = self.cache.get(&cache_key).await {
138            metrics.increment_cache_hit();
139            info!("Cache hit for key: {}", cache_key);
140            metrics.increment_success();
141            return Ok(cached_response);
142        }
143
144        let payload = GrokChatCompletionRequest {
145            messages,
146            reasoning_effort: reasoning_effort.to_string(),
147            model: self.config.model.clone(),
148        };
149
150        let backoff = ExponentialBackoff {
151            max_elapsed_time: Some(Duration::from_secs(60)),
152            ..Default::default()
153        };
154
155        info!("Sending request to Grok API with model: {}", self.config.model);
156
157        let response = retry(backoff, || async {
158            debug!("Attempting API request to {}", self.config.endpoint);
159            let result = self
160                .client
161                .post(&self.config.endpoint)
162                .header("Authorization", format!("Bearer {}", self.config.api_key))
163                .json(&payload)
164                .send()
165                .await;
166
167            match result {
168                Ok(resp) => {
169                    if resp.status().is_success() {
170                        Ok(resp)
171                    } else {
172                        self.handle_rate_limit(&resp).await?;
173                        let status = resp.status();
174                        let error_msg = resp.text().await.unwrap_or_default();
175                        error!("API request failed with status {}: {}", status, error_msg);
176                        Err(backoff::Error::Permanent(LlmClientError::ApiError(
177                            format!("API request failed with status: {}", status),
178                        )))
179                    }
180                }
181                Err(err) if err.is_timeout() || err.is_connect() => {
182                    debug!("Retrying due to transient error: {}", err);
183                    Err(backoff::Error::Transient {
184                        err: LlmClientError::HttpError(err),
185                        retry_after: None,
186                    })
187                }
188                Err(err) => {
189                    error!("Permanent HTTP error: {}", err);
190                    Err(backoff::Error::Permanent(LlmClientError::HttpError(err)))
191                }
192            }
193        })
194        .await?;
195
196        let chat_response = response.json::<ChatCompletionResponse>().await?;
197        info!("Received successful response with ID: {}", chat_response.id);
198
199        if let Err(e) = self.cache.put(cache_key, chat_response.clone()).await {
200            error!("Failed to cache response: {}", e);
201        }
202
203        metrics.increment_success();
204        Ok(chat_response)
205    }
206
207    async fn get_metrics(&self) -> Metrics {
208        let metrics = self.metrics.read().await;
209        metrics.clone()
210    }
211
212    async fn stream_chat_completion(
213        &self,
214        _messages: Vec<Message>,
215        _reasoning_effort: &str,
216    ) -> Result<(), LlmClientError> {
217        Err(LlmClientError::ApiError("Streaming not yet supported".to_string()))
218    }
219}