use super::LlmProvider;
use anyhow::{Context, Result};
use async_trait::async_trait;
use serde_json::json;
pub struct GroqProvider {
client: reqwest::Client,
api_key: String,
model: String,
}
fn is_gpt_oss_model(model: &str) -> bool {
model.starts_with("openai/gpt-oss-")
}
impl GroqProvider {
pub fn new(api_key: String, model: Option<String>) -> Result<Self> {
Ok(Self {
client: reqwest::Client::new(),
api_key,
model: model.unwrap_or_else(|| "llama-3.3-70b-versatile".to_string()),
})
}
}
#[async_trait]
impl LlmProvider for GroqProvider {
async fn complete(&self, prompt: &str, json_mode: bool) -> Result<String> {
let mut messages = Vec::new();
if json_mode && is_gpt_oss_model(&self.model) {
messages.push(json!({
"role": "system",
"content": "You are a JSON generation assistant. You MUST ALWAYS return valid JSON that matches the schema provided in the user prompt. Never return free-form text. If you cannot answer the question, return a minimal valid JSON object that conforms to the schema. This is critical - only valid JSON is acceptable."
}));
}
messages.push(json!({
"role": "user",
"content": prompt
}));
let max_tokens = 4000;
let mut request_body = json!({
"model": self.model,
"messages": messages,
"temperature": 0.1,
"max_tokens": max_tokens,
});
if json_mode {
request_body["response_format"] = json!({
"type": "json_object"
});
}
let response = self
.client
.post("https://api.groq.com/openai/v1/chat/completions")
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.timeout(std::time::Duration::from_secs(30))
.send()
.await
.map_err(|e| {
log::error!("Groq API request failed: {}", e);
if e.is_timeout() {
log::error!(" Reason: Request timeout (>30s)");
} else if e.is_connect() {
log::error!(" Reason: Connection failed");
} else if e.is_request() {
log::error!(" Reason: Invalid request");
}
anyhow::anyhow!("Failed to send request to Groq API: {}", e)
})?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
let error_msg = match status.as_u16() {
429 => {
log::warn!("Groq rate limit exceeded: {}", error_text);
format!("Rate limit exceeded (try again in a few seconds)")
}
503 | 502 | 504 => {
log::warn!("Groq service unavailable ({}): {}", status, error_text);
format!("Groq service temporarily unavailable ({})", status)
}
401 => {
log::error!("Groq authentication failed: {}", error_text);
format!("Authentication failed - check API key")
}
_ => {
log::error!("Groq API error ({}): {}", status, error_text);
format!("API error ({}): {}", status, error_text)
}
};
anyhow::bail!("{}", error_msg);
}
let data: serde_json::Value = response
.json()
.await
.context("Failed to parse Groq response as JSON")?;
let content = data["choices"][0]["message"]["content"]
.as_str()
.context("No content in Groq response")?;
Ok(content.to_string())
}
fn name(&self) -> &str {
"groq"
}
fn default_model(&self) -> &str {
"llama-3.3-70b-versatile"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_with_default_model() {
let provider = GroqProvider::new("test-key".to_string(), None).unwrap();
assert_eq!(provider.name(), "groq");
assert_eq!(provider.model, "llama-3.3-70b-versatile");
}
#[test]
fn test_new_with_custom_model() {
let provider = GroqProvider::new(
"test-key".to_string(),
Some("mixtral-8x7b-32768".to_string())
).unwrap();
assert_eq!(provider.model, "mixtral-8x7b-32768");
}
}