use llm_optimizer_types::models::ModelConfig;
use serde::{Deserialize, Serialize};
use crate::errors::{DecisionError, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum VariantStrategy {
Temperature(Vec<f64>),
TopP(Vec<f64>),
MaxTokens(Vec<u32>),
SystemPrompt(Vec<String>),
Model(Vec<String>),
Custom(Vec<ModelConfig>),
}
pub struct VariantGenerator;
impl VariantGenerator {
pub fn generate(base: &ModelConfig, strategy: &VariantStrategy) -> Result<Vec<ModelConfig>> {
match strategy {
VariantStrategy::Temperature(temps) => {
Self::generate_temperature_variants(base, temps)
}
VariantStrategy::TopP(top_ps) => {
Self::generate_top_p_variants(base, top_ps)
}
VariantStrategy::MaxTokens(max_tokens) => {
Self::generate_max_tokens_variants(base, max_tokens)
}
VariantStrategy::SystemPrompt(prompts) => {
Self::generate_system_prompt_variants(base, prompts)
}
VariantStrategy::Model(models) => {
Self::generate_model_variants(base, models)
}
VariantStrategy::Custom(configs) => {
Ok(configs.clone())
}
}
}
fn generate_temperature_variants(base: &ModelConfig, temps: &[f64]) -> Result<Vec<ModelConfig>> {
if temps.is_empty() {
return Err(DecisionError::InvalidConfig(
"Temperature variants list is empty".to_string()
));
}
let variants: Vec<ModelConfig> = temps.iter().map(|&temp| {
let mut config = base.clone();
config.temperature = temp.clamp(0.0, 1.0);
config
}).collect();
Ok(variants)
}
fn generate_top_p_variants(base: &ModelConfig, top_ps: &[f64]) -> Result<Vec<ModelConfig>> {
if top_ps.is_empty() {
return Err(DecisionError::InvalidConfig(
"Top-p variants list is empty".to_string()
));
}
let variants: Vec<ModelConfig> = top_ps.iter().map(|&top_p| {
let mut config = base.clone();
config.top_p = top_p.clamp(0.0, 1.0);
config
}).collect();
Ok(variants)
}
fn generate_max_tokens_variants(base: &ModelConfig, max_tokens: &[u32]) -> Result<Vec<ModelConfig>> {
if max_tokens.is_empty() {
return Err(DecisionError::InvalidConfig(
"Max tokens variants list is empty".to_string()
));
}
let variants: Vec<ModelConfig> = max_tokens.iter().map(|&max_tok| {
let mut config = base.clone();
config.max_tokens = max_tok;
config
}).collect();
Ok(variants)
}
fn generate_system_prompt_variants(base: &ModelConfig, prompts: &[String]) -> Result<Vec<ModelConfig>> {
if prompts.is_empty() {
return Err(DecisionError::InvalidConfig(
"System prompt variants list is empty".to_string()
));
}
let variants: Vec<ModelConfig> = prompts.iter().map(|prompt| {
let mut config = base.clone();
config.system_prompt = Some(prompt.clone());
config
}).collect();
Ok(variants)
}
fn generate_model_variants(base: &ModelConfig, models: &[String]) -> Result<Vec<ModelConfig>> {
if models.is_empty() {
return Err(DecisionError::InvalidConfig(
"Model variants list is empty".to_string()
));
}
let variants: Vec<ModelConfig> = models.iter().map(|model| {
let mut config = base.clone();
config.model = model.clone();
config
}).collect();
Ok(variants)
}
pub fn standard_temperature_variants(base: &ModelConfig) -> Result<Vec<ModelConfig>> {
Self::generate_temperature_variants(base, &[0.0, 0.3, 0.7, 1.0])
}
pub fn standard_top_p_variants(base: &ModelConfig) -> Result<Vec<ModelConfig>> {
Self::generate_top_p_variants(base, &[0.8, 0.9, 0.95, 1.0])
}
pub fn validate_config(config: &ModelConfig) -> Result<()> {
if config.temperature < 0.0 || config.temperature > 1.0 {
return Err(DecisionError::InvalidConfig(
format!("Temperature {} is out of range [0, 1]", config.temperature)
));
}
if config.top_p < 0.0 || config.top_p > 1.0 {
return Err(DecisionError::InvalidConfig(
format!("Top-p {} is out of range [0, 1]", config.top_p)
));
}
if config.max_tokens == 0 {
return Err(DecisionError::InvalidConfig(
"Max tokens must be greater than 0".to_string()
));
}
if let Some(presence) = config.presence_penalty {
if presence < -2.0 || presence > 2.0 {
return Err(DecisionError::InvalidConfig(
format!("Presence penalty {} is out of range [-2, 2]", presence)
));
}
}
if let Some(frequency) = config.frequency_penalty {
if frequency < -2.0 || frequency > 2.0 {
return Err(DecisionError::InvalidConfig(
format!("Frequency penalty {} is out of range [-2, 2]", frequency)
));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn base_config() -> ModelConfig {
ModelConfig {
model: "test-model".to_string(),
temperature: 0.7,
top_p: 0.9,
top_k: None,
max_tokens: 1024,
presence_penalty: None,
frequency_penalty: None,
system_prompt: None,
extra_params: Default::default(),
}
}
#[test]
fn test_generate_temperature_variants() {
let base = base_config();
let strategy = VariantStrategy::Temperature(vec![0.0, 0.5, 1.0]);
let variants = VariantGenerator::generate(&base, &strategy).unwrap();
assert_eq!(variants.len(), 3);
assert_eq!(variants[0].temperature, 0.0);
assert_eq!(variants[1].temperature, 0.5);
assert_eq!(variants[2].temperature, 1.0);
}
#[test]
fn test_generate_top_p_variants() {
let base = base_config();
let strategy = VariantStrategy::TopP(vec![0.8, 0.9, 1.0]);
let variants = VariantGenerator::generate(&base, &strategy).unwrap();
assert_eq!(variants.len(), 3);
assert_eq!(variants[0].top_p, 0.8);
assert_eq!(variants[1].top_p, 0.9);
assert_eq!(variants[2].top_p, 1.0);
}
#[test]
fn test_generate_max_tokens_variants() {
let base = base_config();
let strategy = VariantStrategy::MaxTokens(vec![512, 1024, 2048]);
let variants = VariantGenerator::generate(&base, &strategy).unwrap();
assert_eq!(variants.len(), 3);
assert_eq!(variants[0].max_tokens, 512);
assert_eq!(variants[1].max_tokens, 1024);
assert_eq!(variants[2].max_tokens, 2048);
}
#[test]
fn test_generate_system_prompt_variants() {
let base = base_config();
let prompts = vec![
"You are a helpful assistant.".to_string(),
"You are a coding expert.".to_string(),
];
let strategy = VariantStrategy::SystemPrompt(prompts.clone());
let variants = VariantGenerator::generate(&base, &strategy).unwrap();
assert_eq!(variants.len(), 2);
assert_eq!(variants[0].system_prompt, Some(prompts[0].clone()));
assert_eq!(variants[1].system_prompt, Some(prompts[1].clone()));
}
#[test]
fn test_generate_model_variants() {
let base = base_config();
let models = vec!["model-1".to_string(), "model-2".to_string()];
let strategy = VariantStrategy::Model(models.clone());
let variants = VariantGenerator::generate(&base, &strategy).unwrap();
assert_eq!(variants.len(), 2);
assert_eq!(variants[0].model, models[0]);
assert_eq!(variants[1].model, models[1]);
}
#[test]
fn test_empty_variants_error() {
let base = base_config();
let strategy = VariantStrategy::Temperature(vec![]);
assert!(VariantGenerator::generate(&base, &strategy).is_err());
}
#[test]
fn test_temperature_clamping() {
let base = base_config();
let strategy = VariantStrategy::Temperature(vec![-0.5, 0.5, 1.5]);
let variants = VariantGenerator::generate(&base, &strategy).unwrap();
assert_eq!(variants[0].temperature, 0.0);
assert_eq!(variants[1].temperature, 0.5);
assert_eq!(variants[2].temperature, 1.0);
}
#[test]
fn test_standard_temperature_variants() {
let base = base_config();
let variants = VariantGenerator::standard_temperature_variants(&base).unwrap();
assert_eq!(variants.len(), 4);
assert_eq!(variants[0].temperature, 0.0);
assert_eq!(variants[3].temperature, 1.0);
}
#[test]
fn test_standard_top_p_variants() {
let base = base_config();
let variants = VariantGenerator::standard_top_p_variants(&base).unwrap();
assert_eq!(variants.len(), 4);
assert_eq!(variants[0].top_p, 0.8);
assert_eq!(variants[3].top_p, 1.0);
}
#[test]
fn test_validate_config_valid() {
let config = base_config();
assert!(VariantGenerator::validate_config(&config).is_ok());
}
#[test]
fn test_validate_config_invalid_temperature() {
let mut config = base_config();
config.temperature = 1.5;
assert!(VariantGenerator::validate_config(&config).is_err());
}
#[test]
fn test_validate_config_invalid_top_p() {
let mut config = base_config();
config.top_p = -0.1;
assert!(VariantGenerator::validate_config(&config).is_err());
}
#[test]
fn test_validate_config_zero_max_tokens() {
let mut config = base_config();
config.max_tokens = 0;
assert!(VariantGenerator::validate_config(&config).is_err());
}
#[test]
fn test_validate_config_invalid_penalties() {
let mut config = base_config();
config.presence_penalty = Some(3.0);
assert!(VariantGenerator::validate_config(&config).is_err());
config.presence_penalty = None;
config.frequency_penalty = Some(-3.0);
assert!(VariantGenerator::validate_config(&config).is_err());
}
}