Skip to main content

llm_optimizer_decision/
parameter_search.rs

1//! Parameter Search Strategies
2//!
3//! This module provides various strategies for exploring the parameter space
4//! including grid search, random search, and Bayesian optimization.
5
6use rand::Rng;
7use serde::{Deserialize, Serialize};
8use std::collections::VecDeque;
9
10use crate::adaptive_params::{ParameterConfig, ParameterRange};
11
12/// Search strategy type
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub enum SearchStrategy {
15    /// Systematic grid search
16    Grid,
17    /// Random sampling
18    Random,
19    /// Latin hypercube sampling
20    LatinHypercube,
21    /// Sobol sequence (quasi-random)
22    Sobol,
23}
24
25/// Grid search configuration
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct GridSearchConfig {
28    /// Number of temperature steps
29    pub temp_steps: usize,
30    /// Number of top-p steps
31    pub top_p_steps: usize,
32    /// Number of max token steps
33    pub max_tokens_steps: usize,
34}
35
36impl Default for GridSearchConfig {
37    fn default() -> Self {
38        Self {
39            temp_steps: 5,
40            top_p_steps: 4,
41            max_tokens_steps: 4,
42        }
43    }
44}
45
46/// Grid search iterator for parameter space exploration
47pub struct GridSearch {
48    /// Parameter range to search
49    range: ParameterRange,
50    /// Grid configuration
51    config: GridSearchConfig,
52    /// Generated configurations
53    configurations: VecDeque<ParameterConfig>,
54}
55
56impl GridSearch {
57    /// Create new grid search
58    pub fn new(range: ParameterRange, config: GridSearchConfig) -> Self {
59        let mut search = Self {
60            range,
61            config,
62            configurations: VecDeque::new(),
63        };
64        search.generate_grid();
65        search
66    }
67
68    /// Create with default configuration
69    pub fn with_defaults(range: ParameterRange) -> Self {
70        Self::new(range, GridSearchConfig::default())
71    }
72
73    /// Generate grid of configurations
74    fn generate_grid(&mut self) {
75        let temp_values = linspace(
76            self.range.temp_min,
77            self.range.temp_max,
78            self.config.temp_steps,
79        );
80
81        let top_p_values = linspace(
82            self.range.top_p_min,
83            self.range.top_p_max,
84            self.config.top_p_steps,
85        );
86
87        let max_tokens_values = linspace_usize(
88            self.range.max_tokens_min,
89            self.range.max_tokens_max,
90            self.config.max_tokens_steps,
91        );
92
93        for &temp in &temp_values {
94            for &top_p in &top_p_values {
95                for &max_tokens in &max_tokens_values {
96                    if let Ok(config) = ParameterConfig::new(temp, top_p, max_tokens) {
97                        self.configurations.push_back(config);
98                    }
99                }
100            }
101        }
102    }
103
104    /// Get next configuration
105    pub fn next(&mut self) -> Option<ParameterConfig> {
106        self.configurations.pop_front()
107    }
108
109    /// Get all configurations
110    pub fn all_configs(&self) -> Vec<ParameterConfig> {
111        self.configurations.iter().cloned().collect()
112    }
113
114    /// Get total number of configurations
115    pub fn total_configs(&self) -> usize {
116        self.config.temp_steps * self.config.top_p_steps * self.config.max_tokens_steps
117    }
118
119    /// Check if search is complete
120    pub fn is_complete(&self) -> bool {
121        self.configurations.is_empty()
122    }
123}
124
125/// Random search for parameter exploration
126pub struct RandomSearch {
127    /// Parameter range to search
128    range: ParameterRange,
129    /// Number of samples to generate
130    num_samples: usize,
131    /// Generated samples
132    samples_generated: usize,
133}
134
135impl RandomSearch {
136    /// Create new random search
137    pub fn new(range: ParameterRange, num_samples: usize) -> Self {
138        Self {
139            range,
140            num_samples,
141            samples_generated: 0,
142        }
143    }
144
145    /// Generate next random configuration
146    pub fn next(&mut self) -> Option<ParameterConfig> {
147        if self.samples_generated >= self.num_samples {
148            return None;
149        }
150
151        let mut rng = rand::thread_rng();
152
153        let temp = rng.gen_range(self.range.temp_min..=self.range.temp_max);
154        let top_p = rng.gen_range(self.range.top_p_min..=self.range.top_p_max);
155        let max_tokens = rng.gen_range(self.range.max_tokens_min..=self.range.max_tokens_max);
156
157        self.samples_generated += 1;
158
159        ParameterConfig::new(temp, top_p, max_tokens).ok()
160    }
161
162    /// Generate all random configurations at once
163    pub fn generate_all(&mut self) -> Vec<ParameterConfig> {
164        let mut configs = Vec::new();
165        while let Some(config) = self.next() {
166            configs.push(config);
167        }
168        configs
169    }
170
171    /// Check if search is complete
172    pub fn is_complete(&self) -> bool {
173        self.samples_generated >= self.num_samples
174    }
175
176    /// Reset the search
177    pub fn reset(&mut self) {
178        self.samples_generated = 0;
179    }
180}
181
182/// Latin Hypercube Sampling for better coverage
183pub struct LatinHypercubeSampling {
184    /// Parameter range
185    range: ParameterRange,
186    /// Number of samples
187    num_samples: usize,
188    /// Generated configurations
189    configurations: VecDeque<ParameterConfig>,
190}
191
192impl LatinHypercubeSampling {
193    /// Create new LHS sampler
194    pub fn new(range: ParameterRange, num_samples: usize) -> Self {
195        let mut lhs = Self {
196            range,
197            num_samples,
198            configurations: VecDeque::new(),
199        };
200        lhs.generate_samples();
201        lhs
202    }
203
204    /// Generate Latin Hypercube samples
205    fn generate_samples(&mut self) {
206        let mut rng = rand::thread_rng();
207
208        // Create permutations for each dimension
209        let mut temp_indices: Vec<usize> = (0..self.num_samples).collect();
210        let mut top_p_indices: Vec<usize> = (0..self.num_samples).collect();
211        let mut tokens_indices: Vec<usize> = (0..self.num_samples).collect();
212
213        // Shuffle each dimension independently
214        shuffle(&mut temp_indices);
215        shuffle(&mut top_p_indices);
216        shuffle(&mut tokens_indices);
217
218        // Generate samples
219        for i in 0..self.num_samples {
220            // Add random jitter within each cell
221            let temp_cell = temp_indices[i] as f64 + rng.gen::<f64>();
222            let top_p_cell = top_p_indices[i] as f64 + rng.gen::<f64>();
223            let tokens_cell = tokens_indices[i] as f64 + rng.gen::<f64>();
224
225            // Scale to parameter range
226            let temp = self.range.temp_min
227                + (temp_cell / self.num_samples as f64)
228                    * (self.range.temp_max - self.range.temp_min);
229
230            let top_p = self.range.top_p_min
231                + (top_p_cell / self.num_samples as f64)
232                    * (self.range.top_p_max - self.range.top_p_min);
233
234            let max_tokens = self.range.max_tokens_min
235                + ((tokens_cell / self.num_samples as f64)
236                    * (self.range.max_tokens_max - self.range.max_tokens_min) as f64)
237                    as usize;
238
239            if let Ok(config) = ParameterConfig::new(temp, top_p, max_tokens) {
240                self.configurations.push_back(config);
241            }
242        }
243    }
244
245    /// Get next configuration
246    pub fn next(&mut self) -> Option<ParameterConfig> {
247        self.configurations.pop_front()
248    }
249
250    /// Get all configurations
251    pub fn all_configs(&self) -> Vec<ParameterConfig> {
252        self.configurations.iter().cloned().collect()
253    }
254
255    /// Check if complete
256    pub fn is_complete(&self) -> bool {
257        self.configurations.is_empty()
258    }
259}
260
261/// Parameter search manager
262pub struct ParameterSearchManager {
263    /// Current search strategy
264    strategy: SearchStrategy,
265    /// Parameter range
266    range: ParameterRange,
267    /// Grid search instance
268    pub grid_search: Option<GridSearch>,
269    /// Random search instance
270    pub random_search: Option<RandomSearch>,
271    /// LHS instance
272    pub lhs_search: Option<LatinHypercubeSampling>,
273}
274
275impl ParameterSearchManager {
276    /// Create new search manager with grid search
277    pub fn with_grid_search(range: ParameterRange, config: GridSearchConfig) -> Self {
278        Self {
279            strategy: SearchStrategy::Grid,
280            range: range.clone(),
281            grid_search: Some(GridSearch::new(range, config)),
282            random_search: None,
283            lhs_search: None,
284        }
285    }
286
287    /// Create new search manager with random search
288    pub fn with_random_search(range: ParameterRange, num_samples: usize) -> Self {
289        Self {
290            strategy: SearchStrategy::Random,
291            range: range.clone(),
292            grid_search: None,
293            random_search: Some(RandomSearch::new(range, num_samples)),
294            lhs_search: None,
295        }
296    }
297
298    /// Create new search manager with Latin Hypercube Sampling
299    pub fn with_lhs(range: ParameterRange, num_samples: usize) -> Self {
300        Self {
301            strategy: SearchStrategy::LatinHypercube,
302            range: range.clone(),
303            grid_search: None,
304            random_search: None,
305            lhs_search: Some(LatinHypercubeSampling::new(range, num_samples)),
306        }
307    }
308
309    /// Get next configuration from current strategy
310    pub fn next(&mut self) -> Option<ParameterConfig> {
311        match self.strategy {
312            SearchStrategy::Grid => self.grid_search.as_mut().and_then(|s| s.next()),
313            SearchStrategy::Random => self.random_search.as_mut().and_then(|s| s.next()),
314            SearchStrategy::LatinHypercube => self.lhs_search.as_mut().and_then(|s| s.next()),
315            SearchStrategy::Sobol => None, // TODO: Implement Sobol sequence
316        }
317    }
318
319    /// Check if search is complete
320    pub fn is_complete(&self) -> bool {
321        match self.strategy {
322            SearchStrategy::Grid => self
323                .grid_search
324                .as_ref()
325                .map(|s| s.is_complete())
326                .unwrap_or(true),
327            SearchStrategy::Random => self
328                .random_search
329                .as_ref()
330                .map(|s| s.is_complete())
331                .unwrap_or(true),
332            SearchStrategy::LatinHypercube => self
333                .lhs_search
334                .as_ref()
335                .map(|s| s.is_complete())
336                .unwrap_or(true),
337            SearchStrategy::Sobol => true,
338        }
339    }
340
341    /// Get current strategy
342    pub fn strategy(&self) -> SearchStrategy {
343        self.strategy
344    }
345}
346
347/// Helper: Generate linearly spaced values
348fn linspace(start: f64, end: f64, num: usize) -> Vec<f64> {
349    if num == 0 {
350        return vec![];
351    }
352    if num == 1 {
353        return vec![start];
354    }
355
356    let step = (end - start) / (num - 1) as f64;
357    (0..num).map(|i| start + i as f64 * step).collect()
358}
359
360/// Helper: Generate linearly spaced usize values
361fn linspace_usize(start: usize, end: usize, num: usize) -> Vec<usize> {
362    if num == 0 {
363        return vec![];
364    }
365    if num == 1 {
366        return vec![start];
367    }
368
369    let step = (end - start) as f64 / (num - 1) as f64;
370    (0..num).map(|i| start + (i as f64 * step) as usize).collect()
371}
372
373/// Helper: Fisher-Yates shuffle
374fn shuffle<T>(vec: &mut [T]) {
375    let mut rng = rand::thread_rng();
376    let len = vec.len();
377    for i in 0..len {
378        let j = rng.gen_range(i..len);
379        vec.swap(i, j);
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    #[test]
388    fn test_linspace() {
389        let values = linspace(0.0, 1.0, 5);
390        assert_eq!(values.len(), 5);
391        assert_eq!(values[0], 0.0);
392        assert_eq!(values[4], 1.0);
393        assert!((values[2] - 0.5).abs() < 1e-10);
394    }
395
396    #[test]
397    fn test_linspace_usize() {
398        let values = linspace_usize(0, 100, 5);
399        assert_eq!(values.len(), 5);
400        assert_eq!(values[0], 0);
401        assert_eq!(values[4], 100);
402    }
403
404    #[test]
405    fn test_grid_search_creation() {
406        let range = ParameterRange::default();
407        let config = GridSearchConfig {
408            temp_steps: 3,
409            top_p_steps: 3,
410            max_tokens_steps: 2,
411        };
412
413        let search = GridSearch::new(range, config);
414        assert_eq!(search.total_configs(), 3 * 3 * 2);
415    }
416
417    #[test]
418    fn test_grid_search_iteration() {
419        let range = ParameterRange::default();
420        let config = GridSearchConfig {
421            temp_steps: 2,
422            top_p_steps: 2,
423            max_tokens_steps: 2,
424        };
425
426        let mut search = GridSearch::new(range, config);
427        let mut count = 0;
428
429        while search.next().is_some() {
430            count += 1;
431        }
432
433        assert_eq!(count, 8);
434        assert!(search.is_complete());
435    }
436
437    #[test]
438    fn test_grid_search_coverage() {
439        let range = ParameterRange {
440            temp_min: 0.0,
441            temp_max: 1.0,
442            top_p_min: 0.8,
443            top_p_max: 1.0,
444            max_tokens_min: 512,
445            max_tokens_max: 2048,
446        };
447
448        let config = GridSearchConfig {
449            temp_steps: 3,
450            top_p_steps: 3,
451            max_tokens_steps: 2,
452        };
453
454        let search = GridSearch::new(range.clone(), config);
455        let configs = search.all_configs();
456
457        // Check boundary coverage
458        assert!(configs.iter().any(|c| c.temperature == range.temp_min));
459        assert!(configs.iter().any(|c| c.temperature == range.temp_max));
460        assert!(configs.iter().any(|c| c.top_p == range.top_p_min));
461        assert!(configs.iter().any(|c| c.top_p == range.top_p_max));
462    }
463
464    #[test]
465    fn test_random_search_creation() {
466        let range = ParameterRange::default();
467        let search = RandomSearch::new(range, 10);
468        assert!(!search.is_complete());
469    }
470
471    #[test]
472    fn test_random_search_sampling() {
473        let range = ParameterRange::default();
474        let mut search = RandomSearch::new(range.clone(), 20);
475
476        let mut count = 0;
477        while let Some(config) = search.next() {
478            assert!(range.contains(&config));
479            count += 1;
480        }
481
482        assert_eq!(count, 20);
483        assert!(search.is_complete());
484    }
485
486    #[test]
487    fn test_random_search_reset() {
488        let range = ParameterRange::default();
489        let mut search = RandomSearch::new(range, 5);
490
491        while search.next().is_some() {}
492        assert!(search.is_complete());
493
494        search.reset();
495        assert!(!search.is_complete());
496    }
497
498    #[test]
499    fn test_random_search_generate_all() {
500        let range = ParameterRange::default();
501        let mut search = RandomSearch::new(range, 15);
502
503        let configs = search.generate_all();
504        assert_eq!(configs.len(), 15);
505        assert!(search.is_complete());
506    }
507
508    #[test]
509    fn test_lhs_creation() {
510        let range = ParameterRange::default();
511        let lhs = LatinHypercubeSampling::new(range, 10);
512        assert!(!lhs.is_complete());
513    }
514
515    #[test]
516    fn test_lhs_sampling() {
517        let range = ParameterRange::default();
518        let mut lhs = LatinHypercubeSampling::new(range.clone(), 20);
519
520        let mut count = 0;
521        while let Some(config) = lhs.next() {
522            assert!(range.contains(&config));
523            count += 1;
524        }
525
526        assert!(count > 0);
527        assert!(lhs.is_complete());
528    }
529
530    #[test]
531    fn test_lhs_coverage() {
532        let range = ParameterRange::default();
533        let lhs = LatinHypercubeSampling::new(range.clone(), 50);
534
535        let configs = lhs.all_configs();
536
537        // LHS should provide good coverage across the range
538        let avg_temp: f64 = configs.iter().map(|c| c.temperature).sum::<f64>() / configs.len() as f64;
539        let avg_top_p: f64 = configs.iter().map(|c| c.top_p).sum::<f64>() / configs.len() as f64;
540
541        // Average should be near middle of range
542        let temp_mid = (range.temp_min + range.temp_max) / 2.0;
543        let top_p_mid = (range.top_p_min + range.top_p_max) / 2.0;
544
545        assert!((avg_temp - temp_mid).abs() < 0.3);
546        assert!((avg_top_p - top_p_mid).abs() < 0.1);
547    }
548
549    #[test]
550    fn test_shuffle() {
551        let mut vec: Vec<usize> = (0..10).collect();
552        let original = vec.clone();
553
554        shuffle(&mut vec);
555
556        // Should contain same elements
557        let mut sorted = vec.clone();
558        sorted.sort();
559        assert_eq!(sorted, original);
560
561        // Should be different order (very high probability)
562        // This could theoretically fail, but with probability 1/10!
563        assert_ne!(vec, original);
564    }
565
566    #[test]
567    fn test_search_manager_grid() {
568        let range = ParameterRange::default();
569        let config = GridSearchConfig {
570            temp_steps: 2,
571            top_p_steps: 2,
572            max_tokens_steps: 2,
573        };
574
575        let mut manager = ParameterSearchManager::with_grid_search(range, config);
576        assert_eq!(manager.strategy(), SearchStrategy::Grid);
577
578        let mut count = 0;
579        while manager.next().is_some() {
580            count += 1;
581        }
582
583        assert_eq!(count, 8);
584        assert!(manager.is_complete());
585    }
586
587    #[test]
588    fn test_search_manager_random() {
589        let range = ParameterRange::default();
590        let mut manager = ParameterSearchManager::with_random_search(range, 10);
591        assert_eq!(manager.strategy(), SearchStrategy::Random);
592
593        let mut count = 0;
594        while manager.next().is_some() {
595            count += 1;
596        }
597
598        assert_eq!(count, 10);
599        assert!(manager.is_complete());
600    }
601
602    #[test]
603    fn test_search_manager_lhs() {
604        let range = ParameterRange::default();
605        let mut manager = ParameterSearchManager::with_lhs(range, 15);
606        assert_eq!(manager.strategy(), SearchStrategy::LatinHypercube);
607
608        let mut count = 0;
609        while manager.next().is_some() {
610            count += 1;
611        }
612
613        assert!(count > 0);
614        assert!(manager.is_complete());
615    }
616}