1use llm_optimizer_types::models::ModelConfig;
7use serde::{Deserialize, Serialize};
8
9use crate::errors::{DecisionError, Result};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum VariantStrategy {
15 Temperature(Vec<f64>),
17 TopP(Vec<f64>),
19 MaxTokens(Vec<u32>),
21 SystemPrompt(Vec<String>),
23 Model(Vec<String>),
25 Custom(Vec<ModelConfig>),
27}
28
29pub struct VariantGenerator;
31
32impl VariantGenerator {
33 pub fn generate(base: &ModelConfig, strategy: &VariantStrategy) -> Result<Vec<ModelConfig>> {
35 match strategy {
36 VariantStrategy::Temperature(temps) => {
37 Self::generate_temperature_variants(base, temps)
38 }
39 VariantStrategy::TopP(top_ps) => {
40 Self::generate_top_p_variants(base, top_ps)
41 }
42 VariantStrategy::MaxTokens(max_tokens) => {
43 Self::generate_max_tokens_variants(base, max_tokens)
44 }
45 VariantStrategy::SystemPrompt(prompts) => {
46 Self::generate_system_prompt_variants(base, prompts)
47 }
48 VariantStrategy::Model(models) => {
49 Self::generate_model_variants(base, models)
50 }
51 VariantStrategy::Custom(configs) => {
52 Ok(configs.clone())
53 }
54 }
55 }
56
57 fn generate_temperature_variants(base: &ModelConfig, temps: &[f64]) -> Result<Vec<ModelConfig>> {
59 if temps.is_empty() {
60 return Err(DecisionError::InvalidConfig(
61 "Temperature variants list is empty".to_string()
62 ));
63 }
64
65 let variants: Vec<ModelConfig> = temps.iter().map(|&temp| {
66 let mut config = base.clone();
67 config.temperature = temp.clamp(0.0, 1.0);
68 config
69 }).collect();
70
71 Ok(variants)
72 }
73
74 fn generate_top_p_variants(base: &ModelConfig, top_ps: &[f64]) -> Result<Vec<ModelConfig>> {
76 if top_ps.is_empty() {
77 return Err(DecisionError::InvalidConfig(
78 "Top-p variants list is empty".to_string()
79 ));
80 }
81
82 let variants: Vec<ModelConfig> = top_ps.iter().map(|&top_p| {
83 let mut config = base.clone();
84 config.top_p = top_p.clamp(0.0, 1.0);
85 config
86 }).collect();
87
88 Ok(variants)
89 }
90
91 fn generate_max_tokens_variants(base: &ModelConfig, max_tokens: &[u32]) -> Result<Vec<ModelConfig>> {
93 if max_tokens.is_empty() {
94 return Err(DecisionError::InvalidConfig(
95 "Max tokens variants list is empty".to_string()
96 ));
97 }
98
99 let variants: Vec<ModelConfig> = max_tokens.iter().map(|&max_tok| {
100 let mut config = base.clone();
101 config.max_tokens = max_tok;
102 config
103 }).collect();
104
105 Ok(variants)
106 }
107
108 fn generate_system_prompt_variants(base: &ModelConfig, prompts: &[String]) -> Result<Vec<ModelConfig>> {
110 if prompts.is_empty() {
111 return Err(DecisionError::InvalidConfig(
112 "System prompt variants list is empty".to_string()
113 ));
114 }
115
116 let variants: Vec<ModelConfig> = prompts.iter().map(|prompt| {
117 let mut config = base.clone();
118 config.system_prompt = Some(prompt.clone());
119 config
120 }).collect();
121
122 Ok(variants)
123 }
124
125 fn generate_model_variants(base: &ModelConfig, models: &[String]) -> Result<Vec<ModelConfig>> {
127 if models.is_empty() {
128 return Err(DecisionError::InvalidConfig(
129 "Model variants list is empty".to_string()
130 ));
131 }
132
133 let variants: Vec<ModelConfig> = models.iter().map(|model| {
134 let mut config = base.clone();
135 config.model = model.clone();
136 config
137 }).collect();
138
139 Ok(variants)
140 }
141
142 pub fn standard_temperature_variants(base: &ModelConfig) -> Result<Vec<ModelConfig>> {
144 Self::generate_temperature_variants(base, &[0.0, 0.3, 0.7, 1.0])
145 }
146
147 pub fn standard_top_p_variants(base: &ModelConfig) -> Result<Vec<ModelConfig>> {
149 Self::generate_top_p_variants(base, &[0.8, 0.9, 0.95, 1.0])
150 }
151
152 pub fn validate_config(config: &ModelConfig) -> Result<()> {
154 if config.temperature < 0.0 || config.temperature > 1.0 {
155 return Err(DecisionError::InvalidConfig(
156 format!("Temperature {} is out of range [0, 1]", config.temperature)
157 ));
158 }
159
160 if config.top_p < 0.0 || config.top_p > 1.0 {
161 return Err(DecisionError::InvalidConfig(
162 format!("Top-p {} is out of range [0, 1]", config.top_p)
163 ));
164 }
165
166 if config.max_tokens == 0 {
167 return Err(DecisionError::InvalidConfig(
168 "Max tokens must be greater than 0".to_string()
169 ));
170 }
171
172 if let Some(presence) = config.presence_penalty {
173 if presence < -2.0 || presence > 2.0 {
174 return Err(DecisionError::InvalidConfig(
175 format!("Presence penalty {} is out of range [-2, 2]", presence)
176 ));
177 }
178 }
179
180 if let Some(frequency) = config.frequency_penalty {
181 if frequency < -2.0 || frequency > 2.0 {
182 return Err(DecisionError::InvalidConfig(
183 format!("Frequency penalty {} is out of range [-2, 2]", frequency)
184 ));
185 }
186 }
187
188 Ok(())
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 fn base_config() -> ModelConfig {
197 ModelConfig {
198 model: "test-model".to_string(),
199 temperature: 0.7,
200 top_p: 0.9,
201 top_k: None,
202 max_tokens: 1024,
203 presence_penalty: None,
204 frequency_penalty: None,
205 system_prompt: None,
206 extra_params: Default::default(),
207 }
208 }
209
210 #[test]
211 fn test_generate_temperature_variants() {
212 let base = base_config();
213 let strategy = VariantStrategy::Temperature(vec![0.0, 0.5, 1.0]);
214
215 let variants = VariantGenerator::generate(&base, &strategy).unwrap();
216
217 assert_eq!(variants.len(), 3);
218 assert_eq!(variants[0].temperature, 0.0);
219 assert_eq!(variants[1].temperature, 0.5);
220 assert_eq!(variants[2].temperature, 1.0);
221 }
222
223 #[test]
224 fn test_generate_top_p_variants() {
225 let base = base_config();
226 let strategy = VariantStrategy::TopP(vec![0.8, 0.9, 1.0]);
227
228 let variants = VariantGenerator::generate(&base, &strategy).unwrap();
229
230 assert_eq!(variants.len(), 3);
231 assert_eq!(variants[0].top_p, 0.8);
232 assert_eq!(variants[1].top_p, 0.9);
233 assert_eq!(variants[2].top_p, 1.0);
234 }
235
236 #[test]
237 fn test_generate_max_tokens_variants() {
238 let base = base_config();
239 let strategy = VariantStrategy::MaxTokens(vec![512, 1024, 2048]);
240
241 let variants = VariantGenerator::generate(&base, &strategy).unwrap();
242
243 assert_eq!(variants.len(), 3);
244 assert_eq!(variants[0].max_tokens, 512);
245 assert_eq!(variants[1].max_tokens, 1024);
246 assert_eq!(variants[2].max_tokens, 2048);
247 }
248
249 #[test]
250 fn test_generate_system_prompt_variants() {
251 let base = base_config();
252 let prompts = vec![
253 "You are a helpful assistant.".to_string(),
254 "You are a coding expert.".to_string(),
255 ];
256 let strategy = VariantStrategy::SystemPrompt(prompts.clone());
257
258 let variants = VariantGenerator::generate(&base, &strategy).unwrap();
259
260 assert_eq!(variants.len(), 2);
261 assert_eq!(variants[0].system_prompt, Some(prompts[0].clone()));
262 assert_eq!(variants[1].system_prompt, Some(prompts[1].clone()));
263 }
264
265 #[test]
266 fn test_generate_model_variants() {
267 let base = base_config();
268 let models = vec!["model-1".to_string(), "model-2".to_string()];
269 let strategy = VariantStrategy::Model(models.clone());
270
271 let variants = VariantGenerator::generate(&base, &strategy).unwrap();
272
273 assert_eq!(variants.len(), 2);
274 assert_eq!(variants[0].model, models[0]);
275 assert_eq!(variants[1].model, models[1]);
276 }
277
278 #[test]
279 fn test_empty_variants_error() {
280 let base = base_config();
281 let strategy = VariantStrategy::Temperature(vec![]);
282
283 assert!(VariantGenerator::generate(&base, &strategy).is_err());
284 }
285
286 #[test]
287 fn test_temperature_clamping() {
288 let base = base_config();
289 let strategy = VariantStrategy::Temperature(vec![-0.5, 0.5, 1.5]);
290
291 let variants = VariantGenerator::generate(&base, &strategy).unwrap();
292
293 assert_eq!(variants[0].temperature, 0.0);
295 assert_eq!(variants[1].temperature, 0.5);
296 assert_eq!(variants[2].temperature, 1.0);
297 }
298
299 #[test]
300 fn test_standard_temperature_variants() {
301 let base = base_config();
302 let variants = VariantGenerator::standard_temperature_variants(&base).unwrap();
303
304 assert_eq!(variants.len(), 4);
305 assert_eq!(variants[0].temperature, 0.0);
306 assert_eq!(variants[3].temperature, 1.0);
307 }
308
309 #[test]
310 fn test_standard_top_p_variants() {
311 let base = base_config();
312 let variants = VariantGenerator::standard_top_p_variants(&base).unwrap();
313
314 assert_eq!(variants.len(), 4);
315 assert_eq!(variants[0].top_p, 0.8);
316 assert_eq!(variants[3].top_p, 1.0);
317 }
318
319 #[test]
320 fn test_validate_config_valid() {
321 let config = base_config();
322 assert!(VariantGenerator::validate_config(&config).is_ok());
323 }
324
325 #[test]
326 fn test_validate_config_invalid_temperature() {
327 let mut config = base_config();
328 config.temperature = 1.5;
329
330 assert!(VariantGenerator::validate_config(&config).is_err());
331 }
332
333 #[test]
334 fn test_validate_config_invalid_top_p() {
335 let mut config = base_config();
336 config.top_p = -0.1;
337
338 assert!(VariantGenerator::validate_config(&config).is_err());
339 }
340
341 #[test]
342 fn test_validate_config_zero_max_tokens() {
343 let mut config = base_config();
344 config.max_tokens = 0;
345
346 assert!(VariantGenerator::validate_config(&config).is_err());
347 }
348
349 #[test]
350 fn test_validate_config_invalid_penalties() {
351 let mut config = base_config();
352 config.presence_penalty = Some(3.0);
353
354 assert!(VariantGenerator::validate_config(&config).is_err());
355
356 config.presence_penalty = None;
357 config.frequency_penalty = Some(-3.0);
358
359 assert!(VariantGenerator::validate_config(&config).is_err());
360 }
361}