Skip to main content

llm_optimizer_decision/
variant_generator.rs

1//! Variant generation strategies for A/B testing
2//!
3//! This module provides strategies for generating prompt variants
4//! for A/B testing experiments.
5
6use llm_optimizer_types::models::ModelConfig;
7use serde::{Deserialize, Serialize};
8
9use crate::errors::{DecisionError, Result};
10
11/// Strategy for generating prompt variants
12#[derive(Debug, Clone, Serialize, Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum VariantStrategy {
15    /// Vary temperature parameter
16    Temperature(Vec<f64>),
17    /// Vary top-p parameter
18    TopP(Vec<f64>),
19    /// Vary max tokens
20    MaxTokens(Vec<u32>),
21    /// Vary system prompt
22    SystemPrompt(Vec<String>),
23    /// Vary model
24    Model(Vec<String>),
25    /// Custom configuration variants
26    Custom(Vec<ModelConfig>),
27}
28
29/// Variant generator
30pub struct VariantGenerator;
31
32impl VariantGenerator {
33    /// Generate variants from a base configuration and strategy
34    pub fn generate(base: &ModelConfig, strategy: &VariantStrategy) -> Result<Vec<ModelConfig>> {
35        match strategy {
36            VariantStrategy::Temperature(temps) => {
37                Self::generate_temperature_variants(base, temps)
38            }
39            VariantStrategy::TopP(top_ps) => {
40                Self::generate_top_p_variants(base, top_ps)
41            }
42            VariantStrategy::MaxTokens(max_tokens) => {
43                Self::generate_max_tokens_variants(base, max_tokens)
44            }
45            VariantStrategy::SystemPrompt(prompts) => {
46                Self::generate_system_prompt_variants(base, prompts)
47            }
48            VariantStrategy::Model(models) => {
49                Self::generate_model_variants(base, models)
50            }
51            VariantStrategy::Custom(configs) => {
52                Ok(configs.clone())
53            }
54        }
55    }
56
57    /// Generate variants with different temperature values
58    fn generate_temperature_variants(base: &ModelConfig, temps: &[f64]) -> Result<Vec<ModelConfig>> {
59        if temps.is_empty() {
60            return Err(DecisionError::InvalidConfig(
61                "Temperature variants list is empty".to_string()
62            ));
63        }
64
65        let variants: Vec<ModelConfig> = temps.iter().map(|&temp| {
66            let mut config = base.clone();
67            config.temperature = temp.clamp(0.0, 1.0);
68            config
69        }).collect();
70
71        Ok(variants)
72    }
73
74    /// Generate variants with different top-p values
75    fn generate_top_p_variants(base: &ModelConfig, top_ps: &[f64]) -> Result<Vec<ModelConfig>> {
76        if top_ps.is_empty() {
77            return Err(DecisionError::InvalidConfig(
78                "Top-p variants list is empty".to_string()
79            ));
80        }
81
82        let variants: Vec<ModelConfig> = top_ps.iter().map(|&top_p| {
83            let mut config = base.clone();
84            config.top_p = top_p.clamp(0.0, 1.0);
85            config
86        }).collect();
87
88        Ok(variants)
89    }
90
91    /// Generate variants with different max tokens
92    fn generate_max_tokens_variants(base: &ModelConfig, max_tokens: &[u32]) -> Result<Vec<ModelConfig>> {
93        if max_tokens.is_empty() {
94            return Err(DecisionError::InvalidConfig(
95                "Max tokens variants list is empty".to_string()
96            ));
97        }
98
99        let variants: Vec<ModelConfig> = max_tokens.iter().map(|&max_tok| {
100            let mut config = base.clone();
101            config.max_tokens = max_tok;
102            config
103        }).collect();
104
105        Ok(variants)
106    }
107
108    /// Generate variants with different system prompts
109    fn generate_system_prompt_variants(base: &ModelConfig, prompts: &[String]) -> Result<Vec<ModelConfig>> {
110        if prompts.is_empty() {
111            return Err(DecisionError::InvalidConfig(
112                "System prompt variants list is empty".to_string()
113            ));
114        }
115
116        let variants: Vec<ModelConfig> = prompts.iter().map(|prompt| {
117            let mut config = base.clone();
118            config.system_prompt = Some(prompt.clone());
119            config
120        }).collect();
121
122        Ok(variants)
123    }
124
125    /// Generate variants with different models
126    fn generate_model_variants(base: &ModelConfig, models: &[String]) -> Result<Vec<ModelConfig>> {
127        if models.is_empty() {
128            return Err(DecisionError::InvalidConfig(
129                "Model variants list is empty".to_string()
130            ));
131        }
132
133        let variants: Vec<ModelConfig> = models.iter().map(|model| {
134            let mut config = base.clone();
135            config.model = model.clone();
136            config
137        }).collect();
138
139        Ok(variants)
140    }
141
142    /// Generate standard set of temperature variants
143    pub fn standard_temperature_variants(base: &ModelConfig) -> Result<Vec<ModelConfig>> {
144        Self::generate_temperature_variants(base, &[0.0, 0.3, 0.7, 1.0])
145    }
146
147    /// Generate standard set of top-p variants
148    pub fn standard_top_p_variants(base: &ModelConfig) -> Result<Vec<ModelConfig>> {
149        Self::generate_top_p_variants(base, &[0.8, 0.9, 0.95, 1.0])
150    }
151
152    /// Validate a configuration
153    pub fn validate_config(config: &ModelConfig) -> Result<()> {
154        if config.temperature < 0.0 || config.temperature > 1.0 {
155            return Err(DecisionError::InvalidConfig(
156                format!("Temperature {} is out of range [0, 1]", config.temperature)
157            ));
158        }
159
160        if config.top_p < 0.0 || config.top_p > 1.0 {
161            return Err(DecisionError::InvalidConfig(
162                format!("Top-p {} is out of range [0, 1]", config.top_p)
163            ));
164        }
165
166        if config.max_tokens == 0 {
167            return Err(DecisionError::InvalidConfig(
168                "Max tokens must be greater than 0".to_string()
169            ));
170        }
171
172        if let Some(presence) = config.presence_penalty {
173            if presence < -2.0 || presence > 2.0 {
174                return Err(DecisionError::InvalidConfig(
175                    format!("Presence penalty {} is out of range [-2, 2]", presence)
176                ));
177            }
178        }
179
180        if let Some(frequency) = config.frequency_penalty {
181            if frequency < -2.0 || frequency > 2.0 {
182                return Err(DecisionError::InvalidConfig(
183                    format!("Frequency penalty {} is out of range [-2, 2]", frequency)
184                ));
185            }
186        }
187
188        Ok(())
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    fn base_config() -> ModelConfig {
197        ModelConfig {
198            model: "test-model".to_string(),
199            temperature: 0.7,
200            top_p: 0.9,
201            top_k: None,
202            max_tokens: 1024,
203            presence_penalty: None,
204            frequency_penalty: None,
205            system_prompt: None,
206            extra_params: Default::default(),
207        }
208    }
209
210    #[test]
211    fn test_generate_temperature_variants() {
212        let base = base_config();
213        let strategy = VariantStrategy::Temperature(vec![0.0, 0.5, 1.0]);
214        
215        let variants = VariantGenerator::generate(&base, &strategy).unwrap();
216        
217        assert_eq!(variants.len(), 3);
218        assert_eq!(variants[0].temperature, 0.0);
219        assert_eq!(variants[1].temperature, 0.5);
220        assert_eq!(variants[2].temperature, 1.0);
221    }
222
223    #[test]
224    fn test_generate_top_p_variants() {
225        let base = base_config();
226        let strategy = VariantStrategy::TopP(vec![0.8, 0.9, 1.0]);
227        
228        let variants = VariantGenerator::generate(&base, &strategy).unwrap();
229        
230        assert_eq!(variants.len(), 3);
231        assert_eq!(variants[0].top_p, 0.8);
232        assert_eq!(variants[1].top_p, 0.9);
233        assert_eq!(variants[2].top_p, 1.0);
234    }
235
236    #[test]
237    fn test_generate_max_tokens_variants() {
238        let base = base_config();
239        let strategy = VariantStrategy::MaxTokens(vec![512, 1024, 2048]);
240        
241        let variants = VariantGenerator::generate(&base, &strategy).unwrap();
242        
243        assert_eq!(variants.len(), 3);
244        assert_eq!(variants[0].max_tokens, 512);
245        assert_eq!(variants[1].max_tokens, 1024);
246        assert_eq!(variants[2].max_tokens, 2048);
247    }
248
249    #[test]
250    fn test_generate_system_prompt_variants() {
251        let base = base_config();
252        let prompts = vec![
253            "You are a helpful assistant.".to_string(),
254            "You are a coding expert.".to_string(),
255        ];
256        let strategy = VariantStrategy::SystemPrompt(prompts.clone());
257        
258        let variants = VariantGenerator::generate(&base, &strategy).unwrap();
259        
260        assert_eq!(variants.len(), 2);
261        assert_eq!(variants[0].system_prompt, Some(prompts[0].clone()));
262        assert_eq!(variants[1].system_prompt, Some(prompts[1].clone()));
263    }
264
265    #[test]
266    fn test_generate_model_variants() {
267        let base = base_config();
268        let models = vec!["model-1".to_string(), "model-2".to_string()];
269        let strategy = VariantStrategy::Model(models.clone());
270        
271        let variants = VariantGenerator::generate(&base, &strategy).unwrap();
272        
273        assert_eq!(variants.len(), 2);
274        assert_eq!(variants[0].model, models[0]);
275        assert_eq!(variants[1].model, models[1]);
276    }
277
278    #[test]
279    fn test_empty_variants_error() {
280        let base = base_config();
281        let strategy = VariantStrategy::Temperature(vec![]);
282        
283        assert!(VariantGenerator::generate(&base, &strategy).is_err());
284    }
285
286    #[test]
287    fn test_temperature_clamping() {
288        let base = base_config();
289        let strategy = VariantStrategy::Temperature(vec![-0.5, 0.5, 1.5]);
290        
291        let variants = VariantGenerator::generate(&base, &strategy).unwrap();
292        
293        // Should clamp to [0, 1]
294        assert_eq!(variants[0].temperature, 0.0);
295        assert_eq!(variants[1].temperature, 0.5);
296        assert_eq!(variants[2].temperature, 1.0);
297    }
298
299    #[test]
300    fn test_standard_temperature_variants() {
301        let base = base_config();
302        let variants = VariantGenerator::standard_temperature_variants(&base).unwrap();
303        
304        assert_eq!(variants.len(), 4);
305        assert_eq!(variants[0].temperature, 0.0);
306        assert_eq!(variants[3].temperature, 1.0);
307    }
308
309    #[test]
310    fn test_standard_top_p_variants() {
311        let base = base_config();
312        let variants = VariantGenerator::standard_top_p_variants(&base).unwrap();
313        
314        assert_eq!(variants.len(), 4);
315        assert_eq!(variants[0].top_p, 0.8);
316        assert_eq!(variants[3].top_p, 1.0);
317    }
318
319    #[test]
320    fn test_validate_config_valid() {
321        let config = base_config();
322        assert!(VariantGenerator::validate_config(&config).is_ok());
323    }
324
325    #[test]
326    fn test_validate_config_invalid_temperature() {
327        let mut config = base_config();
328        config.temperature = 1.5;
329        
330        assert!(VariantGenerator::validate_config(&config).is_err());
331    }
332
333    #[test]
334    fn test_validate_config_invalid_top_p() {
335        let mut config = base_config();
336        config.top_p = -0.1;
337        
338        assert!(VariantGenerator::validate_config(&config).is_err());
339    }
340
341    #[test]
342    fn test_validate_config_zero_max_tokens() {
343        let mut config = base_config();
344        config.max_tokens = 0;
345        
346        assert!(VariantGenerator::validate_config(&config).is_err());
347    }
348
349    #[test]
350    fn test_validate_config_invalid_penalties() {
351        let mut config = base_config();
352        config.presence_penalty = Some(3.0);
353        
354        assert!(VariantGenerator::validate_config(&config).is_err());
355        
356        config.presence_penalty = None;
357        config.frequency_penalty = Some(-3.0);
358        
359        assert!(VariantGenerator::validate_config(&config).is_err());
360    }
361}