use async_trait::async_trait;
use tracing::info;
use super::base::{LlmBase, LlmResponse, Message, ResponseFormat, Tool};
use super::openai_compat::{OpenAICompatClient, OpenAICompatConfig};
use crate::config::GroqLlmConfig;
use crate::error::{NeomemxError, Result};
pub struct GroqLlm {
client: OpenAICompatClient,
}
impl GroqLlm {
pub fn new(config: GroqLlmConfig) -> Result<Self> {
let api_key = config.get_api_key().ok_or_else(|| {
NeomemxError::LlmError(
"Groq API key not found. Set GROQ_API_KEY or provide in config.".to_string(),
)
})?;
info!("Creating Groq LLM with model: {}", config.model);
Ok(Self {
client: OpenAICompatClient::new(OpenAICompatConfig {
base_url: config.base_url,
api_key,
model: config.model,
temperature: config.temperature,
max_tokens: config.max_tokens,
top_p: config.top_p,
provider_name: "Groq",
}),
})
}
}
#[async_trait]
impl LlmBase for GroqLlm {
async fn generate_response(
&self,
messages: Vec<Message>,
response_format: Option<ResponseFormat>,
tools: Option<Vec<Tool>>,
tool_choice: Option<String>,
) -> Result<LlmResponse> {
self.client
.chat_completion(messages, response_format, tools, tool_choice)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = GroqLlmConfig::default();
assert_eq!(config.model, "llama-3.3-70b-versatile");
assert_eq!(config.base_url, "https://api.groq.com/openai/v1");
}
#[test]
fn test_model_presets() {
assert_eq!(
GroqLlmConfig::llama_3_3_70b().model,
"llama-3.3-70b-versatile"
);
assert_eq!(GroqLlmConfig::mixtral_8x7b().model, "mixtral-8x7b-32768");
}
}