neomemx 0.1.2

A high-performance memory library for AI agents with semantic search
Documentation
//! Groq LLM implementation

use async_trait::async_trait;
use tracing::info;

use super::base::{LlmBase, LlmResponse, Message, ResponseFormat, Tool};
use super::openai_compat::{OpenAICompatClient, OpenAICompatConfig};
use crate::config::GroqLlmConfig;
use crate::error::{NeomemxError, Result};

/// Groq LLM client using their OpenAI-compatible API.
pub struct GroqLlm {
    client: OpenAICompatClient,
}

impl GroqLlm {
    /// Creates a new Groq client with the given configuration.
    pub fn new(config: GroqLlmConfig) -> Result<Self> {
        let api_key = config.get_api_key().ok_or_else(|| {
            NeomemxError::LlmError(
                "Groq API key not found. Set GROQ_API_KEY or provide in config.".to_string(),
            )
        })?;

        info!("Creating Groq LLM with model: {}", config.model);

        Ok(Self {
            client: OpenAICompatClient::new(OpenAICompatConfig {
                base_url: config.base_url,
                api_key,
                model: config.model,
                temperature: config.temperature,
                max_tokens: config.max_tokens,
                top_p: config.top_p,
                provider_name: "Groq",
            }),
        })
    }
}

#[async_trait]
impl LlmBase for GroqLlm {
    async fn generate_response(
        &self,
        messages: Vec<Message>,
        response_format: Option<ResponseFormat>,
        tools: Option<Vec<Tool>>,
        tool_choice: Option<String>,
    ) -> Result<LlmResponse> {
        self.client
            .chat_completion(messages, response_format, tools, tool_choice)
            .await
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_default_config() {
        let config = GroqLlmConfig::default();
        assert_eq!(config.model, "llama-3.3-70b-versatile");
        assert_eq!(config.base_url, "https://api.groq.com/openai/v1");
    }

    #[test]
    fn test_model_presets() {
        assert_eq!(
            GroqLlmConfig::llama_3_3_70b().model,
            "llama-3.3-70b-versatile"
        );
        assert_eq!(GroqLlmConfig::mixtral_8x7b().model, "mixtral-8x7b-32768");
    }
}