ai_client 0.1.0

A Rust crate for interacting with AI language model APIs
Documentation
use async_trait::async_trait;
use backoff::{ExponentialBackoff, future::retry};
use log::{debug, error, info};
use reqwest::{Client, ClientBuilder, Response};
use serde::Serialize;
use std::env;
use std::time::Duration;
use tokio::sync::RwLock;
use std::sync::Arc;
use md5;
use super::super::entities::{Message, ChatCompletionResponse};
use super::super::metrics::Metrics;
use super::super::cache::ResponseCache;
use super::super::error::LlmClientError;
use super::ChatCompletionClient;

/// Configuration for the Grok API client
#[derive(Debug)]
struct GrokApiConfig {
    api_key: String,
    endpoint: String,
    model: String,
}

/// Request structure for Grok chat completions
#[derive(Debug, Serialize)]
struct GrokChatCompletionRequest {
    messages: Vec<Message>,
    reasoning_effort: String,
    model: String,
}

/// Grok client implementing the ChatCompletionClient trait
pub struct GrokClient {
    client: Client,
    config: GrokApiConfig,
    cache: ResponseCache,
    metrics: Arc<RwLock<Metrics>>,
}

impl GrokClient {
    /// Creates a new GrokClient instance
    pub fn new() -> Result<Self, LlmClientError> {
        let _ = dotenv::dotenv();

        let api_key = env::var("GROK_API_KEY")
            .map_err(|_| LlmClientError::EnvVarMissing("GROK_API_KEY".to_string()))?;
        let endpoint = env::var("GROK_API_ENDPOINT")
            .unwrap_or_else(|_| "https://api.x.ai/v1/chat/completions".to_string());
        let model = env::var("GROK_MODEL")
            .unwrap_or_else(|_| "grok-3-mini-fast-latest".to_string());
        let cache_size = env::var("GROK_CACHE_SIZE")
            .unwrap_or_else(|_| "100".to_string())
            .parse::<usize>()
            .map_err(|_| LlmClientError::ValidationError("Invalid GROK_CACHE_SIZE".to_string()))?;

        let client = ClientBuilder::new()
            .timeout(Duration::from_secs(30))
            .connect_timeout(Duration::from_secs(5))
            .build()?;

        Ok(GrokClient {
            client,
            config: GrokApiConfig {
                api_key,
                endpoint,
                model,
            },
            cache: ResponseCache::new(cache_size),
            metrics: Arc::new(RwLock::new(Metrics::default())),
        })
    }

    fn validate_input(&self, messages: &[Message], reasoning_effort: &str) -> Result<(), LlmClientError> {
        if messages.is_empty() {
            return Err(LlmClientError::ValidationError("Messages cannot be empty".to_string()));
        }
        for msg in messages {
            if msg.role.is_empty() || msg.content.is_empty() {
                return Err(LlmClientError::ValidationError(
                    "Message role and content cannot be empty".to_string(),
                ));
            }
            if !["system", "user", "assistant"].contains(&msg.role.as_str()) {
                return Err(LlmClientError::ValidationError(
                    format!("Invalid role: {}", msg.role),
                ));
            }
        }
        if !["low", "medium", "high"].contains(&reasoning_effort) {
            return Err(LlmClientError::ValidationError(
                format!("Invalid reasoning_effort: {}", reasoning_effort),
            ));
        }
        Ok(())
    }

    fn generate_cache_key(&self, messages: &[Message], reasoning_effort: &str) -> String {
        let mut key = String::new();
        for msg in messages {
            key.push_str(&msg.role);
            key.push_str(&msg.content);
        }
        key.push_str(reasoning_effort);
        key.push_str(&self.config.model);
        format!("{:x}", md5::compute(key))
    }

    async fn handle_rate_limit(&self, response: &Response) -> Result<(), LlmClientError> {
        if response.status().as_u16() == 429 {
            let retry_after = response
                .headers()
                .get("Retry-After")
                .and_then(|v| v.to_str().ok())
                .and_then(|v| v.parse::<u64>().ok())
                .unwrap_or(1);
            let error_msg = format!("Rate limit exceeded, retry after {} seconds", retry_after);
            error!("{}", error_msg);
            return Err(LlmClientError::RateLimitExceeded(error_msg));
        }
        Ok(())
    }
}

#[async_trait]
impl ChatCompletionClient for GrokClient {
    async fn send_chat_completion(
        &self,
        messages: Vec<Message>,
        reasoning_effort: &str,
    ) -> Result<ChatCompletionResponse, LlmClientError> {
        self.validate_input(&messages, reasoning_effort)?;
        let mut metrics = self.metrics.write().await;
        metrics.increment_request();

        let cache_key = self.generate_cache_key(&messages, reasoning_effort);
        if let Some(cached_response) = self.cache.get(&cache_key).await {
            metrics.increment_cache_hit();
            info!("Cache hit for key: {}", cache_key);
            metrics.increment_success();
            return Ok(cached_response);
        }

        let payload = GrokChatCompletionRequest {
            messages,
            reasoning_effort: reasoning_effort.to_string(),
            model: self.config.model.clone(),
        };

        let backoff = ExponentialBackoff {
            max_elapsed_time: Some(Duration::from_secs(60)),
            ..Default::default()
        };

        info!("Sending request to Grok API with model: {}", self.config.model);

        let response = retry(backoff, || async {
            debug!("Attempting API request to {}", self.config.endpoint);
            let result = self
                .client
                .post(&self.config.endpoint)
                .header("Authorization", format!("Bearer {}", self.config.api_key))
                .json(&payload)
                .send()
                .await;

            match result {
                Ok(resp) => {
                    if resp.status().is_success() {
                        Ok(resp)
                    } else {
                        self.handle_rate_limit(&resp).await?;
                        let status = resp.status();
                        let error_msg = resp.text().await.unwrap_or_default();
                        error!("API request failed with status {}: {}", status, error_msg);
                        Err(backoff::Error::Permanent(LlmClientError::ApiError(
                            format!("API request failed with status: {}", status),
                        )))
                    }
                }
                Err(err) if err.is_timeout() || err.is_connect() => {
                    debug!("Retrying due to transient error: {}", err);
                    Err(backoff::Error::Transient {
                        err: LlmClientError::HttpError(err),
                        retry_after: None,
                    })
                }
                Err(err) => {
                    error!("Permanent HTTP error: {}", err);
                    Err(backoff::Error::Permanent(LlmClientError::HttpError(err)))
                }
            }
        })
        .await?;

        let chat_response = response.json::<ChatCompletionResponse>().await?;
        info!("Received successful response with ID: {}", chat_response.id);

        if let Err(e) = self.cache.put(cache_key, chat_response.clone()).await {
            error!("Failed to cache response: {}", e);
        }

        metrics.increment_success();
        Ok(chat_response)
    }

    async fn get_metrics(&self) -> Metrics {
        let metrics = self.metrics.read().await;
        metrics.clone()
    }

    async fn stream_chat_completion(
        &self,
        _messages: Vec<Message>,
        _reasoning_effort: &str,
    ) -> Result<(), LlmClientError> {
        Err(LlmClientError::ApiError("Streaming not yet supported".to_string()))
    }
}