Skip to main content

optimizer/sampler/
mod.rs

1//! Sampler trait and implementations for parameter sampling.
2//!
3//! A sampler generates parameter values for each trial. It receives a
4//! [`Distribution`] describing the parameter space, a monotonically increasing
5//! `trial_id`, and the list of all [`CompletedTrial`]s so far, and returns a
6//! [`ParamValue`] that matches the distribution variant.
7//!
8//! # Available samplers
9//!
10//! ## Single-objective
11//!
12//! | Sampler | Algorithm | Best for |
13//! |---------|-----------|----------|
14//! | [`RandomSampler`] | Uniform independent sampling | Baselines, startup phases |
15//! | [`TpeSampler`] | Tree-Parzen Estimator | General-purpose Bayesian optimization |
16//! | [`TpeSampler`] (multivariate) | Multivariate TPE with tree-structured Parzen | Correlated parameters |
17//! | [`GridSampler`] | Exhaustive grid evaluation | Small discrete spaces |
18//! | [`SobolSampler`]\* | Quasi-random Sobol sequences | Uniform coverage without model |
19//! | [`CmaEsSampler`]\* | Covariance Matrix Adaptation | Continuous, non-separable problems |
20//! | [`GpSampler`]\* | Gaussian Process with EI | Expensive, low-dimensional functions |
21//! | [`DESampler`] | Differential Evolution | Population-based, multi-modal landscapes |
22//! | [`BohbSampler`] | Bayesian Optimization + `HyperBand` | Combined sampling and pruning |
23//!
24//! \*Requires a feature flag (`sobol`, `cma-es`, or `gp`).
25//!
26//! ## Multi-objective
27//!
28//! | Sampler | Algorithm | Best for |
29//! |---------|-----------|----------|
30//! | [`Nsga2Sampler`] | NSGA-II | General multi-objective with 2-3 objectives |
31//! | [`Nsga3Sampler`] | NSGA-III | Many-objective (4+ objectives) |
32//! | [`MoeadSampler`] | MOEA/D with decomposition | Structured Pareto front exploration |
33//! | [`MotpeSampler`] | Multi-objective TPE | Bayesian multi-objective |
34//!
35//! # Implementing a custom sampler
36//!
37//! Implement the [`Sampler`] trait with its single method:
38//!
39//! ```rust
40//! use optimizer::sampler::{Sampler, CompletedTrial};
41//! use optimizer::distribution::Distribution;
42//! use optimizer::param::ParamValue;
43//!
44//! /// A sampler that always picks the midpoint of each distribution.
45//! struct MidpointSampler;
46//!
47//! impl Sampler for MidpointSampler {
48//!     fn sample(
49//!         &self,
50//!         distribution: &Distribution,
51//!         _trial_id: u64,
52//!         _history: &[CompletedTrial],
53//!     ) -> ParamValue {
54//!         match distribution {
55//!             Distribution::Float(fd) => {
56//!                 ParamValue::Float((fd.low + fd.high) / 2.0)
57//!             }
58//!             Distribution::Int(id) => {
59//!                 ParamValue::Int((id.low + id.high) / 2)
60//!             }
61//!             Distribution::Categorical(cd) => {
62//!                 ParamValue::Categorical(cd.n_choices / 2)
63//!             }
64//!         }
65//!     }
66//! }
67//! ```
68//!
69//! The arguments to [`Sampler::sample`]:
70//!
71//! - **`distribution`** — a [`Distribution::Float`], [`Distribution::Int`], or
72//!   [`Distribution::Categorical`] that describes the parameter bounds,
73//!   log-scale flag, and optional step size.
74//! - **`trial_id`** — a monotonically increasing identifier. Useful for
75//!   deterministic RNG seeding (see [Stateless vs stateful samplers]).
76//! - **`history`** — all completed trials so far. May be empty on the first
77//!   trial. Model-based samplers use this to guide future sampling.
78//! - **Return value** — the [`ParamValue`] variant *must* match the
79//!   distribution variant (`Float` → `ParamValue::Float`, etc.).
80//!
81//! [Stateless vs stateful samplers]: #stateless-vs-stateful-samplers
82//!
83//! # Stateless vs stateful samplers
84//!
85//! **Stateless** samplers derive all randomness from a deterministic function
86//! of `seed + trial_id + distribution`. They use an [`AtomicU64`] call-sequence
87//! counter to disambiguate multiple calls within the same trial, but need no
88//! `Mutex`. See [`RandomSampler`] and [`TpeSampler`] for this pattern.
89//!
90//! **Stateful** samplers maintain mutable state (e.g. a population pool)
91//! across calls. Wrap mutable state in `parking_lot::Mutex<State>` and lock
92//! for the duration of [`Sampler::sample`]. See [`DESampler`] and
93//! [`GridSampler`] for this pattern.
94//!
95//! [`AtomicU64`]: core::sync::atomic::AtomicU64
96//!
97//! # Cold start handling
98//!
99//! Model-based samplers need completed trials before their surrogate model is
100//! useful. The standard pattern is to check `history.len() < n_startup_trials`
101//! and fall back to random sampling during the startup phase. Expose this as a
102//! builder parameter so users can tune the trade-off between exploration and
103//! exploitation. See [`TpeSampler`] for a reference implementation.
104//!
105//! # Reading trial history
106//!
107//! The `history` slice contains only completed trials (never pending ones).
108//! Common operations:
109//!
110//! - **Extract a parameter value:**
111//!   `trial.params.get(&param_id)` returns `Option<&ParamValue>`.
112//! - **Find the best trial:**
113//!   `history.iter().min_by(|a, b| a.value.partial_cmp(&b.value).unwrap())`.
114//! - **Filter by state:**
115//!   `history.iter().filter(|t| t.state == TrialState::Complete)`.
116//! - **Check feasibility:**
117//!   `trial.is_feasible()` returns `true` when all constraints are ≤ 0.
118//!
119//! # Thread safety
120//!
121//! The [`Sampler`] trait requires `Send + Sync`. [`Study`](crate::Study) stores
122//! the sampler as `Arc<dyn Sampler>`, so multiple threads may call
123//! [`Sampler::sample`] concurrently.
124//!
125//! - **Stateless:** `AtomicU64` counters satisfy `Send + Sync` without locking.
126//! - **Stateful:** use `parking_lot::Mutex` (the crate convention) or
127//!   `std::sync::Mutex` to protect mutable state.
128//!
129//! # Testing custom samplers
130//!
131//! Recommended test categories:
132//!
133//! 1. **Bounds compliance** — sample many values and assert they fall within
134//!    the distribution range.
135//! 2. **Step / log-scale correctness** — verify that discretized and
136//!    log-scaled distributions produce valid values.
137//! 3. **Reproducibility** — the same seed must produce the same output.
138//! 4. **History sensitivity** — model-based samplers should produce different
139//!    (better) samples as history grows.
140//! 5. **Empty history** — `sample()` must not panic when `history` is empty.
141//!
142//! # Using a custom sampler with Study
143//!
144//! ```rust
145//! use optimizer::{Direction, Study};
146//! use optimizer::sampler::{Sampler, CompletedTrial};
147//! use optimizer::distribution::Distribution;
148//! use optimizer::param::ParamValue;
149//!
150//! struct MySampler;
151//! impl Sampler for MySampler {
152//!     fn sample(
153//!         &self,
154//!         distribution: &Distribution,
155//!         _trial_id: u64,
156//!         _history: &[CompletedTrial],
157//!     ) -> ParamValue {
158//!         match distribution {
159//!             Distribution::Float(fd) => ParamValue::Float(fd.low),
160//!             Distribution::Int(id) => ParamValue::Int(id.low),
161//!             Distribution::Categorical(_) => ParamValue::Categorical(0),
162//!         }
163//!     }
164//! }
165//!
166//! let study: Study<f64> = Study::with_sampler(Direction::Minimize, MySampler);
167//! ```
168//!
169//! The sampler is wrapped in `Arc<dyn Sampler>` internally.
170//!
171//! # Reference implementations
172//!
173//! - [`RandomSampler`] — simplest sampler; stateless, ignores history.
174//! - [`TpeSampler`] — model-based with cold start fallback.
175//! - [`DESampler`] — stateful, population-based.
176//! - [`GridSampler`] — deterministic, exhaustive search.
177
178pub mod bohb;
179#[cfg(feature = "cma-es")]
180pub mod cma_es;
181pub(crate) mod common;
182pub mod de;
183pub(crate) mod genetic;
184#[cfg(feature = "gp")]
185pub mod gp;
186pub mod grid;
187pub mod moead;
188pub mod motpe;
189pub mod nsga2;
190pub mod nsga3;
191pub mod random;
192#[cfg(feature = "sobol")]
193pub mod sobol;
194pub mod tpe;
195
196use std::collections::HashMap;
197
198pub use bohb::BohbSampler;
199#[cfg(feature = "cma-es")]
200pub use cma_es::CmaEsSampler;
201pub use de::{DESampler, DEStrategy};
202#[cfg(feature = "gp")]
203pub use gp::GpSampler;
204pub use grid::GridSampler;
205pub use moead::{Decomposition, MoeadSampler};
206pub use motpe::MotpeSampler;
207pub use nsga2::Nsga2Sampler;
208pub use nsga3::Nsga3Sampler;
209pub use random::RandomSampler;
210#[cfg(feature = "sobol")]
211pub use sobol::SobolSampler;
212pub use tpe::TpeSampler;
213
214use crate::distribution::Distribution;
215use crate::param::ParamValue;
216use crate::parameter::{ParamId, Parameter};
217use crate::trial::AttrValue;
218use crate::types::TrialState;
219
220/// A completed trial with its parameters, distributions, and objective value.
221///
222/// This struct stores the results of a completed trial, including all sampled
223/// parameter values, their distributions, and the objective value returned
224/// by the objective function.
225#[derive(Clone, Debug)]
226#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
227pub struct CompletedTrial<V = f64> {
228    /// The unique identifier for this trial.
229    pub id: u64,
230    /// The sampled parameter values, keyed by parameter id.
231    pub params: HashMap<ParamId, ParamValue>,
232    /// The parameter distributions used, keyed by parameter id.
233    pub distributions: HashMap<ParamId, Distribution>,
234    /// Human-readable labels for parameters, keyed by parameter id.
235    pub param_labels: HashMap<ParamId, String>,
236    /// The objective value returned by the objective function.
237    pub value: V,
238    /// Intermediate objective values reported during the trial.
239    pub intermediate_values: Vec<(u64, f64)>,
240    /// The state of the trial (Complete, Pruned, or Failed).
241    pub state: TrialState,
242    /// User-defined attributes stored during the trial.
243    pub user_attrs: HashMap<String, AttrValue>,
244    /// Constraint values for this trial (<=0.0 means feasible).
245    #[cfg_attr(feature = "serde", serde(default))]
246    pub constraints: Vec<f64>,
247}
248
249impl<V> CompletedTrial<V> {
250    /// Creates a new completed trial.
251    pub fn new(
252        id: u64,
253        params: HashMap<ParamId, ParamValue>,
254        distributions: HashMap<ParamId, Distribution>,
255        param_labels: HashMap<ParamId, String>,
256        value: V,
257    ) -> Self {
258        Self {
259            id,
260            params,
261            distributions,
262            param_labels,
263            value,
264            intermediate_values: Vec::new(),
265            state: TrialState::Complete,
266            user_attrs: HashMap::new(),
267            constraints: Vec::new(),
268        }
269    }
270
271    /// Creates a new completed trial with intermediate values and user attributes.
272    pub fn with_intermediate_values(
273        id: u64,
274        params: HashMap<ParamId, ParamValue>,
275        distributions: HashMap<ParamId, Distribution>,
276        param_labels: HashMap<ParamId, String>,
277        value: V,
278        intermediate_values: Vec<(u64, f64)>,
279        user_attrs: HashMap<String, AttrValue>,
280    ) -> Self {
281        Self {
282            id,
283            params,
284            distributions,
285            param_labels,
286            value,
287            intermediate_values,
288            state: TrialState::Complete,
289            user_attrs,
290            constraints: Vec::new(),
291        }
292    }
293
294    /// Returns the typed value for the given parameter.
295    ///
296    /// Looks up the parameter by its unique id and casts the stored
297    /// [`ParamValue`] to the parameter's typed value.
298    ///
299    /// Returns `None` if the parameter was not used in this trial or if
300    /// the stored value is incompatible with the parameter type (e.g., a
301    /// `Float` value stored for an `IntParam`).
302    ///
303    /// # Examples
304    ///
305    /// ```
306    /// use optimizer::parameter::{FloatParam, Parameter};
307    /// use optimizer::{Direction, Study};
308    ///
309    /// let study: Study<f64> = Study::new(Direction::Minimize);
310    /// let x = FloatParam::new(-10.0, 10.0);
311    ///
312    /// study
313    ///     .optimize(5, |trial: &mut optimizer::Trial| {
314    ///         let val = x.suggest(trial)?;
315    ///         Ok::<_, optimizer::Error>(val * val)
316    ///     })
317    ///     .unwrap();
318    ///
319    /// let best = study.best_trial().unwrap();
320    /// let x_val: f64 = best.get(&x).unwrap();
321    /// assert!((-10.0..=10.0).contains(&x_val));
322    /// ```
323    pub fn get<P: Parameter>(&self, param: &P) -> Option<P::Value> {
324        self.params
325            .get(&param.id())
326            .and_then(|v| param.cast_param_value(v).ok())
327    }
328
329    /// Returns `true` if all constraints are satisfied (values <= 0.0).
330    ///
331    /// A trial with no constraints is considered feasible.
332    #[must_use]
333    pub fn is_feasible(&self) -> bool {
334        self.constraints.iter().all(|&c| c <= 0.0)
335    }
336
337    /// Gets a user attribute by key.
338    #[must_use]
339    pub fn user_attr(&self, key: &str) -> Option<&AttrValue> {
340        self.user_attrs.get(key)
341    }
342
343    /// Returns all user attributes.
344    #[must_use]
345    pub fn user_attrs(&self) -> &HashMap<String, AttrValue> {
346        &self.user_attrs
347    }
348
349    /// Validates that all floating-point fields are finite (not NaN or
350    /// Infinity).
351    ///
352    /// Checks distribution bounds, parameter values, constraints, and
353    /// intermediate values.  Returns a description of the first invalid
354    /// field found, or `Ok(())` if everything is valid.
355    ///
356    /// # Errors
357    ///
358    /// Returns a `String` describing the first non-finite value found.
359    pub fn validate(&self) -> core::result::Result<(), String> {
360        for (id, dist) in &self.distributions {
361            if let Distribution::Float(fd) = dist {
362                if !fd.low.is_finite() {
363                    return Err(format!(
364                        "trial {}: float distribution for param {id} has non-finite low bound ({})",
365                        self.id, fd.low
366                    ));
367                }
368                if !fd.high.is_finite() {
369                    return Err(format!(
370                        "trial {}: float distribution for param {id} has non-finite high bound ({})",
371                        self.id, fd.high
372                    ));
373                }
374                if let Some(step) = fd.step
375                    && !step.is_finite()
376                {
377                    return Err(format!(
378                        "trial {}: float distribution for param {id} has non-finite step ({step})",
379                        self.id
380                    ));
381                }
382            }
383        }
384
385        for (id, pv) in &self.params {
386            if let ParamValue::Float(v) = pv
387                && !v.is_finite()
388            {
389                return Err(format!(
390                    "trial {}: param {id} has non-finite float value ({v})",
391                    self.id
392                ));
393            }
394        }
395
396        for (i, &c) in self.constraints.iter().enumerate() {
397            if !c.is_finite() {
398                return Err(format!(
399                    "trial {}: constraint[{i}] is non-finite ({c})",
400                    self.id
401                ));
402            }
403        }
404
405        for &(step, v) in &self.intermediate_values {
406            if !v.is_finite() {
407                return Err(format!(
408                    "trial {}: intermediate value at step {step} is non-finite ({v})",
409                    self.id
410                ));
411            }
412        }
413
414        Ok(())
415    }
416}
417
418/// A pending (running) trial with its parameters and distributions, but no objective value yet.
419///
420/// This struct represents a trial that has been started and has sampled parameters,
421/// but is still running and hasn't returned an objective value. It is used with the
422/// constant liar strategy for parallel optimization.
423#[derive(Clone, Debug)]
424pub struct PendingTrial {
425    /// The unique identifier for this trial.
426    pub id: u64,
427    /// The sampled parameter values, keyed by parameter id.
428    pub params: HashMap<ParamId, ParamValue>,
429    /// The parameter distributions used, keyed by parameter id.
430    pub distributions: HashMap<ParamId, Distribution>,
431    /// Human-readable labels for parameters, keyed by parameter id.
432    pub param_labels: HashMap<ParamId, String>,
433}
434
435impl PendingTrial {
436    /// Creates a new pending trial.
437    #[must_use]
438    pub fn new(
439        id: u64,
440        params: HashMap<ParamId, ParamValue>,
441        distributions: HashMap<ParamId, Distribution>,
442        param_labels: HashMap<ParamId, String>,
443    ) -> Self {
444        Self {
445            id,
446            params,
447            distributions,
448            param_labels,
449        }
450    }
451}
452
453/// Trait for pluggable parameter sampling strategies.
454///
455/// Samplers are responsible for generating parameter values based on
456/// the distribution and historical trial data. The trait requires
457/// `Send + Sync` to support concurrent and async optimization.
458///
459/// # Implementing a custom sampler
460///
461/// ```
462/// use optimizer::sampler::{Sampler, CompletedTrial};
463/// use optimizer::distribution::Distribution;
464/// use optimizer::param::ParamValue;
465///
466/// struct NoisySampler {
467///     noise_scale: f64,
468///     seed: u64,
469/// }
470///
471/// impl Sampler for NoisySampler {
472///     fn sample(
473///         &self,
474///         distribution: &Distribution,
475///         trial_id: u64,
476///         history: &[CompletedTrial],
477///     ) -> ParamValue {
478///         // Find the best value seen so far, or fall back to the midpoint
479///         match distribution {
480///             Distribution::Float(fd) => {
481///                 let center = if history.is_empty() {
482///                     (fd.low + fd.high) / 2.0
483///                 } else {
484///                     history.iter()
485///                         .filter_map(|t| t.params.values().next())
486///                         .filter_map(|v| if let ParamValue::Float(f) = v { Some(*f) } else { None })
487///                         .next()
488///                         .unwrap_or((fd.low + fd.high) / 2.0)
489///                 };
490///                 let noise = (trial_id as f64 * 0.1).sin() * self.noise_scale;
491///                 ParamValue::Float(center + noise)
492///             }
493///             Distribution::Int(id) => ParamValue::Int((id.low + id.high) / 2),
494///             Distribution::Categorical(cd) => ParamValue::Categorical(trial_id as usize % cd.n_choices),
495///         }
496///     }
497/// }
498/// ```
499///
500/// See the [module-level documentation](self) for a comprehensive guide
501/// covering cold start handling, thread safety patterns, and testing.
502pub trait Sampler: Send + Sync {
503    /// Samples a parameter value from the given distribution.
504    ///
505    /// # Arguments
506    ///
507    /// * `distribution` - The parameter distribution to sample from.
508    /// * `trial_id` - The unique ID of the trial being sampled for.
509    /// * `history` - Historical completed trials for informed sampling.
510    ///
511    /// # Returns
512    ///
513    /// A `ParamValue` sampled from the distribution.
514    fn sample(
515        &self,
516        distribution: &Distribution,
517        trial_id: u64,
518        history: &[CompletedTrial],
519    ) -> ParamValue;
520}