1use rand::Rng;
7use serde::{Deserialize, Serialize};
8use std::collections::VecDeque;
9
10use crate::adaptive_params::{ParameterConfig, ParameterRange};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub enum SearchStrategy {
15 Grid,
17 Random,
19 LatinHypercube,
21 Sobol,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct GridSearchConfig {
28 pub temp_steps: usize,
30 pub top_p_steps: usize,
32 pub max_tokens_steps: usize,
34}
35
36impl Default for GridSearchConfig {
37 fn default() -> Self {
38 Self {
39 temp_steps: 5,
40 top_p_steps: 4,
41 max_tokens_steps: 4,
42 }
43 }
44}
45
46pub struct GridSearch {
48 range: ParameterRange,
50 config: GridSearchConfig,
52 configurations: VecDeque<ParameterConfig>,
54}
55
56impl GridSearch {
57 pub fn new(range: ParameterRange, config: GridSearchConfig) -> Self {
59 let mut search = Self {
60 range,
61 config,
62 configurations: VecDeque::new(),
63 };
64 search.generate_grid();
65 search
66 }
67
68 pub fn with_defaults(range: ParameterRange) -> Self {
70 Self::new(range, GridSearchConfig::default())
71 }
72
73 fn generate_grid(&mut self) {
75 let temp_values = linspace(
76 self.range.temp_min,
77 self.range.temp_max,
78 self.config.temp_steps,
79 );
80
81 let top_p_values = linspace(
82 self.range.top_p_min,
83 self.range.top_p_max,
84 self.config.top_p_steps,
85 );
86
87 let max_tokens_values = linspace_usize(
88 self.range.max_tokens_min,
89 self.range.max_tokens_max,
90 self.config.max_tokens_steps,
91 );
92
93 for &temp in &temp_values {
94 for &top_p in &top_p_values {
95 for &max_tokens in &max_tokens_values {
96 if let Ok(config) = ParameterConfig::new(temp, top_p, max_tokens) {
97 self.configurations.push_back(config);
98 }
99 }
100 }
101 }
102 }
103
104 pub fn next(&mut self) -> Option<ParameterConfig> {
106 self.configurations.pop_front()
107 }
108
109 pub fn all_configs(&self) -> Vec<ParameterConfig> {
111 self.configurations.iter().cloned().collect()
112 }
113
114 pub fn total_configs(&self) -> usize {
116 self.config.temp_steps * self.config.top_p_steps * self.config.max_tokens_steps
117 }
118
119 pub fn is_complete(&self) -> bool {
121 self.configurations.is_empty()
122 }
123}
124
125pub struct RandomSearch {
127 range: ParameterRange,
129 num_samples: usize,
131 samples_generated: usize,
133}
134
135impl RandomSearch {
136 pub fn new(range: ParameterRange, num_samples: usize) -> Self {
138 Self {
139 range,
140 num_samples,
141 samples_generated: 0,
142 }
143 }
144
145 pub fn next(&mut self) -> Option<ParameterConfig> {
147 if self.samples_generated >= self.num_samples {
148 return None;
149 }
150
151 let mut rng = rand::thread_rng();
152
153 let temp = rng.gen_range(self.range.temp_min..=self.range.temp_max);
154 let top_p = rng.gen_range(self.range.top_p_min..=self.range.top_p_max);
155 let max_tokens = rng.gen_range(self.range.max_tokens_min..=self.range.max_tokens_max);
156
157 self.samples_generated += 1;
158
159 ParameterConfig::new(temp, top_p, max_tokens).ok()
160 }
161
162 pub fn generate_all(&mut self) -> Vec<ParameterConfig> {
164 let mut configs = Vec::new();
165 while let Some(config) = self.next() {
166 configs.push(config);
167 }
168 configs
169 }
170
171 pub fn is_complete(&self) -> bool {
173 self.samples_generated >= self.num_samples
174 }
175
176 pub fn reset(&mut self) {
178 self.samples_generated = 0;
179 }
180}
181
182pub struct LatinHypercubeSampling {
184 range: ParameterRange,
186 num_samples: usize,
188 configurations: VecDeque<ParameterConfig>,
190}
191
192impl LatinHypercubeSampling {
193 pub fn new(range: ParameterRange, num_samples: usize) -> Self {
195 let mut lhs = Self {
196 range,
197 num_samples,
198 configurations: VecDeque::new(),
199 };
200 lhs.generate_samples();
201 lhs
202 }
203
204 fn generate_samples(&mut self) {
206 let mut rng = rand::thread_rng();
207
208 let mut temp_indices: Vec<usize> = (0..self.num_samples).collect();
210 let mut top_p_indices: Vec<usize> = (0..self.num_samples).collect();
211 let mut tokens_indices: Vec<usize> = (0..self.num_samples).collect();
212
213 shuffle(&mut temp_indices);
215 shuffle(&mut top_p_indices);
216 shuffle(&mut tokens_indices);
217
218 for i in 0..self.num_samples {
220 let temp_cell = temp_indices[i] as f64 + rng.gen::<f64>();
222 let top_p_cell = top_p_indices[i] as f64 + rng.gen::<f64>();
223 let tokens_cell = tokens_indices[i] as f64 + rng.gen::<f64>();
224
225 let temp = self.range.temp_min
227 + (temp_cell / self.num_samples as f64)
228 * (self.range.temp_max - self.range.temp_min);
229
230 let top_p = self.range.top_p_min
231 + (top_p_cell / self.num_samples as f64)
232 * (self.range.top_p_max - self.range.top_p_min);
233
234 let max_tokens = self.range.max_tokens_min
235 + ((tokens_cell / self.num_samples as f64)
236 * (self.range.max_tokens_max - self.range.max_tokens_min) as f64)
237 as usize;
238
239 if let Ok(config) = ParameterConfig::new(temp, top_p, max_tokens) {
240 self.configurations.push_back(config);
241 }
242 }
243 }
244
245 pub fn next(&mut self) -> Option<ParameterConfig> {
247 self.configurations.pop_front()
248 }
249
250 pub fn all_configs(&self) -> Vec<ParameterConfig> {
252 self.configurations.iter().cloned().collect()
253 }
254
255 pub fn is_complete(&self) -> bool {
257 self.configurations.is_empty()
258 }
259}
260
261pub struct ParameterSearchManager {
263 strategy: SearchStrategy,
265 range: ParameterRange,
267 pub grid_search: Option<GridSearch>,
269 pub random_search: Option<RandomSearch>,
271 pub lhs_search: Option<LatinHypercubeSampling>,
273}
274
275impl ParameterSearchManager {
276 pub fn with_grid_search(range: ParameterRange, config: GridSearchConfig) -> Self {
278 Self {
279 strategy: SearchStrategy::Grid,
280 range: range.clone(),
281 grid_search: Some(GridSearch::new(range, config)),
282 random_search: None,
283 lhs_search: None,
284 }
285 }
286
287 pub fn with_random_search(range: ParameterRange, num_samples: usize) -> Self {
289 Self {
290 strategy: SearchStrategy::Random,
291 range: range.clone(),
292 grid_search: None,
293 random_search: Some(RandomSearch::new(range, num_samples)),
294 lhs_search: None,
295 }
296 }
297
298 pub fn with_lhs(range: ParameterRange, num_samples: usize) -> Self {
300 Self {
301 strategy: SearchStrategy::LatinHypercube,
302 range: range.clone(),
303 grid_search: None,
304 random_search: None,
305 lhs_search: Some(LatinHypercubeSampling::new(range, num_samples)),
306 }
307 }
308
309 pub fn next(&mut self) -> Option<ParameterConfig> {
311 match self.strategy {
312 SearchStrategy::Grid => self.grid_search.as_mut().and_then(|s| s.next()),
313 SearchStrategy::Random => self.random_search.as_mut().and_then(|s| s.next()),
314 SearchStrategy::LatinHypercube => self.lhs_search.as_mut().and_then(|s| s.next()),
315 SearchStrategy::Sobol => None, }
317 }
318
319 pub fn is_complete(&self) -> bool {
321 match self.strategy {
322 SearchStrategy::Grid => self
323 .grid_search
324 .as_ref()
325 .map(|s| s.is_complete())
326 .unwrap_or(true),
327 SearchStrategy::Random => self
328 .random_search
329 .as_ref()
330 .map(|s| s.is_complete())
331 .unwrap_or(true),
332 SearchStrategy::LatinHypercube => self
333 .lhs_search
334 .as_ref()
335 .map(|s| s.is_complete())
336 .unwrap_or(true),
337 SearchStrategy::Sobol => true,
338 }
339 }
340
341 pub fn strategy(&self) -> SearchStrategy {
343 self.strategy
344 }
345}
346
347fn linspace(start: f64, end: f64, num: usize) -> Vec<f64> {
349 if num == 0 {
350 return vec![];
351 }
352 if num == 1 {
353 return vec![start];
354 }
355
356 let step = (end - start) / (num - 1) as f64;
357 (0..num).map(|i| start + i as f64 * step).collect()
358}
359
360fn linspace_usize(start: usize, end: usize, num: usize) -> Vec<usize> {
362 if num == 0 {
363 return vec![];
364 }
365 if num == 1 {
366 return vec![start];
367 }
368
369 let step = (end - start) as f64 / (num - 1) as f64;
370 (0..num).map(|i| start + (i as f64 * step) as usize).collect()
371}
372
373fn shuffle<T>(vec: &mut [T]) {
375 let mut rng = rand::thread_rng();
376 let len = vec.len();
377 for i in 0..len {
378 let j = rng.gen_range(i..len);
379 vec.swap(i, j);
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 #[test]
388 fn test_linspace() {
389 let values = linspace(0.0, 1.0, 5);
390 assert_eq!(values.len(), 5);
391 assert_eq!(values[0], 0.0);
392 assert_eq!(values[4], 1.0);
393 assert!((values[2] - 0.5).abs() < 1e-10);
394 }
395
396 #[test]
397 fn test_linspace_usize() {
398 let values = linspace_usize(0, 100, 5);
399 assert_eq!(values.len(), 5);
400 assert_eq!(values[0], 0);
401 assert_eq!(values[4], 100);
402 }
403
404 #[test]
405 fn test_grid_search_creation() {
406 let range = ParameterRange::default();
407 let config = GridSearchConfig {
408 temp_steps: 3,
409 top_p_steps: 3,
410 max_tokens_steps: 2,
411 };
412
413 let search = GridSearch::new(range, config);
414 assert_eq!(search.total_configs(), 3 * 3 * 2);
415 }
416
417 #[test]
418 fn test_grid_search_iteration() {
419 let range = ParameterRange::default();
420 let config = GridSearchConfig {
421 temp_steps: 2,
422 top_p_steps: 2,
423 max_tokens_steps: 2,
424 };
425
426 let mut search = GridSearch::new(range, config);
427 let mut count = 0;
428
429 while search.next().is_some() {
430 count += 1;
431 }
432
433 assert_eq!(count, 8);
434 assert!(search.is_complete());
435 }
436
437 #[test]
438 fn test_grid_search_coverage() {
439 let range = ParameterRange {
440 temp_min: 0.0,
441 temp_max: 1.0,
442 top_p_min: 0.8,
443 top_p_max: 1.0,
444 max_tokens_min: 512,
445 max_tokens_max: 2048,
446 };
447
448 let config = GridSearchConfig {
449 temp_steps: 3,
450 top_p_steps: 3,
451 max_tokens_steps: 2,
452 };
453
454 let search = GridSearch::new(range.clone(), config);
455 let configs = search.all_configs();
456
457 assert!(configs.iter().any(|c| c.temperature == range.temp_min));
459 assert!(configs.iter().any(|c| c.temperature == range.temp_max));
460 assert!(configs.iter().any(|c| c.top_p == range.top_p_min));
461 assert!(configs.iter().any(|c| c.top_p == range.top_p_max));
462 }
463
464 #[test]
465 fn test_random_search_creation() {
466 let range = ParameterRange::default();
467 let search = RandomSearch::new(range, 10);
468 assert!(!search.is_complete());
469 }
470
471 #[test]
472 fn test_random_search_sampling() {
473 let range = ParameterRange::default();
474 let mut search = RandomSearch::new(range.clone(), 20);
475
476 let mut count = 0;
477 while let Some(config) = search.next() {
478 assert!(range.contains(&config));
479 count += 1;
480 }
481
482 assert_eq!(count, 20);
483 assert!(search.is_complete());
484 }
485
486 #[test]
487 fn test_random_search_reset() {
488 let range = ParameterRange::default();
489 let mut search = RandomSearch::new(range, 5);
490
491 while search.next().is_some() {}
492 assert!(search.is_complete());
493
494 search.reset();
495 assert!(!search.is_complete());
496 }
497
498 #[test]
499 fn test_random_search_generate_all() {
500 let range = ParameterRange::default();
501 let mut search = RandomSearch::new(range, 15);
502
503 let configs = search.generate_all();
504 assert_eq!(configs.len(), 15);
505 assert!(search.is_complete());
506 }
507
508 #[test]
509 fn test_lhs_creation() {
510 let range = ParameterRange::default();
511 let lhs = LatinHypercubeSampling::new(range, 10);
512 assert!(!lhs.is_complete());
513 }
514
515 #[test]
516 fn test_lhs_sampling() {
517 let range = ParameterRange::default();
518 let mut lhs = LatinHypercubeSampling::new(range.clone(), 20);
519
520 let mut count = 0;
521 while let Some(config) = lhs.next() {
522 assert!(range.contains(&config));
523 count += 1;
524 }
525
526 assert!(count > 0);
527 assert!(lhs.is_complete());
528 }
529
530 #[test]
531 fn test_lhs_coverage() {
532 let range = ParameterRange::default();
533 let lhs = LatinHypercubeSampling::new(range.clone(), 50);
534
535 let configs = lhs.all_configs();
536
537 let avg_temp: f64 = configs.iter().map(|c| c.temperature).sum::<f64>() / configs.len() as f64;
539 let avg_top_p: f64 = configs.iter().map(|c| c.top_p).sum::<f64>() / configs.len() as f64;
540
541 let temp_mid = (range.temp_min + range.temp_max) / 2.0;
543 let top_p_mid = (range.top_p_min + range.top_p_max) / 2.0;
544
545 assert!((avg_temp - temp_mid).abs() < 0.3);
546 assert!((avg_top_p - top_p_mid).abs() < 0.1);
547 }
548
549 #[test]
550 fn test_shuffle() {
551 let mut vec: Vec<usize> = (0..10).collect();
552 let original = vec.clone();
553
554 shuffle(&mut vec);
555
556 let mut sorted = vec.clone();
558 sorted.sort();
559 assert_eq!(sorted, original);
560
561 assert_ne!(vec, original);
564 }
565
566 #[test]
567 fn test_search_manager_grid() {
568 let range = ParameterRange::default();
569 let config = GridSearchConfig {
570 temp_steps: 2,
571 top_p_steps: 2,
572 max_tokens_steps: 2,
573 };
574
575 let mut manager = ParameterSearchManager::with_grid_search(range, config);
576 assert_eq!(manager.strategy(), SearchStrategy::Grid);
577
578 let mut count = 0;
579 while manager.next().is_some() {
580 count += 1;
581 }
582
583 assert_eq!(count, 8);
584 assert!(manager.is_complete());
585 }
586
587 #[test]
588 fn test_search_manager_random() {
589 let range = ParameterRange::default();
590 let mut manager = ParameterSearchManager::with_random_search(range, 10);
591 assert_eq!(manager.strategy(), SearchStrategy::Random);
592
593 let mut count = 0;
594 while manager.next().is_some() {
595 count += 1;
596 }
597
598 assert_eq!(count, 10);
599 assert!(manager.is_complete());
600 }
601
602 #[test]
603 fn test_search_manager_lhs() {
604 let range = ParameterRange::default();
605 let mut manager = ParameterSearchManager::with_lhs(range, 15);
606 assert_eq!(manager.strategy(), SearchStrategy::LatinHypercube);
607
608 let mut count = 0;
609 while manager.next().is_some() {
610 count += 1;
611 }
612
613 assert!(count > 0);
614 assert!(manager.is_complete());
615 }
616}