1use std::collections::HashMap;
29
30use crate::dataset::Dataset;
31use crate::error::{Result, ScryLearnError};
32use crate::metrics::accuracy;
33use crate::rng::FastRng;
34use crate::split::{k_fold, stratified_k_fold, ScoringFn};
35
36use super::{evaluate_combo, CvResult, ParamValue, Tunable};
37
38#[derive(Debug, Clone)]
56#[non_exhaustive]
57pub enum ParamDistribution {
58 Categorical(Vec<ParamValue>),
60 Uniform {
62 low: f64,
64 high: f64,
66 },
67 LogUniform {
70 low: f64,
72 high: f64,
74 },
75 IntUniform {
77 low: usize,
79 high: usize,
81 },
82}
83
84pub type ParamSpace = HashMap<String, ParamDistribution>;
96
97#[non_exhaustive]
134pub struct BayesSearchCV {
135 base_model: Box<dyn Tunable>,
136 param_space: ParamSpace,
137 n_iter: usize,
138 n_initial: usize,
139 gamma: f64,
140 cv: usize,
141 scorer: ScoringFn,
142 seed: u64,
143 stratified: bool,
144 best_params_: Option<HashMap<String, ParamValue>>,
146 best_score_: f64,
147 cv_results_: Vec<CvResult>,
148}
149
150impl std::fmt::Debug for BayesSearchCV {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 f.debug_struct("BayesSearchCV")
153 .field("n_iter", &self.n_iter)
154 .field("n_initial", &self.n_initial)
155 .field("gamma", &self.gamma)
156 .field("cv", &self.cv)
157 .field("seed", &self.seed)
158 .field("stratified", &self.stratified)
159 .field("best_score_", &self.best_score_)
160 .field("cv_results_len", &self.cv_results_.len())
161 .finish()
162 }
163}
164
165impl BayesSearchCV {
166 pub fn new(model: impl Tunable + 'static, param_space: ParamSpace) -> Self {
171 Self {
172 base_model: Box::new(model),
173 param_space,
174 n_iter: 30,
175 n_initial: 10,
176 gamma: 0.25,
177 cv: 5,
178 scorer: accuracy,
179 seed: 42,
180 stratified: false,
181 best_params_: None,
182 best_score_: f64::NEG_INFINITY,
183 cv_results_: Vec::new(),
184 }
185 }
186
187 pub fn n_iter(mut self, n: usize) -> Self {
189 self.n_iter = n;
190 self
191 }
192
193 pub fn n_initial(mut self, n: usize) -> Self {
195 self.n_initial = n;
196 self
197 }
198
199 pub fn gamma(mut self, gamma: f64) -> Self {
201 self.gamma = gamma;
202 self
203 }
204
205 pub fn cv(mut self, k: usize) -> Self {
207 self.cv = k;
208 self
209 }
210
211 pub fn scoring(mut self, scorer: ScoringFn) -> Self {
213 self.scorer = scorer;
214 self
215 }
216
217 pub fn seed(mut self, seed: u64) -> Self {
219 self.seed = seed;
220 self
221 }
222
223 pub fn stratified(mut self, stratified: bool) -> Self {
228 self.stratified = stratified;
229 self
230 }
231
232 pub fn fit(mut self, data: &Dataset) -> Result<Self> {
236 if self.cv < 2 {
237 return Err(ScryLearnError::InvalidParameter(format!(
238 "cv must be >= 2, got {}",
239 self.cv
240 )));
241 }
242 if self.param_space.is_empty() {
243 return Err(ScryLearnError::InvalidParameter(
244 "parameter space is empty".into(),
245 ));
246 }
247 if self.n_iter == 0 {
248 return Err(ScryLearnError::InvalidParameter(
249 "n_iter must be >= 1".into(),
250 ));
251 }
252
253 let folds = if self.stratified {
254 stratified_k_fold(data, self.cv, self.seed)
255 } else {
256 k_fold(data, self.cv, self.seed)
257 };
258
259 let mut rng = FastRng::new(self.seed);
260
261 let param_names: Vec<String> = {
263 let mut names: Vec<String> = self.param_space.keys().cloned().collect();
264 names.sort();
265 names
266 };
267
268 let n_initial = self.n_initial.min(self.n_iter);
270 for _ in 0..n_initial {
271 let combo = sample_random(&self.param_space, ¶m_names, &mut rng);
272 let result = evaluate_combo(&*self.base_model, &combo, &folds, self.scorer)?;
273 self.update_best(&result);
274 self.cv_results_.push(result);
275 }
276
277 let n_tpe = self.n_iter - n_initial;
279 for _ in 0..n_tpe {
280 let mut scores: Vec<f64> = self
282 .cv_results_
283 .iter()
284 .map(|r| r.mean_score)
285 .filter(|s| s.is_finite())
286 .collect();
287 scores.sort_by(|a, b| a.total_cmp(b));
288
289 let n_good = ((scores.len() as f64 * self.gamma).ceil() as usize).max(1);
290 let threshold = scores[scores.len().saturating_sub(n_good)];
291
292 let (good, bad): (Vec<&CvResult>, Vec<&CvResult>) = self
293 .cv_results_
294 .iter()
295 .filter(|r| r.mean_score.is_finite())
296 .partition(|r| r.mean_score >= threshold);
297
298 let combo = if bad.is_empty() {
300 sample_random(&self.param_space, ¶m_names, &mut rng)
301 } else {
302 let good_kde = build_factored_kde(&good, ¶m_names, &self.param_space);
304 let bad_kde = build_factored_kde(&bad, ¶m_names, &self.param_space);
305
306 let n_candidates = 100;
307 let mut best_candidate = sample_random(&self.param_space, ¶m_names, &mut rng);
308 let mut best_ei = f64::NEG_INFINITY;
309
310 for _ in 0..n_candidates {
311 let candidate = sample_random(&self.param_space, ¶m_names, &mut rng);
312 let l = evaluate_kde(&good_kde, &candidate, ¶m_names, &self.param_space);
313 let g = evaluate_kde(&bad_kde, &candidate, ¶m_names, &self.param_space);
314 let ei = if g > 1e-300 { l / g } else { l * 1e300 };
315 if ei > best_ei {
316 best_ei = ei;
317 best_candidate = candidate;
318 }
319 }
320 best_candidate
321 };
322
323 let result = evaluate_combo(&*self.base_model, &combo, &folds, self.scorer)?;
324 self.update_best(&result);
325 self.cv_results_.push(result);
326 }
327
328 if self.best_params_.is_none() {
329 return Err(ScryLearnError::InvalidParameter(
330 "all parameter combinations produced NaN scores".into(),
331 ));
332 }
333
334 Ok(self)
335 }
336
337 pub fn best_params(&self) -> &HashMap<String, ParamValue> {
343 self.best_params_.as_ref().expect("call fit() first")
344 }
345
346 pub fn best_score(&self) -> f64 {
348 self.best_score_
349 }
350
351 pub fn cv_results(&self) -> &[CvResult] {
353 &self.cv_results_
354 }
355
356 fn update_best(&mut self, result: &CvResult) {
357 if result.mean_score.is_finite()
358 && (self.best_params_.is_none() || result.mean_score > self.best_score_)
359 {
360 self.best_score_ = result.mean_score;
361 self.best_params_ = Some(result.params.clone());
362 }
363 }
364}
365
366fn sample_random(
372 space: &ParamSpace,
373 param_names: &[String],
374 rng: &mut FastRng,
375) -> HashMap<String, ParamValue> {
376 let mut combo = HashMap::new();
377 for name in param_names {
378 let dist = &space[name];
379 let value = match dist {
380 ParamDistribution::Categorical(values) => {
381 let idx = rng.usize(0..values.len());
382 values[idx].clone()
383 }
384 ParamDistribution::Uniform { low, high } => {
385 ParamValue::Float(low + rng.f64() * (high - low))
386 }
387 ParamDistribution::LogUniform { low, high } => {
388 let log_low = low.ln();
389 let log_high = high.ln();
390 ParamValue::Float((log_low + rng.f64() * (log_high - log_low)).exp())
391 }
392 ParamDistribution::IntUniform { low, high } => {
393 if high > low {
394 ParamValue::Int(low + rng.usize(0..=(high - low)))
395 } else {
396 ParamValue::Int(*low)
397 }
398 }
399 };
400 combo.insert(name.clone(), value);
401 }
402 combo
403}
404
405enum ParamKde {
412 Continuous {
414 observations: Vec<f64>,
415 bandwidth: f64,
416 },
417 Categorical {
419 probs: Vec<f64>,
421 },
422}
423
424struct FactoredKde {
426 kdes: Vec<(String, ParamKde)>,
427}
428
429fn build_factored_kde(
431 observations: &[&CvResult],
432 param_names: &[String],
433 space: &ParamSpace,
434) -> FactoredKde {
435 let mut kdes = Vec::with_capacity(param_names.len());
436
437 for name in param_names {
438 let dist = &space[name];
439 if let ParamDistribution::Categorical(values) = dist {
440 let n_categories = values.len();
441 let mut counts = vec![1.0_f64; n_categories]; for obs in observations {
444 if let Some(val) = obs.params.get(name) {
445 if let Some(idx) = values.iter().position(|v| v == val) {
446 counts[idx] += 1.0;
447 }
448 }
449 }
450 let total: f64 = counts.iter().sum();
451 let probs: Vec<f64> = counts.iter().map(|c| c / total).collect();
452 kdes.push((name.clone(), ParamKde::Categorical { probs }));
453 } else {
454 let obs_normalized: Vec<f64> = observations
456 .iter()
457 .filter_map(|r| r.params.get(name))
458 .map(|v| normalize_param(v, dist))
459 .collect();
460
461 let bw = if obs_normalized.is_empty() {
463 1.0
464 } else {
465 (obs_normalized.len() as f64).powf(-1.0 / 5.0)
466 };
467
468 kdes.push((
469 name.clone(),
470 ParamKde::Continuous {
471 observations: obs_normalized,
472 bandwidth: bw,
473 },
474 ));
475 }
476 }
477
478 FactoredKde { kdes }
479}
480
481fn evaluate_kde(
483 kde: &FactoredKde,
484 candidate: &HashMap<String, ParamValue>,
485 _param_names: &[String],
486 space: &ParamSpace,
487) -> f64 {
488 let mut log_density = 0.0_f64;
489
490 for (name, param_kde) in &kde.kdes {
491 let Some(val) = candidate.get(name) else {
492 continue;
493 };
494 let dist = &space[name];
495
496 match param_kde {
497 ParamKde::Continuous {
498 observations,
499 bandwidth,
500 } => {
501 let x = normalize_param(val, dist);
502 let n = observations.len() as f64;
503 if n < 1.0 {
504 continue;
505 }
506 let mut density_sum = 0.0_f64;
508 for &obs in observations {
509 let z = (x - obs) / bandwidth;
510 density_sum += (-0.5 * z * z).exp();
511 }
512 let density = density_sum / (n * bandwidth * (std::f64::consts::TAU).sqrt());
513 log_density += density.max(1e-300).ln();
515 }
516 ParamKde::Categorical { probs } => {
517 if let ParamDistribution::Categorical(values) = dist {
518 if let Some(idx) = values.iter().position(|v| v == val) {
519 log_density += probs[idx].max(1e-300).ln();
520 } else {
521 log_density += (1.0 / probs.len() as f64).ln();
523 }
524 }
525 }
526 }
527 }
528
529 log_density.exp()
530}
531
532fn normalize_param(value: &ParamValue, dist: &ParamDistribution) -> f64 {
534 match (value, dist) {
535 (ParamValue::Float(v), ParamDistribution::Uniform { low, high }) => {
536 if (high - low).abs() < 1e-300 {
537 0.5
538 } else {
539 (v - low) / (high - low)
540 }
541 }
542 (ParamValue::Float(v), ParamDistribution::LogUniform { low, high }) => {
543 let log_low = low.ln();
544 let log_high = high.ln();
545 if (log_high - log_low).abs() < 1e-300 {
546 0.5
547 } else {
548 (v.ln() - log_low) / (log_high - log_low)
549 }
550 }
551 (ParamValue::Int(v), ParamDistribution::IntUniform { low, high }) => {
552 if high == low {
553 0.5
554 } else {
555 (*v as f64 - *low as f64) / (*high as f64 - *low as f64)
556 }
557 }
558 (ParamValue::Float(v), _) => v.clamp(0.0, 1.0),
560 (ParamValue::Int(v), _) => (*v as f64).clamp(0.0, 1.0),
561 _ => 0.5,
562 }
563}
564
565#[cfg(test)]
570mod tests {
571 use super::*;
572 use crate::tree::DecisionTreeClassifier;
573
574 fn iris_like() -> Dataset {
576 let n_per_class = 30;
577 let n = n_per_class * 3;
578 let mut f0 = Vec::with_capacity(n);
579 let mut f1 = Vec::with_capacity(n);
580 let mut f2 = Vec::with_capacity(n);
581 let mut f3 = Vec::with_capacity(n);
582 let mut target = Vec::with_capacity(n);
583
584 let mut rng = FastRng::new(123);
585
586 for _ in 0..n_per_class {
587 f0.push(1.0 + rng.f64() * 0.5);
588 f1.push(1.0 + rng.f64() * 0.5);
589 f2.push(0.5 + rng.f64() * 0.3);
590 f3.push(0.1 + rng.f64() * 0.2);
591 target.push(0.0);
592 }
593 for _ in 0..n_per_class {
594 f0.push(5.0 + rng.f64() * 0.5);
595 f1.push(3.0 + rng.f64() * 0.5);
596 f2.push(3.5 + rng.f64() * 0.5);
597 f3.push(1.0 + rng.f64() * 0.3);
598 target.push(1.0);
599 }
600 for _ in 0..n_per_class {
601 f0.push(6.5 + rng.f64() * 0.5);
602 f1.push(3.0 + rng.f64() * 0.5);
603 f2.push(5.5 + rng.f64() * 0.5);
604 f3.push(2.0 + rng.f64() * 0.3);
605 target.push(2.0);
606 }
607
608 Dataset::new(
609 vec![f0, f1, f2, f3],
610 target,
611 vec![
612 "sepal_len".into(),
613 "sepal_wid".into(),
614 "petal_len".into(),
615 "petal_wid".into(),
616 ],
617 "species",
618 )
619 }
620
621 #[test]
622 fn test_bayes_search_int_uniform() {
623 let data = iris_like();
624 let mut space = ParamSpace::new();
625 space.insert(
626 "max_depth".into(),
627 ParamDistribution::IntUniform { low: 2, high: 10 },
628 );
629
630 let result = BayesSearchCV::new(DecisionTreeClassifier::new(), space)
631 .n_iter(15)
632 .n_initial(5)
633 .cv(3)
634 .seed(42)
635 .fit(&data)
636 .unwrap();
637
638 assert!(
639 result.best_score() > 0.7,
640 "bayes best score {:.3} too low",
641 result.best_score()
642 );
643 assert_eq!(result.cv_results().len(), 15);
644 assert!(result.best_params().contains_key("max_depth"));
645 }
646
647 #[test]
648 fn test_bayes_search_categorical() {
649 let data = iris_like();
650 let mut space = ParamSpace::new();
651 space.insert(
652 "max_depth".into(),
653 ParamDistribution::Categorical(vec![
654 ParamValue::Int(2),
655 ParamValue::Int(4),
656 ParamValue::Int(6),
657 ParamValue::Int(8),
658 ]),
659 );
660
661 let result = BayesSearchCV::new(DecisionTreeClassifier::new(), space)
662 .n_iter(10)
663 .n_initial(4)
664 .cv(3)
665 .seed(99)
666 .fit(&data)
667 .unwrap();
668
669 assert!(
670 result.best_score() > 0.5,
671 "bayes categorical best score {:.3} too low",
672 result.best_score()
673 );
674 assert!(result.best_params().contains_key("max_depth"));
675 }
676
677 #[test]
678 fn test_bayes_search_mixed_space() {
679 let data = iris_like();
680 let mut space = ParamSpace::new();
681 space.insert(
682 "max_depth".into(),
683 ParamDistribution::IntUniform { low: 2, high: 8 },
684 );
685 space.insert(
686 "min_samples_split".into(),
687 ParamDistribution::IntUniform { low: 2, high: 10 },
688 );
689
690 let result = BayesSearchCV::new(DecisionTreeClassifier::new(), space)
691 .n_iter(12)
692 .n_initial(5)
693 .cv(3)
694 .seed(42)
695 .fit(&data)
696 .unwrap();
697
698 assert_eq!(result.cv_results().len(), 12);
699 assert!(result.best_params().contains_key("max_depth"));
700 assert!(result.best_params().contains_key("min_samples_split"));
701 }
702
703 #[test]
704 fn test_bayes_search_stratified() {
705 let data = iris_like();
706 let mut space = ParamSpace::new();
707 space.insert(
708 "max_depth".into(),
709 ParamDistribution::IntUniform { low: 2, high: 8 },
710 );
711
712 let result = BayesSearchCV::new(DecisionTreeClassifier::new(), space)
713 .n_iter(10)
714 .n_initial(5)
715 .cv(3)
716 .stratified(true)
717 .seed(42)
718 .fit(&data)
719 .unwrap();
720
721 assert!(
722 result.best_score() > 0.7,
723 "stratified bayes best score {:.3} too low",
724 result.best_score()
725 );
726 }
727
728 #[test]
729 fn test_bayes_search_empty_space() {
730 let data = iris_like();
731 let space = ParamSpace::new();
732 let result = BayesSearchCV::new(DecisionTreeClassifier::new(), space).fit(&data);
733 assert!(result.is_err());
734 }
735
736 #[test]
737 fn test_bayes_search_n_iter_zero() {
738 let data = iris_like();
739 let mut space = ParamSpace::new();
740 space.insert(
741 "max_depth".into(),
742 ParamDistribution::IntUniform { low: 2, high: 8 },
743 );
744 let result = BayesSearchCV::new(DecisionTreeClassifier::new(), space)
745 .n_iter(0)
746 .fit(&data);
747 assert!(result.is_err());
748 }
749
750 #[test]
751 fn test_bayes_search_all_initial() {
752 let data = iris_like();
754 let mut space = ParamSpace::new();
755 space.insert(
756 "max_depth".into(),
757 ParamDistribution::IntUniform { low: 2, high: 6 },
758 );
759
760 let result = BayesSearchCV::new(DecisionTreeClassifier::new(), space)
761 .n_iter(5)
762 .n_initial(10)
763 .cv(3)
764 .seed(42)
765 .fit(&data)
766 .unwrap();
767
768 assert_eq!(result.cv_results().len(), 5);
769 }
770
771 #[test]
772 fn test_bayes_search_gbc_log_uniform() {
773 let data = iris_like();
774 let mut space = ParamSpace::new();
775 space.insert(
776 "n_estimators".into(),
777 ParamDistribution::Categorical(vec![
778 ParamValue::Int(5),
779 ParamValue::Int(10),
780 ParamValue::Int(20),
781 ]),
782 );
783 space.insert(
784 "max_depth".into(),
785 ParamDistribution::IntUniform { low: 2, high: 4 },
786 );
787
788 let result = BayesSearchCV::new(crate::tree::GradientBoostingClassifier::new(), space)
789 .n_iter(10)
790 .n_initial(5)
791 .cv(3)
792 .scoring(crate::metrics::accuracy)
793 .seed(42)
794 .fit(&data)
795 .unwrap();
796
797 assert!(
798 result.best_score() > 0.5,
799 "gbc bayes best score {:.3} too low",
800 result.best_score()
801 );
802 }
803
804 #[test]
805 fn test_normalize_param() {
806 let dist = ParamDistribution::Uniform {
807 low: 0.0,
808 high: 10.0,
809 };
810 let val = ParamValue::Float(5.0);
811 let norm = normalize_param(&val, &dist);
812 assert!((norm - 0.5).abs() < 1e-10);
813
814 let dist_int = ParamDistribution::IntUniform { low: 0, high: 10 };
815 let val_int = ParamValue::Int(5);
816 let norm_int = normalize_param(&val_int, &dist_int);
817 assert!((norm_int - 0.5).abs() < 1e-10);
818 }
819}