use crate::error::LlmError;
use crate::traits::ModelListingCapability;
use crate::types::{HttpConfig, ModelInfo};
use async_trait::async_trait;
#[derive(Debug, Clone, PartialEq)]
pub enum ModelType {
Chat,
Audio,
Embedding,
Image,
}
use super::types::*;
use super::utils::*;
#[derive(Clone)]
pub struct GroqModels {
pub api_key: String,
pub base_url: String,
pub http_client: reqwest::Client,
pub http_config: HttpConfig,
}
impl GroqModels {
pub const fn new(
api_key: String,
base_url: String,
http_client: reqwest::Client,
http_config: HttpConfig,
) -> Self {
Self {
api_key,
base_url,
http_client,
http_config,
}
}
#[allow(dead_code)]
fn convert_groq_model(&self, groq_model: GroqModel) -> ModelInfo {
let model_id = groq_model.id.clone();
ModelInfo {
id: model_id.clone(),
name: Some(model_id),
description: None, owned_by: groq_model.owned_by.clone(),
created: Some(groq_model.created),
capabilities: self.get_model_capabilities(&groq_model),
context_window: Some(groq_model.context_window),
max_output_tokens: groq_model.max_completion_tokens,
input_cost_per_token: None, output_cost_per_token: None,
}
}
#[allow(dead_code)]
fn get_model_capabilities(&self, model: &GroqModel) -> Vec<String> {
let mut capabilities = Vec::new();
if model.id.contains("whisper") {
capabilities.push("transcription".to_string());
if !model.id.contains("-en") {
capabilities.push("translation".to_string());
}
} else if model.id.contains("tts") || model.id.contains("playai") {
capabilities.push("speech_synthesis".to_string());
} else {
capabilities.push("chat".to_string());
capabilities.push("streaming".to_string());
if !model.id.contains("guard") {
capabilities.push("function_calling".to_string());
}
if model.id.contains("qwen") {
capabilities.push("reasoning".to_string());
}
}
capabilities
}
}
#[allow(dead_code)]
impl GroqModels {
async fn list_models_internal(&self) -> Result<Vec<ModelInfo>, LlmError> {
let url = format!("{}/models", self.base_url);
let headers = build_headers(&self.api_key, &self.http_config.headers)?;
let response = self.http_client.get(&url).headers(headers).send().await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
let error_message = extract_error_message(&error_text);
return Err(LlmError::ApiError {
code: status.as_u16(),
message: format!("Groq list models error: {error_message}"),
details: serde_json::from_str(&error_text).ok(),
});
}
let groq_response: GroqModelsResponse = response.json().await?;
let models = groq_response
.data
.into_iter()
.map(|m| self.convert_groq_model(m))
.collect();
Ok(models)
}
async fn get_model_internal(&self, model_id: String) -> Result<ModelInfo, LlmError> {
let url = format!("{}/models/{}", self.base_url, model_id);
let headers = build_headers(&self.api_key, &self.http_config.headers)?;
let response = self.http_client.get(&url).headers(headers).send().await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
let error_message = extract_error_message(&error_text);
return Err(LlmError::ApiError {
code: status.as_u16(),
message: format!("Groq get model error: {error_message}"),
details: serde_json::from_str(&error_text).ok(),
});
}
let groq_model: GroqModel = response.json().await?;
Ok(self.convert_groq_model(groq_model))
}
fn supports_model_listing(&self) -> bool {
true
}
}
#[async_trait]
impl ModelListingCapability for GroqModels {
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
self.list_models_internal().await
}
async fn get_model(&self, model_id: String) -> Result<ModelInfo, LlmError> {
self.get_model_internal(model_id).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::HttpConfig;
fn create_test_models() -> GroqModels {
GroqModels::new(
"test-api-key".to_string(),
"https://api.groq.com/openai/v1".to_string(),
reqwest::Client::new(),
HttpConfig::default(),
)
}
#[test]
fn test_convert_groq_model_chat() {
let models = create_test_models();
let groq_model = GroqModel {
id: "llama-3.3-70b-versatile".to_string(),
object: "model".to_string(),
created: 1640995200,
owned_by: "Meta".to_string(),
active: true,
context_window: 32768,
public_apps: None,
max_completion_tokens: Some(8192),
};
let model_info = models.convert_groq_model(groq_model);
assert_eq!(model_info.id, "llama-3.3-70b-versatile");
assert_eq!(model_info.context_window, Some(32768));
assert_eq!(model_info.max_output_tokens, Some(8192));
assert!(model_info.capabilities.contains(&"chat".to_string()));
assert!(model_info.capabilities.contains(&"streaming".to_string()));
assert!(
model_info
.capabilities
.contains(&"function_calling".to_string())
);
}
#[test]
fn test_convert_groq_model_whisper() {
let models = create_test_models();
let groq_model = GroqModel {
id: "whisper-large-v3".to_string(),
object: "model".to_string(),
created: 1640995200,
owned_by: "OpenAI".to_string(),
active: true,
context_window: 448,
public_apps: None,
max_completion_tokens: None,
};
let model_info = models.convert_groq_model(groq_model);
assert_eq!(model_info.id, "whisper-large-v3");
assert!(
model_info
.capabilities
.contains(&"transcription".to_string())
);
assert!(model_info.capabilities.contains(&"translation".to_string()));
}
#[test]
fn test_convert_groq_model_qwen() {
let models = create_test_models();
let groq_model = GroqModel {
id: "qwen3-8b-instruct".to_string(),
object: "model".to_string(),
created: 1640995200,
owned_by: "Alibaba".to_string(),
active: true,
context_window: 8192,
public_apps: None,
max_completion_tokens: Some(4096),
};
let model_info = models.convert_groq_model(groq_model);
assert_eq!(model_info.id, "qwen3-8b-instruct");
assert!(model_info.capabilities.contains(&"reasoning".to_string()));
}
#[test]
fn test_capability_support() {
let models = create_test_models();
assert!(models.supports_model_listing());
}
}