use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroqLlmConfig {
#[serde(default = "default_model")]
pub model: String,
pub api_key: Option<String>,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_top_p")]
pub top_p: f32,
#[serde(default = "default_max_tokens")]
pub max_tokens: u32,
#[serde(default = "default_base_url")]
pub base_url: String,
}
fn default_model() -> String {
"llama-3.3-70b-versatile".to_string()
}
fn default_temperature() -> f32 {
0.7
}
fn default_top_p() -> f32 {
1.0
}
fn default_max_tokens() -> u32 {
2048
}
fn default_base_url() -> String {
"https://api.groq.com/openai/v1".to_string()
}
impl Default for GroqLlmConfig {
fn default() -> Self {
Self {
model: default_model(),
api_key: None,
temperature: default_temperature(),
top_p: default_top_p(),
max_tokens: default_max_tokens(),
base_url: default_base_url(),
}
}
}
impl GroqLlmConfig {
pub fn llama_3_3_70b() -> Self {
Self {
model: "llama-3.3-70b-versatile".to_string(),
..Default::default()
}
}
pub fn mixtral_8x7b() -> Self {
Self {
model: "mixtral-8x7b-32768".to_string(),
..Default::default()
}
}
pub fn get_api_key(&self) -> Option<String> {
self.api_key
.clone()
.or_else(|| std::env::var("GROQ_API_KEY").ok())
}
}