Skip to main content

optimizer/sampler/tpe/
sampler.rs

1//! Tree-Parzen Estimator (TPE) sampler implementation.
2//!
3//! TPE is a Bayesian optimization algorithm that models the objective function
4//! using two probability distributions: one for promising (good) parameter values
5//! and one for unpromising (bad) parameter values.
6//!
7//! # Gamma Strategies
8//!
9//! The gamma parameter controls what fraction of trials are considered "good".
10//! This module provides several built-in strategies via the [`GammaStrategy`] trait:
11//!
12//! - [`FixedGamma`]: Constant gamma value (default: 0.25)
13//! - [`LinearGamma`]: Linear interpolation between min and max based on trial count
14//! - [`SqrtGamma`]: Gamma decreases as 1/√n (similar to Optuna)
15//! - [`HyperoptGamma`]: Hyperopt-style adaptive gamma
16//!
17//! You can also implement your own strategy by implementing the [`GammaStrategy`] trait.
18//!
19//! # Examples
20//!
21//! Using a built-in gamma strategy:
22//!
23//! ```
24//! use optimizer::sampler::tpe::{SqrtGamma, TpeSampler};
25//!
26//! let sampler = TpeSampler::builder()
27//!     .gamma_strategy(SqrtGamma::default())
28//!     .build()
29//!     .unwrap();
30//! ```
31//!
32//! Implementing a custom gamma strategy:
33//!
34//! ```
35//! use optimizer::sampler::tpe::{GammaStrategy, TpeSampler};
36//!
37//! #[derive(Debug, Clone)]
38//! struct MyGamma {
39//!     base: f64,
40//! }
41//!
42//! impl GammaStrategy for MyGamma {
43//!     fn gamma(&self, n_trials: usize) -> f64 {
44//!         (self.base + 0.01 * n_trials as f64).min(0.5)
45//!     }
46//!
47//!     fn clone_box(&self) -> Box<dyn GammaStrategy> {
48//!         Box::new(self.clone())
49//!     }
50//! }
51//!
52//! let sampler = TpeSampler::builder()
53//!     .gamma_strategy(MyGamma { base: 0.1 })
54//!     .build()
55//!     .unwrap();
56//! ```
57
58use core::fmt::Debug;
59use core::sync::atomic::{AtomicU64, Ordering};
60use std::sync::Arc;
61
62use crate::distribution::Distribution;
63use crate::error::{Error, Result};
64use crate::param::ParamValue;
65use crate::rng_util;
66use crate::sampler::common;
67use crate::sampler::tpe::gamma::{FixedGamma, GammaStrategy};
68use crate::sampler::{CompletedTrial, Sampler};
69
70use super::common as tpe_common;
71
72// ============================================================================
73// Gamma Strategy Trait and Implementations
74// ============================================================================
75
76// ============================================================================
77// TPE Sampler
78// ============================================================================
79
80/// A Tree-Parzen Estimator (TPE) sampler for Bayesian optimization.
81///
82/// TPE works by splitting completed trials into two groups based on their
83/// objective values: good trials (below the gamma quantile) and bad trials
84/// (above the gamma quantile). It then fits kernel density estimators (KDE)
85/// to each group and samples new points that maximize the ratio l(x)/g(x),
86/// where l(x) is the density of good trials and g(x) is the density of bad trials.
87///
88/// During the startup phase (when fewer than `n_startup_trials` are completed),
89/// TPE falls back to random sampling to gather initial data.
90///
91/// # Gamma Strategies
92///
93/// The gamma quantile can be configured using different strategies via the
94/// [`GammaStrategy`] trait.
95///
96/// # Examples
97///
98/// ```
99/// use optimizer::sampler::tpe::TpeSampler;
100///
101/// // Create with default settings (FixedGamma at 0.25)
102/// let sampler = TpeSampler::new();
103///
104/// // Create with custom settings using the builder
105/// let sampler = TpeSampler::builder()
106///     .gamma(0.15)  // Shorthand for FixedGamma::new(0.15)
107///     .n_startup_trials(20)
108///     .n_ei_candidates(32)
109///     .seed(42)
110///     .build()
111///     .unwrap();
112/// ```
113///
114/// Using a different gamma strategy:
115///
116/// ```
117/// use optimizer::sampler::tpe::{SqrtGamma, TpeSampler};
118///
119/// let sampler = TpeSampler::builder()
120///     .gamma_strategy(SqrtGamma::default())
121///     .build()
122///     .unwrap();
123/// ```
124pub struct TpeSampler {
125    /// Strategy for computing the gamma quantile.
126    gamma_strategy: Arc<dyn GammaStrategy>,
127    /// Number of trials before TPE kicks in (uses random sampling before this).
128    n_startup_trials: usize,
129    /// Number of candidate samples to evaluate when selecting the next point.
130    n_ei_candidates: usize,
131    /// Optional fixed bandwidth for KDE. If None, uses Scott's rule.
132    kde_bandwidth: Option<f64>,
133    /// Base seed for deterministic per-call RNG derivation (no mutex needed).
134    seed: u64,
135    /// Monotonic counter to disambiguate calls with identical (`trial_id`, distribution).
136    call_seq: AtomicU64,
137}
138
139impl TpeSampler {
140    /// Creates a new TPE sampler with default settings.
141    ///
142    /// Default settings:
143    /// - gamma strategy: [`FixedGamma`] with gamma = 0.25
144    /// - `n_startup_trials`: 10 (random sampling for first 10 trials)
145    /// - `n_ei_candidates`: 24 (evaluate 24 candidates per sample)
146    /// - `kde_bandwidth`: None (uses Scott's rule for automatic bandwidth)
147    #[must_use]
148    pub fn new() -> Self {
149        Self {
150            gamma_strategy: Arc::new(FixedGamma::default()),
151            n_startup_trials: 10,
152            n_ei_candidates: 24,
153            kde_bandwidth: None,
154            seed: fastrand::u64(..),
155            call_seq: AtomicU64::new(0),
156        }
157    }
158
159    /// Creates a builder for configuring a TPE sampler.
160    ///
161    /// # Examples
162    ///
163    /// ```
164    /// use optimizer::sampler::tpe::TpeSampler;
165    ///
166    /// let sampler = TpeSampler::builder()
167    ///     .gamma(0.15)
168    ///     .n_startup_trials(20)
169    ///     .n_ei_candidates(32)
170    ///     .seed(42)
171    ///     .build()
172    ///     .unwrap();
173    /// ```
174    #[must_use]
175    pub fn builder() -> TpeSamplerBuilder {
176        TpeSamplerBuilder::new()
177    }
178
179    /// Creates a new TPE sampler with custom configuration.
180    ///
181    /// This method uses a fixed gamma value. For more advanced gamma strategies,
182    /// use [`TpeSampler::with_strategy`] or the builder pattern with
183    /// [`TpeSamplerBuilder::gamma_strategy`].
184    ///
185    /// # Arguments
186    ///
187    /// * `gamma` - Fraction of trials to consider "good" (0.0 to 1.0).
188    /// * `n_startup_trials` - Number of random trials before TPE sampling.
189    /// * `n_ei_candidates` - Number of candidates to evaluate per sample.
190    /// * `kde_bandwidth` - Optional fixed bandwidth for KDE. If None, uses Scott's rule.
191    /// * `seed` - Optional seed for reproducibility.
192    ///
193    /// # Errors
194    ///
195    /// Returns `Error::InvalidGamma` if gamma is not in (0.0, 1.0).
196    /// Returns `Error::InvalidBandwidth` if `kde_bandwidth` is Some but not positive.
197    pub fn with_config(
198        gamma: f64,
199        n_startup_trials: usize,
200        n_ei_candidates: usize,
201        kde_bandwidth: Option<f64>,
202        seed: Option<u64>,
203    ) -> Result<Self> {
204        let gamma_strategy = FixedGamma::new(gamma)?;
205        Self::with_strategy(
206            gamma_strategy,
207            n_startup_trials,
208            n_ei_candidates,
209            kde_bandwidth,
210            seed,
211        )
212    }
213
214    /// Creates a new TPE sampler with a custom gamma strategy.
215    ///
216    /// # Arguments
217    ///
218    /// * `gamma_strategy` - The strategy for computing the gamma quantile.
219    /// * `n_startup_trials` - Number of random trials before TPE sampling.
220    /// * `n_ei_candidates` - Number of candidates to evaluate per sample.
221    /// * `kde_bandwidth` - Optional fixed bandwidth for KDE. If None, uses Scott's rule.
222    /// * `seed` - Optional seed for reproducibility.
223    ///
224    /// # Errors
225    ///
226    /// Returns `Error::InvalidBandwidth` if `kde_bandwidth` is Some but not positive.
227    ///
228    /// # Examples
229    ///
230    /// ```
231    /// use optimizer::sampler::tpe::{SqrtGamma, TpeSampler};
232    ///
233    /// let sampler = TpeSampler::with_strategy(
234    ///     SqrtGamma::default(),
235    ///     10,       // n_startup_trials
236    ///     24,       // n_ei_candidates
237    ///     None,     // kde_bandwidth
238    ///     Some(42), // seed
239    /// )
240    /// .unwrap();
241    /// ```
242    pub fn with_strategy<G: GammaStrategy + 'static>(
243        gamma_strategy: G,
244        n_startup_trials: usize,
245        n_ei_candidates: usize,
246        kde_bandwidth: Option<f64>,
247        seed: Option<u64>,
248    ) -> Result<Self> {
249        if let Some(bw) = kde_bandwidth
250            && bw <= 0.0
251        {
252            return Err(Error::InvalidBandwidth(bw));
253        }
254
255        Ok(Self {
256            gamma_strategy: Arc::new(gamma_strategy),
257            n_startup_trials,
258            n_ei_candidates,
259            kde_bandwidth,
260            seed: seed.unwrap_or_else(|| fastrand::u64(..)),
261            call_seq: AtomicU64::new(0),
262        })
263    }
264
265    /// Returns the gamma strategy used by this sampler.
266    #[must_use]
267    pub fn gamma_strategy(&self) -> &dyn GammaStrategy {
268        self.gamma_strategy.as_ref()
269    }
270
271    /// Splits trials into good and bad groups based on the gamma quantile.
272    ///
273    /// The gamma value is computed dynamically using the configured [`GammaStrategy`].
274    ///
275    /// Returns (`good_trials`, `bad_trials`) where `good_trials` contains trials
276    /// with values below the gamma quantile (for minimization).
277    #[allow(
278        clippy::cast_precision_loss,
279        clippy::cast_possible_truncation,
280        clippy::cast_sign_loss
281    )]
282    #[must_use]
283    fn split_trials<'a>(
284        &self,
285        history: &'a [CompletedTrial],
286    ) -> (Vec<&'a CompletedTrial>, Vec<&'a CompletedTrial>) {
287        if history.is_empty() {
288            return (vec![], vec![]);
289        }
290
291        // Compute gamma using the strategy and clamp to valid range
292        let gamma = self
293            .gamma_strategy
294            .gamma(history.len())
295            .clamp(f64::EPSILON, 1.0 - f64::EPSILON);
296
297        // Calculate the split point (gamma quantile)
298        // Ensure at least 1 trial in each group if possible
299        let n_good = ((history.len() as f64 * gamma).ceil() as usize)
300            .max(1)
301            .min(history.len() - 1);
302
303        // Use quickselect (O(n)) to partition indices instead of full sort (O(n log n)).
304        // We only need to know which trials are in the top gamma-quantile, not their order.
305        let mut indices: Vec<usize> = (0..history.len()).collect();
306        if n_good > 0 {
307            indices.select_nth_unstable_by(n_good - 1, |&a, &b| {
308                history[a]
309                    .value
310                    .partial_cmp(&history[b].value)
311                    .unwrap_or(core::cmp::Ordering::Equal)
312            });
313        }
314
315        let good: Vec<_> = indices[..n_good].iter().map(|&i| &history[i]).collect();
316        let bad: Vec<_> = indices[n_good..].iter().map(|&i| &history[i]).collect();
317
318        (good, bad)
319    }
320}
321
322impl Default for TpeSampler {
323    fn default() -> Self {
324        Self::new()
325    }
326}
327
328/// Builder for configuring a [`TpeSampler`].
329///
330/// This builder allows fluent configuration of TPE hyperparameters.
331///
332/// # Examples
333///
334/// Using a fixed gamma value:
335///
336/// ```
337/// use optimizer::sampler::tpe::TpeSamplerBuilder;
338///
339/// let sampler = TpeSamplerBuilder::new()
340///     .gamma(0.15)
341///     .n_startup_trials(20)
342///     .n_ei_candidates(32)
343///     .seed(42)
344///     .build()
345///     .unwrap();
346/// ```
347///
348/// Using a custom gamma strategy:
349///
350/// ```
351/// use optimizer::sampler::tpe::{SqrtGamma, TpeSamplerBuilder};
352///
353/// let sampler = TpeSamplerBuilder::new()
354///     .gamma_strategy(SqrtGamma::default())
355///     .n_startup_trials(20)
356///     .build()
357///     .unwrap();
358/// ```
359#[derive(Debug, Clone)]
360pub struct TpeSamplerBuilder {
361    gamma_strategy: Box<dyn GammaStrategy>,
362    /// Raw gamma value for deferred validation (Some if `gamma()` was called)
363    raw_gamma: Option<f64>,
364    n_startup_trials: usize,
365    n_ei_candidates: usize,
366    kde_bandwidth: Option<f64>,
367    seed: Option<u64>,
368}
369
370impl TpeSamplerBuilder {
371    /// Creates a new builder with default settings.
372    ///
373    /// Default settings:
374    /// - gamma strategy: [`FixedGamma`] with gamma = 0.25
375    /// - `n_startup_trials`: 10 (random sampling for first 10 trials)
376    /// - `n_ei_candidates`: 24 (evaluate 24 candidates per sample)
377    /// - `kde_bandwidth`: None (uses Scott's rule for automatic bandwidth)
378    /// - seed: None (use OS-provided entropy)
379    #[must_use]
380    pub fn new() -> Self {
381        Self {
382            gamma_strategy: Box::new(FixedGamma::default()),
383            raw_gamma: None,
384            n_startup_trials: 10,
385            n_ei_candidates: 24,
386            kde_bandwidth: None,
387            seed: None,
388        }
389    }
390
391    /// Sets a fixed gamma value for splitting trials into good/bad groups.
392    ///
393    /// This is a convenience method that creates a [`FixedGamma`] strategy.
394    /// For more advanced gamma strategies, use [`gamma_strategy`](Self::gamma_strategy).
395    ///
396    /// A gamma of 0.25 means the top 25% of trials (by objective value) are
397    /// considered "good" and used to build the l(x) distribution.
398    ///
399    /// # Arguments
400    ///
401    /// * `gamma` - Quantile value, must be in (0.0, 1.0).
402    ///
403    /// # Examples
404    ///
405    /// ```
406    /// use optimizer::sampler::tpe::TpeSamplerBuilder;
407    ///
408    /// let sampler = TpeSamplerBuilder::new()
409    ///     .gamma(0.10)  // Use top 10% as "good" trials
410    ///     .build()
411    ///     .unwrap();
412    /// ```
413    ///
414    /// # Note
415    ///
416    /// Validation happens at `build()` time. If gamma is not in (0.0, 1.0),
417    /// `build()` will return `Err(Error::InvalidGamma)`.
418    #[must_use]
419    pub fn gamma(mut self, gamma: f64) -> Self {
420        // We defer validation to build() time for consistency with the existing API
421        // Store the raw value for validation later
422        self.raw_gamma = Some(gamma);
423        self
424    }
425
426    /// Sets a custom gamma strategy for splitting trials into good/bad groups.
427    ///
428    /// The gamma strategy determines what fraction of trials are considered
429    /// "good" based on the number of completed trials. This allows the gamma
430    /// value to adapt dynamically during optimization.
431    ///
432    /// # Arguments
433    ///
434    /// * `strategy` - A type implementing [`GammaStrategy`].
435    ///
436    /// # Examples
437    ///
438    /// Using built-in strategies:
439    ///
440    /// ```
441    /// use optimizer::sampler::tpe::{LinearGamma, SqrtGamma, TpeSamplerBuilder};
442    ///
443    /// // Square root strategy (Optuna-style)
444    /// let sampler = TpeSamplerBuilder::new()
445    ///     .gamma_strategy(SqrtGamma::default())
446    ///     .build()
447    ///     .unwrap();
448    ///
449    /// // Linear interpolation strategy
450    /// let sampler = TpeSamplerBuilder::new()
451    ///     .gamma_strategy(LinearGamma::new(0.1, 0.3, 50).unwrap())
452    ///     .build()
453    ///     .unwrap();
454    /// ```
455    ///
456    /// Using a custom strategy:
457    ///
458    /// ```
459    /// use optimizer::sampler::tpe::{GammaStrategy, TpeSamplerBuilder};
460    ///
461    /// #[derive(Debug, Clone)]
462    /// struct MyGamma;
463    ///
464    /// impl GammaStrategy for MyGamma {
465    ///     fn gamma(&self, n_trials: usize) -> f64 {
466    ///         0.25 // Always return 0.25
467    ///     }
468    ///     fn clone_box(&self) -> Box<dyn GammaStrategy> {
469    ///         Box::new(self.clone())
470    ///     }
471    /// }
472    ///
473    /// let sampler = TpeSamplerBuilder::new()
474    ///     .gamma_strategy(MyGamma)
475    ///     .build()
476    ///     .unwrap();
477    /// ```
478    #[must_use]
479    pub fn gamma_strategy<G: GammaStrategy + 'static>(mut self, strategy: G) -> Self {
480        self.gamma_strategy = Box::new(strategy);
481        self.raw_gamma = None; // Clear any raw gamma set by gamma()
482        self
483    }
484
485    /// Sets the number of startup trials before TPE sampling begins.
486    ///
487    /// During the startup phase, the sampler uses uniform random sampling
488    /// to gather initial data. Once `n_startup_trials` have completed,
489    /// TPE-based sampling begins.
490    ///
491    /// # Arguments
492    ///
493    /// * `n` - Number of random trials before TPE kicks in.
494    ///
495    /// # Examples
496    ///
497    /// ```
498    /// use optimizer::sampler::tpe::TpeSamplerBuilder;
499    ///
500    /// let sampler = TpeSamplerBuilder::new()
501    ///     .n_startup_trials(20)  // Random sample first 20 trials
502    ///     .build()
503    ///     .unwrap();
504    /// ```
505    #[must_use]
506    pub fn n_startup_trials(mut self, n: usize) -> Self {
507        self.n_startup_trials = n;
508        self
509    }
510
511    /// Sets the number of EI (Expected Improvement) candidates to evaluate.
512    ///
513    /// When sampling a new point, TPE generates this many candidates from
514    /// the l(x) distribution and selects the one with the highest l(x)/g(x)
515    /// ratio.
516    ///
517    /// # Arguments
518    ///
519    /// * `n` - Number of candidates to evaluate per sample.
520    ///
521    /// # Examples
522    ///
523    /// ```
524    /// use optimizer::sampler::tpe::TpeSamplerBuilder;
525    ///
526    /// let sampler = TpeSamplerBuilder::new()
527    ///     .n_ei_candidates(48)  // Evaluate more candidates
528    ///     .build()
529    ///     .unwrap();
530    /// ```
531    #[must_use]
532    pub fn n_ei_candidates(mut self, n: usize) -> Self {
533        self.n_ei_candidates = n;
534        self
535    }
536
537    /// Sets a fixed bandwidth for the kernel density estimator.
538    ///
539    /// By default, TPE uses Scott's rule to automatically select the bandwidth
540    /// based on the sample data. Use this method to override with a fixed value.
541    ///
542    /// Smaller bandwidths give more localized, peaky distributions.
543    /// Larger bandwidths give smoother, more spread-out distributions.
544    ///
545    /// # Arguments
546    ///
547    /// * `bandwidth` - The fixed bandwidth (standard deviation) for Gaussian kernels.
548    ///
549    /// # Examples
550    ///
551    /// ```
552    /// use optimizer::sampler::tpe::TpeSamplerBuilder;
553    ///
554    /// let sampler = TpeSamplerBuilder::new()
555    ///     .kde_bandwidth(0.5)  // Fixed bandwidth of 0.5
556    ///     .build()
557    ///     .unwrap();
558    /// ```
559    ///
560    /// # Note
561    ///
562    /// Validation happens at `build()` time. If bandwidth is not positive,
563    /// `build()` will return `Err(Error::InvalidBandwidth)`.
564    #[must_use]
565    pub fn kde_bandwidth(mut self, bandwidth: f64) -> Self {
566        self.kde_bandwidth = Some(bandwidth);
567        self
568    }
569
570    /// Sets a seed for reproducible sampling.
571    ///
572    /// # Arguments
573    ///
574    /// * `seed` - Seed value for the random number generator.
575    ///
576    /// # Examples
577    ///
578    /// ```
579    /// use optimizer::sampler::tpe::TpeSamplerBuilder;
580    ///
581    /// let sampler = TpeSamplerBuilder::new()
582    ///     .seed(42)  // Reproducible results
583    ///     .build()
584    ///     .unwrap();
585    /// ```
586    #[must_use]
587    pub fn seed(mut self, seed: u64) -> Self {
588        self.seed = Some(seed);
589        self
590    }
591
592    /// Builds the configured [`TpeSampler`].
593    ///
594    /// # Errors
595    ///
596    /// Returns `Error::InvalidGamma` if a fixed gamma value was set and is not in (0.0, 1.0).
597    /// Returns `Error::InvalidBandwidth` if `kde_bandwidth` is Some but not positive.
598    ///
599    /// # Examples
600    ///
601    /// ```
602    /// use optimizer::sampler::tpe::TpeSamplerBuilder;
603    ///
604    /// let sampler = TpeSamplerBuilder::new()
605    ///     .gamma(0.15)
606    ///     .n_startup_trials(20)
607    ///     .n_ei_candidates(32)
608    ///     .seed(42)
609    ///     .build()
610    ///     .unwrap();
611    /// ```
612    pub fn build(self) -> Result<TpeSampler> {
613        // Determine the gamma strategy to use
614        let gamma_strategy: Arc<dyn GammaStrategy> = if let Some(raw) = self.raw_gamma {
615            // Validate and create FixedGamma from raw value
616            Arc::new(FixedGamma::new(raw)?)
617        } else {
618            Arc::from(self.gamma_strategy)
619        };
620
621        // Validate bandwidth
622        if let Some(bw) = self.kde_bandwidth
623            && bw <= 0.0
624        {
625            return Err(Error::InvalidBandwidth(bw));
626        }
627
628        Ok(TpeSampler {
629            gamma_strategy,
630            n_startup_trials: self.n_startup_trials,
631            n_ei_candidates: self.n_ei_candidates,
632            kde_bandwidth: self.kde_bandwidth,
633            seed: self.seed.unwrap_or_else(|| fastrand::u64(..)),
634            call_seq: AtomicU64::new(0),
635        })
636    }
637}
638
639impl Default for TpeSamplerBuilder {
640    fn default() -> Self {
641        Self::new()
642    }
643}
644
645/// Deterministically pick a parameter value matching `target_dist` from a
646/// trial. When multiple parameters share the same distribution (e.g. two
647/// `FloatParam::new(-5.0, 5.0)` in the same study), the one with the
648/// smallest [`ParamId`] is chosen so behavior does not depend on
649/// `HashMap` iteration order.
650fn find_matching_value<'t>(
651    t: &'t CompletedTrial,
652    target_dist: &Distribution,
653) -> Option<&'t ParamValue> {
654    t.distributions
655        .iter()
656        .filter(|(_, dist)| *dist == target_dist)
657        .min_by_key(|(id, _)| *id)
658        .and_then(|(id, _)| t.params.get(id))
659}
660
661impl TpeSampler {
662    fn sample_float(
663        &self,
664        d: &crate::distribution::FloatDistribution,
665        good_trials: &[&CompletedTrial],
666        bad_trials: &[&CompletedTrial],
667        rng: &mut fastrand::Rng,
668    ) -> ParamValue {
669        let target_dist = Distribution::Float(d.clone());
670        let good_values: Vec<f64> = good_trials
671            .iter()
672            .filter_map(|t| match find_matching_value(t, &target_dist)? {
673                ParamValue::Float(f) => Some(*f),
674                _ => None,
675            })
676            .collect();
677
678        let bad_values: Vec<f64> = bad_trials
679            .iter()
680            .filter_map(|t| match find_matching_value(t, &target_dist)? {
681                ParamValue::Float(f) => Some(*f),
682                _ => None,
683            })
684            .collect();
685
686        if good_values.is_empty() || bad_values.is_empty() {
687            return ParamValue::Float(rng_util::f64_range(rng, d.low, d.high));
688        }
689
690        let value = tpe_common::sample_tpe_float(
691            d,
692            good_values,
693            bad_values,
694            self.n_ei_candidates,
695            self.kde_bandwidth,
696            rng,
697        );
698        ParamValue::Float(value)
699    }
700
701    fn sample_int(
702        &self,
703        d: &crate::distribution::IntDistribution,
704        good_trials: &[&CompletedTrial],
705        bad_trials: &[&CompletedTrial],
706        rng: &mut fastrand::Rng,
707    ) -> ParamValue {
708        let target_dist = Distribution::Int(d.clone());
709        let good_values: Vec<i64> = good_trials
710            .iter()
711            .filter_map(|t| match find_matching_value(t, &target_dist)? {
712                ParamValue::Int(i) => Some(*i),
713                _ => None,
714            })
715            .collect();
716
717        let bad_values: Vec<i64> = bad_trials
718            .iter()
719            .filter_map(|t| match find_matching_value(t, &target_dist)? {
720                ParamValue::Int(i) => Some(*i),
721                _ => None,
722            })
723            .collect();
724
725        if good_values.is_empty() || bad_values.is_empty() {
726            return common::sample_random(rng, &Distribution::Int(d.clone()));
727        }
728
729        let value = tpe_common::sample_tpe_int(
730            d,
731            good_values,
732            bad_values,
733            self.n_ei_candidates,
734            self.kde_bandwidth,
735            rng,
736        );
737        ParamValue::Int(value)
738    }
739
740    #[allow(clippy::unused_self)]
741    fn sample_categorical(
742        &self,
743        d: &crate::distribution::CategoricalDistribution,
744        good_trials: &[&CompletedTrial],
745        bad_trials: &[&CompletedTrial],
746        rng: &mut fastrand::Rng,
747    ) -> ParamValue {
748        let target_dist = Distribution::Categorical(d.clone());
749        let good_indices: Vec<usize> = good_trials
750            .iter()
751            .filter_map(|t| match find_matching_value(t, &target_dist)? {
752                ParamValue::Categorical(i) => Some(*i),
753                _ => None,
754            })
755            .collect();
756
757        let bad_indices: Vec<usize> = bad_trials
758            .iter()
759            .filter_map(|t| match find_matching_value(t, &target_dist)? {
760                ParamValue::Categorical(i) => Some(*i),
761                _ => None,
762            })
763            .collect();
764
765        if good_indices.is_empty() || bad_indices.is_empty() {
766            return common::sample_random(rng, &Distribution::Categorical(d.clone()));
767        }
768
769        let index =
770            tpe_common::sample_tpe_categorical(d.n_choices, &good_indices, &bad_indices, rng);
771        ParamValue::Categorical(index)
772    }
773}
774
775impl Sampler for TpeSampler {
776    fn sample(
777        &self,
778        distribution: &Distribution,
779        trial_id: u64,
780        history: &[CompletedTrial],
781    ) -> ParamValue {
782        let seq = self.call_seq.fetch_add(1, Ordering::Relaxed);
783        let mut rng = fastrand::Rng::with_seed(rng_util::mix_seed(
784            self.seed,
785            trial_id,
786            rng_util::distribution_fingerprint(distribution).wrapping_add(seq),
787        ));
788
789        // Fall back to random sampling during startup phase
790        if history.len() < self.n_startup_trials {
791            return common::sample_random(&mut rng, distribution);
792        }
793
794        // Split trials into good and bad groups
795        let (good_trials, bad_trials) = self.split_trials(history);
796
797        // Need at least 1 trial in each group for TPE
798        if good_trials.is_empty() || bad_trials.is_empty() {
799            return common::sample_random(&mut rng, distribution);
800        }
801
802        match distribution {
803            Distribution::Float(d) => self.sample_float(d, &good_trials, &bad_trials, &mut rng),
804            Distribution::Int(d) => self.sample_int(d, &good_trials, &bad_trials, &mut rng),
805            Distribution::Categorical(d) => {
806                self.sample_categorical(d, &good_trials, &bad_trials, &mut rng)
807            }
808        }
809    }
810}
811
812#[cfg(test)]
813#[allow(
814    clippy::similar_names,
815    clippy::cast_sign_loss,
816    clippy::cast_precision_loss
817)]
818mod tests {
819    use std::collections::HashMap;
820
821    use super::*;
822    use crate::distribution::{CategoricalDistribution, FloatDistribution, IntDistribution};
823    use crate::parameter::ParamId;
824
825    fn create_trial(
826        id: u64,
827        value: f64,
828        params: Vec<(ParamId, ParamValue, Distribution)>,
829    ) -> CompletedTrial {
830        let mut param_map = HashMap::new();
831        let mut dist_map = HashMap::new();
832        for (param_id, pv, dist) in params {
833            param_map.insert(param_id, pv);
834            dist_map.insert(param_id, dist);
835        }
836        CompletedTrial::new(id, param_map, dist_map, HashMap::new(), value)
837    }
838
839    #[test]
840    fn test_tpe_sampler_new() {
841        let sampler = TpeSampler::new();
842        // Default uses FixedGamma with 0.25
843        assert!((sampler.gamma_strategy().gamma(0) - 0.25).abs() < f64::EPSILON);
844        assert_eq!(sampler.n_startup_trials, 10);
845        assert_eq!(sampler.n_ei_candidates, 24);
846    }
847
848    #[test]
849    fn test_tpe_sampler_with_config() {
850        let sampler = TpeSampler::with_config(0.15, 20, 32, None, Some(42)).unwrap();
851        // with_config uses FixedGamma
852        assert!((sampler.gamma_strategy().gamma(0) - 0.15).abs() < f64::EPSILON);
853        assert_eq!(sampler.n_startup_trials, 20);
854        assert_eq!(sampler.n_ei_candidates, 32);
855    }
856
857    #[test]
858    fn test_tpe_sampler_invalid_gamma_zero() {
859        let result = TpeSampler::with_config(0.0, 10, 24, None, None);
860        assert!(matches!(result, Err(Error::InvalidGamma(_))));
861    }
862
863    #[test]
864    fn test_tpe_sampler_invalid_gamma_one() {
865        let result = TpeSampler::with_config(1.0, 10, 24, None, None);
866        assert!(matches!(result, Err(Error::InvalidGamma(_))));
867    }
868
869    #[test]
870    fn test_tpe_startup_random_sampling() {
871        let sampler = TpeSampler::with_config(0.25, 10, 24, None, Some(42)).unwrap();
872        let dist = Distribution::Float(FloatDistribution {
873            low: 0.0,
874            high: 1.0,
875            log_scale: false,
876            step: None,
877        });
878
879        // With fewer than n_startup_trials, should use random sampling
880        let history: Vec<CompletedTrial> = vec![];
881
882        for i in 0..100 {
883            let value = sampler.sample(&dist, i, &history);
884            if let ParamValue::Float(v) = value {
885                assert!((0.0..=1.0).contains(&v));
886            } else {
887                panic!("Expected Float value");
888            }
889        }
890    }
891
892    #[test]
893    fn test_tpe_split_trials() {
894        let sampler = TpeSampler::with_config(0.25, 10, 24, None, Some(42)).unwrap();
895
896        let dist = Distribution::Float(FloatDistribution {
897            low: 0.0,
898            high: 1.0,
899            log_scale: false,
900            step: None,
901        });
902
903        // Create 20 trials with values 0..20
904        let x_id = ParamId::new();
905        let history: Vec<CompletedTrial> = (0..20)
906            .map(|i| {
907                create_trial(
908                    i as u64,
909                    f64::from(i),
910                    vec![(x_id, ParamValue::Float(f64::from(i) / 20.0), dist.clone())],
911                )
912            })
913            .collect();
914
915        let (good, bad) = sampler.split_trials(&history);
916
917        // With gamma=0.25 and 20 trials, should have 5 good and 15 bad
918        assert_eq!(good.len(), 5);
919        assert_eq!(bad.len(), 15);
920
921        // Good trials should have lowest values
922        for trial in &good {
923            assert!(trial.value < 5.0);
924        }
925    }
926
927    #[test]
928    fn test_tpe_samples_float_with_history() {
929        let sampler = TpeSampler::with_config(0.25, 5, 24, None, Some(42)).unwrap();
930
931        let dist = Distribution::Float(FloatDistribution {
932            low: 0.0,
933            high: 1.0,
934            log_scale: false,
935            step: None,
936        });
937
938        // Create history where low values (near 0.2) are "good"
939        let x_id = ParamId::new();
940        let history: Vec<CompletedTrial> = (0..20)
941            .map(|i| {
942                let x = f64::from(i) / 20.0;
943                // Objective is (x - 0.2)^2, minimized at x=0.2
944                let value = (x - 0.2).powi(2);
945                create_trial(
946                    i as u64,
947                    value,
948                    vec![(x_id, ParamValue::Float(x), dist.clone())],
949                )
950            })
951            .collect();
952
953        // TPE should bias toward values near 0.2
954        let mut samples = vec![];
955        for i in 0..100 {
956            let value = sampler.sample(&dist, 100 + i, &history);
957            if let ParamValue::Float(v) = value {
958                samples.push(v);
959            }
960        }
961
962        // Calculate mean of samples - should be closer to 0.2 than 0.5
963        let mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
964        assert!(
965            mean < 0.5,
966            "Mean {mean} should be less than 0.5 (biased toward good region near 0.2)"
967        );
968    }
969
970    #[test]
971    fn test_tpe_categorical_sampling() {
972        let sampler = TpeSampler::with_config(0.25, 5, 24, None, Some(42)).unwrap();
973
974        let dist = Distribution::Categorical(CategoricalDistribution { n_choices: 4 });
975
976        // Create history where category 1 is consistently good
977        let cat_id = ParamId::new();
978        let history: Vec<CompletedTrial> = (0..20)
979            .map(|i| {
980                let category = i % 4;
981                // Category 1 has best (lowest) objective value
982                let value = if category == 1 { 0.0 } else { 1.0 };
983                create_trial(
984                    i as u64,
985                    value,
986                    vec![(
987                        cat_id,
988                        ParamValue::Categorical(category as usize),
989                        dist.clone(),
990                    )],
991                )
992            })
993            .collect();
994
995        // TPE should favor category 1
996        let mut counts = vec![0usize; 4];
997        for i in 0..100 {
998            let value = sampler.sample(&dist, 100 + i, &history);
999            if let ParamValue::Categorical(idx) = value {
1000                counts[idx] += 1;
1001            }
1002        }
1003
1004        // Category 1 should be sampled more often
1005        assert!(
1006            counts[1] > counts[0] && counts[1] > counts[2] && counts[1] > counts[3],
1007            "Category 1 should be most common: {counts:?}"
1008        );
1009    }
1010
1011    #[test]
1012    fn test_tpe_int_sampling() {
1013        let sampler = TpeSampler::with_config(0.25, 5, 24, None, Some(42)).unwrap();
1014
1015        let dist = Distribution::Int(IntDistribution {
1016            low: 0,
1017            high: 100,
1018            log_scale: false,
1019            step: None,
1020        });
1021
1022        // Create history where values near 30 are good
1023        let x_id = ParamId::new();
1024        let history: Vec<CompletedTrial> = (0..20)
1025            .map(|i| {
1026                let x = i * 5; // 0, 5, 10, ..., 95
1027                let value = ((x as f64) - 30.0).powi(2);
1028                create_trial(
1029                    i as u64,
1030                    value,
1031                    vec![(x_id, ParamValue::Int(x), dist.clone())],
1032                )
1033            })
1034            .collect();
1035
1036        // TPE should bias toward values near 30
1037        for i in 0..50 {
1038            let value = sampler.sample(&dist, 100 + i, &history);
1039            if let ParamValue::Int(v) = value {
1040                assert!((0..=100).contains(&v), "Value {v} out of range");
1041            } else {
1042                panic!("Expected Int value");
1043            }
1044        }
1045    }
1046
1047    #[test]
1048    fn test_tpe_reproducibility() {
1049        let dist = Distribution::Float(FloatDistribution {
1050            low: 0.0,
1051            high: 1.0,
1052            log_scale: false,
1053            step: None,
1054        });
1055
1056        let x_id = ParamId::new();
1057        let history: Vec<CompletedTrial> = (0..20)
1058            .map(|i| {
1059                create_trial(
1060                    i as u64,
1061                    f64::from(i),
1062                    vec![(x_id, ParamValue::Float(f64::from(i) / 20.0), dist.clone())],
1063                )
1064            })
1065            .collect();
1066
1067        let sampler1 = TpeSampler::with_config(0.25, 5, 24, None, Some(12345)).unwrap();
1068        let sampler2 = TpeSampler::with_config(0.25, 5, 24, None, Some(12345)).unwrap();
1069
1070        for i in 0..10 {
1071            let v1 = sampler1.sample(&dist, i, &history);
1072            let v2 = sampler2.sample(&dist, i, &history);
1073            assert_eq!(v1, v2, "Samples should be identical with same seed");
1074        }
1075    }
1076
1077    #[test]
1078    fn test_tpe_sampler_builder_default() {
1079        let builder = TpeSamplerBuilder::new();
1080        let sampler = builder.build().unwrap();
1081        assert!((sampler.gamma_strategy().gamma(0) - 0.25).abs() < f64::EPSILON);
1082        assert_eq!(sampler.n_startup_trials, 10);
1083        assert_eq!(sampler.n_ei_candidates, 24);
1084    }
1085
1086    #[test]
1087    fn test_tpe_sampler_builder_custom() {
1088        let sampler = TpeSamplerBuilder::new()
1089            .gamma(0.15)
1090            .n_startup_trials(20)
1091            .n_ei_candidates(32)
1092            .seed(42)
1093            .build()
1094            .unwrap();
1095        assert!((sampler.gamma_strategy().gamma(0) - 0.15).abs() < f64::EPSILON);
1096        assert_eq!(sampler.n_startup_trials, 20);
1097        assert_eq!(sampler.n_ei_candidates, 32);
1098    }
1099
1100    #[test]
1101    fn test_tpe_sampler_builder_via_sampler() {
1102        let sampler = TpeSampler::builder()
1103            .gamma(0.10)
1104            .n_startup_trials(15)
1105            .n_ei_candidates(48)
1106            .build()
1107            .unwrap();
1108        assert!((sampler.gamma_strategy().gamma(0) - 0.10).abs() < f64::EPSILON);
1109        assert_eq!(sampler.n_startup_trials, 15);
1110        assert_eq!(sampler.n_ei_candidates, 48);
1111    }
1112
1113    #[test]
1114    fn test_tpe_sampler_builder_partial() {
1115        // Test setting only some options
1116        let sampler = TpeSamplerBuilder::new().gamma(0.20).build().unwrap();
1117        assert!((sampler.gamma_strategy().gamma(0) - 0.20).abs() < f64::EPSILON);
1118        assert_eq!(sampler.n_startup_trials, 10); // default
1119        assert_eq!(sampler.n_ei_candidates, 24); // default
1120    }
1121
1122    #[test]
1123    fn test_tpe_sampler_builder_invalid_gamma() {
1124        let result = TpeSamplerBuilder::new().gamma(1.5).build();
1125        assert!(matches!(result, Err(Error::InvalidGamma(_))));
1126    }
1127
1128    #[test]
1129    fn test_tpe_sampler_builder_reproducibility() {
1130        let dist = Distribution::Float(FloatDistribution {
1131            low: 0.0,
1132            high: 1.0,
1133            log_scale: false,
1134            step: None,
1135        });
1136
1137        let x_id = ParamId::new();
1138        let history: Vec<CompletedTrial> = (0..20u32)
1139            .map(|i| {
1140                create_trial(
1141                    u64::from(i),
1142                    f64::from(i),
1143                    vec![(x_id, ParamValue::Float(f64::from(i) / 20.0), dist.clone())],
1144                )
1145            })
1146            .collect();
1147
1148        let sampler1 = TpeSampler::builder()
1149            .seed(99999)
1150            .n_startup_trials(5)
1151            .build()
1152            .unwrap();
1153        let sampler2 = TpeSampler::builder()
1154            .seed(99999)
1155            .n_startup_trials(5)
1156            .build()
1157            .unwrap();
1158
1159        for i in 0..10 {
1160            let v1 = sampler1.sample(&dist, i, &history);
1161            let v2 = sampler2.sample(&dist, i, &history);
1162            assert_eq!(
1163                v1, v2,
1164                "Builder-created samplers with same seed should be identical"
1165            );
1166        }
1167    }
1168}