optimizer/sampler/tpe/sampler.rs
1//! Tree-Parzen Estimator (TPE) sampler implementation.
2//!
3//! TPE is a Bayesian optimization algorithm that models the objective function
4//! using two probability distributions: one for promising (good) parameter values
5//! and one for unpromising (bad) parameter values.
6//!
7//! # Gamma Strategies
8//!
9//! The gamma parameter controls what fraction of trials are considered "good".
10//! This module provides several built-in strategies via the [`GammaStrategy`] trait:
11//!
12//! - [`FixedGamma`]: Constant gamma value (default: 0.25)
13//! - [`LinearGamma`]: Linear interpolation between min and max based on trial count
14//! - [`SqrtGamma`]: Gamma decreases as 1/√n (similar to Optuna)
15//! - [`HyperoptGamma`]: Hyperopt-style adaptive gamma
16//!
17//! You can also implement your own strategy by implementing the [`GammaStrategy`] trait.
18//!
19//! # Examples
20//!
21//! Using a built-in gamma strategy:
22//!
23//! ```
24//! use optimizer::sampler::tpe::{SqrtGamma, TpeSampler};
25//!
26//! let sampler = TpeSampler::builder()
27//! .gamma_strategy(SqrtGamma::default())
28//! .build()
29//! .unwrap();
30//! ```
31//!
32//! Implementing a custom gamma strategy:
33//!
34//! ```
35//! use optimizer::sampler::tpe::{GammaStrategy, TpeSampler};
36//!
37//! #[derive(Debug, Clone)]
38//! struct MyGamma {
39//! base: f64,
40//! }
41//!
42//! impl GammaStrategy for MyGamma {
43//! fn gamma(&self, n_trials: usize) -> f64 {
44//! (self.base + 0.01 * n_trials as f64).min(0.5)
45//! }
46//!
47//! fn clone_box(&self) -> Box<dyn GammaStrategy> {
48//! Box::new(self.clone())
49//! }
50//! }
51//!
52//! let sampler = TpeSampler::builder()
53//! .gamma_strategy(MyGamma { base: 0.1 })
54//! .build()
55//! .unwrap();
56//! ```
57
58use core::fmt::Debug;
59use core::sync::atomic::{AtomicU64, Ordering};
60use std::sync::Arc;
61
62use crate::distribution::Distribution;
63use crate::error::{Error, Result};
64use crate::param::ParamValue;
65use crate::rng_util;
66use crate::sampler::common;
67use crate::sampler::tpe::gamma::{FixedGamma, GammaStrategy};
68use crate::sampler::{CompletedTrial, Sampler};
69
70use super::common as tpe_common;
71
72// ============================================================================
73// Gamma Strategy Trait and Implementations
74// ============================================================================
75
76// ============================================================================
77// TPE Sampler
78// ============================================================================
79
80/// A Tree-Parzen Estimator (TPE) sampler for Bayesian optimization.
81///
82/// TPE works by splitting completed trials into two groups based on their
83/// objective values: good trials (below the gamma quantile) and bad trials
84/// (above the gamma quantile). It then fits kernel density estimators (KDE)
85/// to each group and samples new points that maximize the ratio l(x)/g(x),
86/// where l(x) is the density of good trials and g(x) is the density of bad trials.
87///
88/// During the startup phase (when fewer than `n_startup_trials` are completed),
89/// TPE falls back to random sampling to gather initial data.
90///
91/// # Gamma Strategies
92///
93/// The gamma quantile can be configured using different strategies via the
94/// [`GammaStrategy`] trait.
95///
96/// # Examples
97///
98/// ```
99/// use optimizer::sampler::tpe::TpeSampler;
100///
101/// // Create with default settings (FixedGamma at 0.25)
102/// let sampler = TpeSampler::new();
103///
104/// // Create with custom settings using the builder
105/// let sampler = TpeSampler::builder()
106/// .gamma(0.15) // Shorthand for FixedGamma::new(0.15)
107/// .n_startup_trials(20)
108/// .n_ei_candidates(32)
109/// .seed(42)
110/// .build()
111/// .unwrap();
112/// ```
113///
114/// Using a different gamma strategy:
115///
116/// ```
117/// use optimizer::sampler::tpe::{SqrtGamma, TpeSampler};
118///
119/// let sampler = TpeSampler::builder()
120/// .gamma_strategy(SqrtGamma::default())
121/// .build()
122/// .unwrap();
123/// ```
124pub struct TpeSampler {
125 /// Strategy for computing the gamma quantile.
126 gamma_strategy: Arc<dyn GammaStrategy>,
127 /// Number of trials before TPE kicks in (uses random sampling before this).
128 n_startup_trials: usize,
129 /// Number of candidate samples to evaluate when selecting the next point.
130 n_ei_candidates: usize,
131 /// Optional fixed bandwidth for KDE. If None, uses Scott's rule.
132 kde_bandwidth: Option<f64>,
133 /// Base seed for deterministic per-call RNG derivation (no mutex needed).
134 seed: u64,
135 /// Monotonic counter to disambiguate calls with identical (`trial_id`, distribution).
136 call_seq: AtomicU64,
137}
138
139impl TpeSampler {
140 /// Creates a new TPE sampler with default settings.
141 ///
142 /// Default settings:
143 /// - gamma strategy: [`FixedGamma`] with gamma = 0.25
144 /// - `n_startup_trials`: 10 (random sampling for first 10 trials)
145 /// - `n_ei_candidates`: 24 (evaluate 24 candidates per sample)
146 /// - `kde_bandwidth`: None (uses Scott's rule for automatic bandwidth)
147 #[must_use]
148 pub fn new() -> Self {
149 Self {
150 gamma_strategy: Arc::new(FixedGamma::default()),
151 n_startup_trials: 10,
152 n_ei_candidates: 24,
153 kde_bandwidth: None,
154 seed: fastrand::u64(..),
155 call_seq: AtomicU64::new(0),
156 }
157 }
158
159 /// Creates a builder for configuring a TPE sampler.
160 ///
161 /// # Examples
162 ///
163 /// ```
164 /// use optimizer::sampler::tpe::TpeSampler;
165 ///
166 /// let sampler = TpeSampler::builder()
167 /// .gamma(0.15)
168 /// .n_startup_trials(20)
169 /// .n_ei_candidates(32)
170 /// .seed(42)
171 /// .build()
172 /// .unwrap();
173 /// ```
174 #[must_use]
175 pub fn builder() -> TpeSamplerBuilder {
176 TpeSamplerBuilder::new()
177 }
178
179 /// Creates a new TPE sampler with custom configuration.
180 ///
181 /// This method uses a fixed gamma value. For more advanced gamma strategies,
182 /// use [`TpeSampler::with_strategy`] or the builder pattern with
183 /// [`TpeSamplerBuilder::gamma_strategy`].
184 ///
185 /// # Arguments
186 ///
187 /// * `gamma` - Fraction of trials to consider "good" (0.0 to 1.0).
188 /// * `n_startup_trials` - Number of random trials before TPE sampling.
189 /// * `n_ei_candidates` - Number of candidates to evaluate per sample.
190 /// * `kde_bandwidth` - Optional fixed bandwidth for KDE. If None, uses Scott's rule.
191 /// * `seed` - Optional seed for reproducibility.
192 ///
193 /// # Errors
194 ///
195 /// Returns `Error::InvalidGamma` if gamma is not in (0.0, 1.0).
196 /// Returns `Error::InvalidBandwidth` if `kde_bandwidth` is Some but not positive.
197 pub fn with_config(
198 gamma: f64,
199 n_startup_trials: usize,
200 n_ei_candidates: usize,
201 kde_bandwidth: Option<f64>,
202 seed: Option<u64>,
203 ) -> Result<Self> {
204 let gamma_strategy = FixedGamma::new(gamma)?;
205 Self::with_strategy(
206 gamma_strategy,
207 n_startup_trials,
208 n_ei_candidates,
209 kde_bandwidth,
210 seed,
211 )
212 }
213
214 /// Creates a new TPE sampler with a custom gamma strategy.
215 ///
216 /// # Arguments
217 ///
218 /// * `gamma_strategy` - The strategy for computing the gamma quantile.
219 /// * `n_startup_trials` - Number of random trials before TPE sampling.
220 /// * `n_ei_candidates` - Number of candidates to evaluate per sample.
221 /// * `kde_bandwidth` - Optional fixed bandwidth for KDE. If None, uses Scott's rule.
222 /// * `seed` - Optional seed for reproducibility.
223 ///
224 /// # Errors
225 ///
226 /// Returns `Error::InvalidBandwidth` if `kde_bandwidth` is Some but not positive.
227 ///
228 /// # Examples
229 ///
230 /// ```
231 /// use optimizer::sampler::tpe::{SqrtGamma, TpeSampler};
232 ///
233 /// let sampler = TpeSampler::with_strategy(
234 /// SqrtGamma::default(),
235 /// 10, // n_startup_trials
236 /// 24, // n_ei_candidates
237 /// None, // kde_bandwidth
238 /// Some(42), // seed
239 /// )
240 /// .unwrap();
241 /// ```
242 pub fn with_strategy<G: GammaStrategy + 'static>(
243 gamma_strategy: G,
244 n_startup_trials: usize,
245 n_ei_candidates: usize,
246 kde_bandwidth: Option<f64>,
247 seed: Option<u64>,
248 ) -> Result<Self> {
249 if let Some(bw) = kde_bandwidth
250 && bw <= 0.0
251 {
252 return Err(Error::InvalidBandwidth(bw));
253 }
254
255 Ok(Self {
256 gamma_strategy: Arc::new(gamma_strategy),
257 n_startup_trials,
258 n_ei_candidates,
259 kde_bandwidth,
260 seed: seed.unwrap_or_else(|| fastrand::u64(..)),
261 call_seq: AtomicU64::new(0),
262 })
263 }
264
265 /// Returns the gamma strategy used by this sampler.
266 #[must_use]
267 pub fn gamma_strategy(&self) -> &dyn GammaStrategy {
268 self.gamma_strategy.as_ref()
269 }
270
271 /// Splits trials into good and bad groups based on the gamma quantile.
272 ///
273 /// The gamma value is computed dynamically using the configured [`GammaStrategy`].
274 ///
275 /// Returns (`good_trials`, `bad_trials`) where `good_trials` contains trials
276 /// with values below the gamma quantile (for minimization).
277 #[allow(
278 clippy::cast_precision_loss,
279 clippy::cast_possible_truncation,
280 clippy::cast_sign_loss
281 )]
282 #[must_use]
283 fn split_trials<'a>(
284 &self,
285 history: &'a [CompletedTrial],
286 ) -> (Vec<&'a CompletedTrial>, Vec<&'a CompletedTrial>) {
287 if history.is_empty() {
288 return (vec![], vec![]);
289 }
290
291 // Compute gamma using the strategy and clamp to valid range
292 let gamma = self
293 .gamma_strategy
294 .gamma(history.len())
295 .clamp(f64::EPSILON, 1.0 - f64::EPSILON);
296
297 // Calculate the split point (gamma quantile)
298 // Ensure at least 1 trial in each group if possible
299 let n_good = ((history.len() as f64 * gamma).ceil() as usize)
300 .max(1)
301 .min(history.len() - 1);
302
303 // Use quickselect (O(n)) to partition indices instead of full sort (O(n log n)).
304 // We only need to know which trials are in the top gamma-quantile, not their order.
305 let mut indices: Vec<usize> = (0..history.len()).collect();
306 if n_good > 0 {
307 indices.select_nth_unstable_by(n_good - 1, |&a, &b| {
308 history[a]
309 .value
310 .partial_cmp(&history[b].value)
311 .unwrap_or(core::cmp::Ordering::Equal)
312 });
313 }
314
315 let good: Vec<_> = indices[..n_good].iter().map(|&i| &history[i]).collect();
316 let bad: Vec<_> = indices[n_good..].iter().map(|&i| &history[i]).collect();
317
318 (good, bad)
319 }
320}
321
322impl Default for TpeSampler {
323 fn default() -> Self {
324 Self::new()
325 }
326}
327
328/// Builder for configuring a [`TpeSampler`].
329///
330/// This builder allows fluent configuration of TPE hyperparameters.
331///
332/// # Examples
333///
334/// Using a fixed gamma value:
335///
336/// ```
337/// use optimizer::sampler::tpe::TpeSamplerBuilder;
338///
339/// let sampler = TpeSamplerBuilder::new()
340/// .gamma(0.15)
341/// .n_startup_trials(20)
342/// .n_ei_candidates(32)
343/// .seed(42)
344/// .build()
345/// .unwrap();
346/// ```
347///
348/// Using a custom gamma strategy:
349///
350/// ```
351/// use optimizer::sampler::tpe::{SqrtGamma, TpeSamplerBuilder};
352///
353/// let sampler = TpeSamplerBuilder::new()
354/// .gamma_strategy(SqrtGamma::default())
355/// .n_startup_trials(20)
356/// .build()
357/// .unwrap();
358/// ```
359#[derive(Debug, Clone)]
360pub struct TpeSamplerBuilder {
361 gamma_strategy: Box<dyn GammaStrategy>,
362 /// Raw gamma value for deferred validation (Some if `gamma()` was called)
363 raw_gamma: Option<f64>,
364 n_startup_trials: usize,
365 n_ei_candidates: usize,
366 kde_bandwidth: Option<f64>,
367 seed: Option<u64>,
368}
369
370impl TpeSamplerBuilder {
371 /// Creates a new builder with default settings.
372 ///
373 /// Default settings:
374 /// - gamma strategy: [`FixedGamma`] with gamma = 0.25
375 /// - `n_startup_trials`: 10 (random sampling for first 10 trials)
376 /// - `n_ei_candidates`: 24 (evaluate 24 candidates per sample)
377 /// - `kde_bandwidth`: None (uses Scott's rule for automatic bandwidth)
378 /// - seed: None (use OS-provided entropy)
379 #[must_use]
380 pub fn new() -> Self {
381 Self {
382 gamma_strategy: Box::new(FixedGamma::default()),
383 raw_gamma: None,
384 n_startup_trials: 10,
385 n_ei_candidates: 24,
386 kde_bandwidth: None,
387 seed: None,
388 }
389 }
390
391 /// Sets a fixed gamma value for splitting trials into good/bad groups.
392 ///
393 /// This is a convenience method that creates a [`FixedGamma`] strategy.
394 /// For more advanced gamma strategies, use [`gamma_strategy`](Self::gamma_strategy).
395 ///
396 /// A gamma of 0.25 means the top 25% of trials (by objective value) are
397 /// considered "good" and used to build the l(x) distribution.
398 ///
399 /// # Arguments
400 ///
401 /// * `gamma` - Quantile value, must be in (0.0, 1.0).
402 ///
403 /// # Examples
404 ///
405 /// ```
406 /// use optimizer::sampler::tpe::TpeSamplerBuilder;
407 ///
408 /// let sampler = TpeSamplerBuilder::new()
409 /// .gamma(0.10) // Use top 10% as "good" trials
410 /// .build()
411 /// .unwrap();
412 /// ```
413 ///
414 /// # Note
415 ///
416 /// Validation happens at `build()` time. If gamma is not in (0.0, 1.0),
417 /// `build()` will return `Err(Error::InvalidGamma)`.
418 #[must_use]
419 pub fn gamma(mut self, gamma: f64) -> Self {
420 // We defer validation to build() time for consistency with the existing API
421 // Store the raw value for validation later
422 self.raw_gamma = Some(gamma);
423 self
424 }
425
426 /// Sets a custom gamma strategy for splitting trials into good/bad groups.
427 ///
428 /// The gamma strategy determines what fraction of trials are considered
429 /// "good" based on the number of completed trials. This allows the gamma
430 /// value to adapt dynamically during optimization.
431 ///
432 /// # Arguments
433 ///
434 /// * `strategy` - A type implementing [`GammaStrategy`].
435 ///
436 /// # Examples
437 ///
438 /// Using built-in strategies:
439 ///
440 /// ```
441 /// use optimizer::sampler::tpe::{LinearGamma, SqrtGamma, TpeSamplerBuilder};
442 ///
443 /// // Square root strategy (Optuna-style)
444 /// let sampler = TpeSamplerBuilder::new()
445 /// .gamma_strategy(SqrtGamma::default())
446 /// .build()
447 /// .unwrap();
448 ///
449 /// // Linear interpolation strategy
450 /// let sampler = TpeSamplerBuilder::new()
451 /// .gamma_strategy(LinearGamma::new(0.1, 0.3, 50).unwrap())
452 /// .build()
453 /// .unwrap();
454 /// ```
455 ///
456 /// Using a custom strategy:
457 ///
458 /// ```
459 /// use optimizer::sampler::tpe::{GammaStrategy, TpeSamplerBuilder};
460 ///
461 /// #[derive(Debug, Clone)]
462 /// struct MyGamma;
463 ///
464 /// impl GammaStrategy for MyGamma {
465 /// fn gamma(&self, n_trials: usize) -> f64 {
466 /// 0.25 // Always return 0.25
467 /// }
468 /// fn clone_box(&self) -> Box<dyn GammaStrategy> {
469 /// Box::new(self.clone())
470 /// }
471 /// }
472 ///
473 /// let sampler = TpeSamplerBuilder::new()
474 /// .gamma_strategy(MyGamma)
475 /// .build()
476 /// .unwrap();
477 /// ```
478 #[must_use]
479 pub fn gamma_strategy<G: GammaStrategy + 'static>(mut self, strategy: G) -> Self {
480 self.gamma_strategy = Box::new(strategy);
481 self.raw_gamma = None; // Clear any raw gamma set by gamma()
482 self
483 }
484
485 /// Sets the number of startup trials before TPE sampling begins.
486 ///
487 /// During the startup phase, the sampler uses uniform random sampling
488 /// to gather initial data. Once `n_startup_trials` have completed,
489 /// TPE-based sampling begins.
490 ///
491 /// # Arguments
492 ///
493 /// * `n` - Number of random trials before TPE kicks in.
494 ///
495 /// # Examples
496 ///
497 /// ```
498 /// use optimizer::sampler::tpe::TpeSamplerBuilder;
499 ///
500 /// let sampler = TpeSamplerBuilder::new()
501 /// .n_startup_trials(20) // Random sample first 20 trials
502 /// .build()
503 /// .unwrap();
504 /// ```
505 #[must_use]
506 pub fn n_startup_trials(mut self, n: usize) -> Self {
507 self.n_startup_trials = n;
508 self
509 }
510
511 /// Sets the number of EI (Expected Improvement) candidates to evaluate.
512 ///
513 /// When sampling a new point, TPE generates this many candidates from
514 /// the l(x) distribution and selects the one with the highest l(x)/g(x)
515 /// ratio.
516 ///
517 /// # Arguments
518 ///
519 /// * `n` - Number of candidates to evaluate per sample.
520 ///
521 /// # Examples
522 ///
523 /// ```
524 /// use optimizer::sampler::tpe::TpeSamplerBuilder;
525 ///
526 /// let sampler = TpeSamplerBuilder::new()
527 /// .n_ei_candidates(48) // Evaluate more candidates
528 /// .build()
529 /// .unwrap();
530 /// ```
531 #[must_use]
532 pub fn n_ei_candidates(mut self, n: usize) -> Self {
533 self.n_ei_candidates = n;
534 self
535 }
536
537 /// Sets a fixed bandwidth for the kernel density estimator.
538 ///
539 /// By default, TPE uses Scott's rule to automatically select the bandwidth
540 /// based on the sample data. Use this method to override with a fixed value.
541 ///
542 /// Smaller bandwidths give more localized, peaky distributions.
543 /// Larger bandwidths give smoother, more spread-out distributions.
544 ///
545 /// # Arguments
546 ///
547 /// * `bandwidth` - The fixed bandwidth (standard deviation) for Gaussian kernels.
548 ///
549 /// # Examples
550 ///
551 /// ```
552 /// use optimizer::sampler::tpe::TpeSamplerBuilder;
553 ///
554 /// let sampler = TpeSamplerBuilder::new()
555 /// .kde_bandwidth(0.5) // Fixed bandwidth of 0.5
556 /// .build()
557 /// .unwrap();
558 /// ```
559 ///
560 /// # Note
561 ///
562 /// Validation happens at `build()` time. If bandwidth is not positive,
563 /// `build()` will return `Err(Error::InvalidBandwidth)`.
564 #[must_use]
565 pub fn kde_bandwidth(mut self, bandwidth: f64) -> Self {
566 self.kde_bandwidth = Some(bandwidth);
567 self
568 }
569
570 /// Sets a seed for reproducible sampling.
571 ///
572 /// # Arguments
573 ///
574 /// * `seed` - Seed value for the random number generator.
575 ///
576 /// # Examples
577 ///
578 /// ```
579 /// use optimizer::sampler::tpe::TpeSamplerBuilder;
580 ///
581 /// let sampler = TpeSamplerBuilder::new()
582 /// .seed(42) // Reproducible results
583 /// .build()
584 /// .unwrap();
585 /// ```
586 #[must_use]
587 pub fn seed(mut self, seed: u64) -> Self {
588 self.seed = Some(seed);
589 self
590 }
591
592 /// Builds the configured [`TpeSampler`].
593 ///
594 /// # Errors
595 ///
596 /// Returns `Error::InvalidGamma` if a fixed gamma value was set and is not in (0.0, 1.0).
597 /// Returns `Error::InvalidBandwidth` if `kde_bandwidth` is Some but not positive.
598 ///
599 /// # Examples
600 ///
601 /// ```
602 /// use optimizer::sampler::tpe::TpeSamplerBuilder;
603 ///
604 /// let sampler = TpeSamplerBuilder::new()
605 /// .gamma(0.15)
606 /// .n_startup_trials(20)
607 /// .n_ei_candidates(32)
608 /// .seed(42)
609 /// .build()
610 /// .unwrap();
611 /// ```
612 pub fn build(self) -> Result<TpeSampler> {
613 // Determine the gamma strategy to use
614 let gamma_strategy: Arc<dyn GammaStrategy> = if let Some(raw) = self.raw_gamma {
615 // Validate and create FixedGamma from raw value
616 Arc::new(FixedGamma::new(raw)?)
617 } else {
618 Arc::from(self.gamma_strategy)
619 };
620
621 // Validate bandwidth
622 if let Some(bw) = self.kde_bandwidth
623 && bw <= 0.0
624 {
625 return Err(Error::InvalidBandwidth(bw));
626 }
627
628 Ok(TpeSampler {
629 gamma_strategy,
630 n_startup_trials: self.n_startup_trials,
631 n_ei_candidates: self.n_ei_candidates,
632 kde_bandwidth: self.kde_bandwidth,
633 seed: self.seed.unwrap_or_else(|| fastrand::u64(..)),
634 call_seq: AtomicU64::new(0),
635 })
636 }
637}
638
639impl Default for TpeSamplerBuilder {
640 fn default() -> Self {
641 Self::new()
642 }
643}
644
645/// Deterministically pick a parameter value matching `target_dist` from a
646/// trial. When multiple parameters share the same distribution (e.g. two
647/// `FloatParam::new(-5.0, 5.0)` in the same study), the one with the
648/// smallest [`ParamId`] is chosen so behavior does not depend on
649/// `HashMap` iteration order.
650fn find_matching_value<'t>(
651 t: &'t CompletedTrial,
652 target_dist: &Distribution,
653) -> Option<&'t ParamValue> {
654 t.distributions
655 .iter()
656 .filter(|(_, dist)| *dist == target_dist)
657 .min_by_key(|(id, _)| *id)
658 .and_then(|(id, _)| t.params.get(id))
659}
660
661impl TpeSampler {
662 fn sample_float(
663 &self,
664 d: &crate::distribution::FloatDistribution,
665 good_trials: &[&CompletedTrial],
666 bad_trials: &[&CompletedTrial],
667 rng: &mut fastrand::Rng,
668 ) -> ParamValue {
669 let target_dist = Distribution::Float(d.clone());
670 let good_values: Vec<f64> = good_trials
671 .iter()
672 .filter_map(|t| match find_matching_value(t, &target_dist)? {
673 ParamValue::Float(f) => Some(*f),
674 _ => None,
675 })
676 .collect();
677
678 let bad_values: Vec<f64> = bad_trials
679 .iter()
680 .filter_map(|t| match find_matching_value(t, &target_dist)? {
681 ParamValue::Float(f) => Some(*f),
682 _ => None,
683 })
684 .collect();
685
686 if good_values.is_empty() || bad_values.is_empty() {
687 return ParamValue::Float(rng_util::f64_range(rng, d.low, d.high));
688 }
689
690 let value = tpe_common::sample_tpe_float(
691 d,
692 good_values,
693 bad_values,
694 self.n_ei_candidates,
695 self.kde_bandwidth,
696 rng,
697 );
698 ParamValue::Float(value)
699 }
700
701 fn sample_int(
702 &self,
703 d: &crate::distribution::IntDistribution,
704 good_trials: &[&CompletedTrial],
705 bad_trials: &[&CompletedTrial],
706 rng: &mut fastrand::Rng,
707 ) -> ParamValue {
708 let target_dist = Distribution::Int(d.clone());
709 let good_values: Vec<i64> = good_trials
710 .iter()
711 .filter_map(|t| match find_matching_value(t, &target_dist)? {
712 ParamValue::Int(i) => Some(*i),
713 _ => None,
714 })
715 .collect();
716
717 let bad_values: Vec<i64> = bad_trials
718 .iter()
719 .filter_map(|t| match find_matching_value(t, &target_dist)? {
720 ParamValue::Int(i) => Some(*i),
721 _ => None,
722 })
723 .collect();
724
725 if good_values.is_empty() || bad_values.is_empty() {
726 return common::sample_random(rng, &Distribution::Int(d.clone()));
727 }
728
729 let value = tpe_common::sample_tpe_int(
730 d,
731 good_values,
732 bad_values,
733 self.n_ei_candidates,
734 self.kde_bandwidth,
735 rng,
736 );
737 ParamValue::Int(value)
738 }
739
740 #[allow(clippy::unused_self)]
741 fn sample_categorical(
742 &self,
743 d: &crate::distribution::CategoricalDistribution,
744 good_trials: &[&CompletedTrial],
745 bad_trials: &[&CompletedTrial],
746 rng: &mut fastrand::Rng,
747 ) -> ParamValue {
748 let target_dist = Distribution::Categorical(d.clone());
749 let good_indices: Vec<usize> = good_trials
750 .iter()
751 .filter_map(|t| match find_matching_value(t, &target_dist)? {
752 ParamValue::Categorical(i) => Some(*i),
753 _ => None,
754 })
755 .collect();
756
757 let bad_indices: Vec<usize> = bad_trials
758 .iter()
759 .filter_map(|t| match find_matching_value(t, &target_dist)? {
760 ParamValue::Categorical(i) => Some(*i),
761 _ => None,
762 })
763 .collect();
764
765 if good_indices.is_empty() || bad_indices.is_empty() {
766 return common::sample_random(rng, &Distribution::Categorical(d.clone()));
767 }
768
769 let index =
770 tpe_common::sample_tpe_categorical(d.n_choices, &good_indices, &bad_indices, rng);
771 ParamValue::Categorical(index)
772 }
773}
774
775impl Sampler for TpeSampler {
776 fn sample(
777 &self,
778 distribution: &Distribution,
779 trial_id: u64,
780 history: &[CompletedTrial],
781 ) -> ParamValue {
782 let seq = self.call_seq.fetch_add(1, Ordering::Relaxed);
783 let mut rng = fastrand::Rng::with_seed(rng_util::mix_seed(
784 self.seed,
785 trial_id,
786 rng_util::distribution_fingerprint(distribution).wrapping_add(seq),
787 ));
788
789 // Fall back to random sampling during startup phase
790 if history.len() < self.n_startup_trials {
791 return common::sample_random(&mut rng, distribution);
792 }
793
794 // Split trials into good and bad groups
795 let (good_trials, bad_trials) = self.split_trials(history);
796
797 // Need at least 1 trial in each group for TPE
798 if good_trials.is_empty() || bad_trials.is_empty() {
799 return common::sample_random(&mut rng, distribution);
800 }
801
802 match distribution {
803 Distribution::Float(d) => self.sample_float(d, &good_trials, &bad_trials, &mut rng),
804 Distribution::Int(d) => self.sample_int(d, &good_trials, &bad_trials, &mut rng),
805 Distribution::Categorical(d) => {
806 self.sample_categorical(d, &good_trials, &bad_trials, &mut rng)
807 }
808 }
809 }
810}
811
812#[cfg(test)]
813#[allow(
814 clippy::similar_names,
815 clippy::cast_sign_loss,
816 clippy::cast_precision_loss
817)]
818mod tests {
819 use std::collections::HashMap;
820
821 use super::*;
822 use crate::distribution::{CategoricalDistribution, FloatDistribution, IntDistribution};
823 use crate::parameter::ParamId;
824
825 fn create_trial(
826 id: u64,
827 value: f64,
828 params: Vec<(ParamId, ParamValue, Distribution)>,
829 ) -> CompletedTrial {
830 let mut param_map = HashMap::new();
831 let mut dist_map = HashMap::new();
832 for (param_id, pv, dist) in params {
833 param_map.insert(param_id, pv);
834 dist_map.insert(param_id, dist);
835 }
836 CompletedTrial::new(id, param_map, dist_map, HashMap::new(), value)
837 }
838
839 #[test]
840 fn test_tpe_sampler_new() {
841 let sampler = TpeSampler::new();
842 // Default uses FixedGamma with 0.25
843 assert!((sampler.gamma_strategy().gamma(0) - 0.25).abs() < f64::EPSILON);
844 assert_eq!(sampler.n_startup_trials, 10);
845 assert_eq!(sampler.n_ei_candidates, 24);
846 }
847
848 #[test]
849 fn test_tpe_sampler_with_config() {
850 let sampler = TpeSampler::with_config(0.15, 20, 32, None, Some(42)).unwrap();
851 // with_config uses FixedGamma
852 assert!((sampler.gamma_strategy().gamma(0) - 0.15).abs() < f64::EPSILON);
853 assert_eq!(sampler.n_startup_trials, 20);
854 assert_eq!(sampler.n_ei_candidates, 32);
855 }
856
857 #[test]
858 fn test_tpe_sampler_invalid_gamma_zero() {
859 let result = TpeSampler::with_config(0.0, 10, 24, None, None);
860 assert!(matches!(result, Err(Error::InvalidGamma(_))));
861 }
862
863 #[test]
864 fn test_tpe_sampler_invalid_gamma_one() {
865 let result = TpeSampler::with_config(1.0, 10, 24, None, None);
866 assert!(matches!(result, Err(Error::InvalidGamma(_))));
867 }
868
869 #[test]
870 fn test_tpe_startup_random_sampling() {
871 let sampler = TpeSampler::with_config(0.25, 10, 24, None, Some(42)).unwrap();
872 let dist = Distribution::Float(FloatDistribution {
873 low: 0.0,
874 high: 1.0,
875 log_scale: false,
876 step: None,
877 });
878
879 // With fewer than n_startup_trials, should use random sampling
880 let history: Vec<CompletedTrial> = vec![];
881
882 for i in 0..100 {
883 let value = sampler.sample(&dist, i, &history);
884 if let ParamValue::Float(v) = value {
885 assert!((0.0..=1.0).contains(&v));
886 } else {
887 panic!("Expected Float value");
888 }
889 }
890 }
891
892 #[test]
893 fn test_tpe_split_trials() {
894 let sampler = TpeSampler::with_config(0.25, 10, 24, None, Some(42)).unwrap();
895
896 let dist = Distribution::Float(FloatDistribution {
897 low: 0.0,
898 high: 1.0,
899 log_scale: false,
900 step: None,
901 });
902
903 // Create 20 trials with values 0..20
904 let x_id = ParamId::new();
905 let history: Vec<CompletedTrial> = (0..20)
906 .map(|i| {
907 create_trial(
908 i as u64,
909 f64::from(i),
910 vec![(x_id, ParamValue::Float(f64::from(i) / 20.0), dist.clone())],
911 )
912 })
913 .collect();
914
915 let (good, bad) = sampler.split_trials(&history);
916
917 // With gamma=0.25 and 20 trials, should have 5 good and 15 bad
918 assert_eq!(good.len(), 5);
919 assert_eq!(bad.len(), 15);
920
921 // Good trials should have lowest values
922 for trial in &good {
923 assert!(trial.value < 5.0);
924 }
925 }
926
927 #[test]
928 fn test_tpe_samples_float_with_history() {
929 let sampler = TpeSampler::with_config(0.25, 5, 24, None, Some(42)).unwrap();
930
931 let dist = Distribution::Float(FloatDistribution {
932 low: 0.0,
933 high: 1.0,
934 log_scale: false,
935 step: None,
936 });
937
938 // Create history where low values (near 0.2) are "good"
939 let x_id = ParamId::new();
940 let history: Vec<CompletedTrial> = (0..20)
941 .map(|i| {
942 let x = f64::from(i) / 20.0;
943 // Objective is (x - 0.2)^2, minimized at x=0.2
944 let value = (x - 0.2).powi(2);
945 create_trial(
946 i as u64,
947 value,
948 vec![(x_id, ParamValue::Float(x), dist.clone())],
949 )
950 })
951 .collect();
952
953 // TPE should bias toward values near 0.2
954 let mut samples = vec![];
955 for i in 0..100 {
956 let value = sampler.sample(&dist, 100 + i, &history);
957 if let ParamValue::Float(v) = value {
958 samples.push(v);
959 }
960 }
961
962 // Calculate mean of samples - should be closer to 0.2 than 0.5
963 let mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
964 assert!(
965 mean < 0.5,
966 "Mean {mean} should be less than 0.5 (biased toward good region near 0.2)"
967 );
968 }
969
970 #[test]
971 fn test_tpe_categorical_sampling() {
972 let sampler = TpeSampler::with_config(0.25, 5, 24, None, Some(42)).unwrap();
973
974 let dist = Distribution::Categorical(CategoricalDistribution { n_choices: 4 });
975
976 // Create history where category 1 is consistently good
977 let cat_id = ParamId::new();
978 let history: Vec<CompletedTrial> = (0..20)
979 .map(|i| {
980 let category = i % 4;
981 // Category 1 has best (lowest) objective value
982 let value = if category == 1 { 0.0 } else { 1.0 };
983 create_trial(
984 i as u64,
985 value,
986 vec![(
987 cat_id,
988 ParamValue::Categorical(category as usize),
989 dist.clone(),
990 )],
991 )
992 })
993 .collect();
994
995 // TPE should favor category 1
996 let mut counts = vec![0usize; 4];
997 for i in 0..100 {
998 let value = sampler.sample(&dist, 100 + i, &history);
999 if let ParamValue::Categorical(idx) = value {
1000 counts[idx] += 1;
1001 }
1002 }
1003
1004 // Category 1 should be sampled more often
1005 assert!(
1006 counts[1] > counts[0] && counts[1] > counts[2] && counts[1] > counts[3],
1007 "Category 1 should be most common: {counts:?}"
1008 );
1009 }
1010
1011 #[test]
1012 fn test_tpe_int_sampling() {
1013 let sampler = TpeSampler::with_config(0.25, 5, 24, None, Some(42)).unwrap();
1014
1015 let dist = Distribution::Int(IntDistribution {
1016 low: 0,
1017 high: 100,
1018 log_scale: false,
1019 step: None,
1020 });
1021
1022 // Create history where values near 30 are good
1023 let x_id = ParamId::new();
1024 let history: Vec<CompletedTrial> = (0..20)
1025 .map(|i| {
1026 let x = i * 5; // 0, 5, 10, ..., 95
1027 let value = ((x as f64) - 30.0).powi(2);
1028 create_trial(
1029 i as u64,
1030 value,
1031 vec![(x_id, ParamValue::Int(x), dist.clone())],
1032 )
1033 })
1034 .collect();
1035
1036 // TPE should bias toward values near 30
1037 for i in 0..50 {
1038 let value = sampler.sample(&dist, 100 + i, &history);
1039 if let ParamValue::Int(v) = value {
1040 assert!((0..=100).contains(&v), "Value {v} out of range");
1041 } else {
1042 panic!("Expected Int value");
1043 }
1044 }
1045 }
1046
1047 #[test]
1048 fn test_tpe_reproducibility() {
1049 let dist = Distribution::Float(FloatDistribution {
1050 low: 0.0,
1051 high: 1.0,
1052 log_scale: false,
1053 step: None,
1054 });
1055
1056 let x_id = ParamId::new();
1057 let history: Vec<CompletedTrial> = (0..20)
1058 .map(|i| {
1059 create_trial(
1060 i as u64,
1061 f64::from(i),
1062 vec![(x_id, ParamValue::Float(f64::from(i) / 20.0), dist.clone())],
1063 )
1064 })
1065 .collect();
1066
1067 let sampler1 = TpeSampler::with_config(0.25, 5, 24, None, Some(12345)).unwrap();
1068 let sampler2 = TpeSampler::with_config(0.25, 5, 24, None, Some(12345)).unwrap();
1069
1070 for i in 0..10 {
1071 let v1 = sampler1.sample(&dist, i, &history);
1072 let v2 = sampler2.sample(&dist, i, &history);
1073 assert_eq!(v1, v2, "Samples should be identical with same seed");
1074 }
1075 }
1076
1077 #[test]
1078 fn test_tpe_sampler_builder_default() {
1079 let builder = TpeSamplerBuilder::new();
1080 let sampler = builder.build().unwrap();
1081 assert!((sampler.gamma_strategy().gamma(0) - 0.25).abs() < f64::EPSILON);
1082 assert_eq!(sampler.n_startup_trials, 10);
1083 assert_eq!(sampler.n_ei_candidates, 24);
1084 }
1085
1086 #[test]
1087 fn test_tpe_sampler_builder_custom() {
1088 let sampler = TpeSamplerBuilder::new()
1089 .gamma(0.15)
1090 .n_startup_trials(20)
1091 .n_ei_candidates(32)
1092 .seed(42)
1093 .build()
1094 .unwrap();
1095 assert!((sampler.gamma_strategy().gamma(0) - 0.15).abs() < f64::EPSILON);
1096 assert_eq!(sampler.n_startup_trials, 20);
1097 assert_eq!(sampler.n_ei_candidates, 32);
1098 }
1099
1100 #[test]
1101 fn test_tpe_sampler_builder_via_sampler() {
1102 let sampler = TpeSampler::builder()
1103 .gamma(0.10)
1104 .n_startup_trials(15)
1105 .n_ei_candidates(48)
1106 .build()
1107 .unwrap();
1108 assert!((sampler.gamma_strategy().gamma(0) - 0.10).abs() < f64::EPSILON);
1109 assert_eq!(sampler.n_startup_trials, 15);
1110 assert_eq!(sampler.n_ei_candidates, 48);
1111 }
1112
1113 #[test]
1114 fn test_tpe_sampler_builder_partial() {
1115 // Test setting only some options
1116 let sampler = TpeSamplerBuilder::new().gamma(0.20).build().unwrap();
1117 assert!((sampler.gamma_strategy().gamma(0) - 0.20).abs() < f64::EPSILON);
1118 assert_eq!(sampler.n_startup_trials, 10); // default
1119 assert_eq!(sampler.n_ei_candidates, 24); // default
1120 }
1121
1122 #[test]
1123 fn test_tpe_sampler_builder_invalid_gamma() {
1124 let result = TpeSamplerBuilder::new().gamma(1.5).build();
1125 assert!(matches!(result, Err(Error::InvalidGamma(_))));
1126 }
1127
1128 #[test]
1129 fn test_tpe_sampler_builder_reproducibility() {
1130 let dist = Distribution::Float(FloatDistribution {
1131 low: 0.0,
1132 high: 1.0,
1133 log_scale: false,
1134 step: None,
1135 });
1136
1137 let x_id = ParamId::new();
1138 let history: Vec<CompletedTrial> = (0..20u32)
1139 .map(|i| {
1140 create_trial(
1141 u64::from(i),
1142 f64::from(i),
1143 vec![(x_id, ParamValue::Float(f64::from(i) / 20.0), dist.clone())],
1144 )
1145 })
1146 .collect();
1147
1148 let sampler1 = TpeSampler::builder()
1149 .seed(99999)
1150 .n_startup_trials(5)
1151 .build()
1152 .unwrap();
1153 let sampler2 = TpeSampler::builder()
1154 .seed(99999)
1155 .n_startup_trials(5)
1156 .build()
1157 .unwrap();
1158
1159 for i in 0..10 {
1160 let v1 = sampler1.sample(&dist, i, &history);
1161 let v2 = sampler2.sample(&dist, i, &history);
1162 assert_eq!(
1163 v1, v2,
1164 "Builder-created samplers with same seed should be identical"
1165 );
1166 }
1167 }
1168}