optimizer/sampler/mod.rs
1//! Sampler trait and implementations for parameter sampling.
2//!
3//! A sampler generates parameter values for each trial. It receives a
4//! [`Distribution`] describing the parameter space, a monotonically increasing
5//! `trial_id`, and the list of all [`CompletedTrial`]s so far, and returns a
6//! [`ParamValue`] that matches the distribution variant.
7//!
8//! # Available samplers
9//!
10//! ## Single-objective
11//!
12//! | Sampler | Algorithm | Best for |
13//! |---------|-----------|----------|
14//! | [`RandomSampler`] | Uniform independent sampling | Baselines, startup phases |
15//! | [`TpeSampler`] | Tree-Parzen Estimator | General-purpose Bayesian optimization |
16//! | [`TpeSampler`] (multivariate) | Multivariate TPE with tree-structured Parzen | Correlated parameters |
17//! | [`GridSampler`] | Exhaustive grid evaluation | Small discrete spaces |
18//! | [`SobolSampler`]\* | Quasi-random Sobol sequences | Uniform coverage without model |
19//! | [`CmaEsSampler`]\* | Covariance Matrix Adaptation | Continuous, non-separable problems |
20//! | [`GpSampler`]\* | Gaussian Process with EI | Expensive, low-dimensional functions |
21//! | [`DESampler`] | Differential Evolution | Population-based, multi-modal landscapes |
22//! | [`BohbSampler`] | Bayesian Optimization + `HyperBand` | Combined sampling and pruning |
23//!
24//! \*Requires a feature flag (`sobol`, `cma-es`, or `gp`).
25//!
26//! ## Multi-objective
27//!
28//! | Sampler | Algorithm | Best for |
29//! |---------|-----------|----------|
30//! | [`Nsga2Sampler`] | NSGA-II | General multi-objective with 2-3 objectives |
31//! | [`Nsga3Sampler`] | NSGA-III | Many-objective (4+ objectives) |
32//! | [`MoeadSampler`] | MOEA/D with decomposition | Structured Pareto front exploration |
33//! | [`MotpeSampler`] | Multi-objective TPE | Bayesian multi-objective |
34//!
35//! # Implementing a custom sampler
36//!
37//! Implement the [`Sampler`] trait with its single method:
38//!
39//! ```rust
40//! use optimizer::sampler::{Sampler, CompletedTrial};
41//! use optimizer::distribution::Distribution;
42//! use optimizer::param::ParamValue;
43//!
44//! /// A sampler that always picks the midpoint of each distribution.
45//! struct MidpointSampler;
46//!
47//! impl Sampler for MidpointSampler {
48//! fn sample(
49//! &self,
50//! distribution: &Distribution,
51//! _trial_id: u64,
52//! _history: &[CompletedTrial],
53//! ) -> ParamValue {
54//! match distribution {
55//! Distribution::Float(fd) => {
56//! ParamValue::Float((fd.low + fd.high) / 2.0)
57//! }
58//! Distribution::Int(id) => {
59//! ParamValue::Int((id.low + id.high) / 2)
60//! }
61//! Distribution::Categorical(cd) => {
62//! ParamValue::Categorical(cd.n_choices / 2)
63//! }
64//! }
65//! }
66//! }
67//! ```
68//!
69//! The arguments to [`Sampler::sample`]:
70//!
71//! - **`distribution`** — a [`Distribution::Float`], [`Distribution::Int`], or
72//! [`Distribution::Categorical`] that describes the parameter bounds,
73//! log-scale flag, and optional step size.
74//! - **`trial_id`** — a monotonically increasing identifier. Useful for
75//! deterministic RNG seeding (see [Stateless vs stateful samplers]).
76//! - **`history`** — all completed trials so far. May be empty on the first
77//! trial. Model-based samplers use this to guide future sampling.
78//! - **Return value** — the [`ParamValue`] variant *must* match the
79//! distribution variant (`Float` → `ParamValue::Float`, etc.).
80//!
81//! [Stateless vs stateful samplers]: #stateless-vs-stateful-samplers
82//!
83//! # Stateless vs stateful samplers
84//!
85//! **Stateless** samplers derive all randomness from a deterministic function
86//! of `seed + trial_id + distribution`. They use an [`AtomicU64`] call-sequence
87//! counter to disambiguate multiple calls within the same trial, but need no
88//! `Mutex`. See [`RandomSampler`] and [`TpeSampler`] for this pattern.
89//!
90//! **Stateful** samplers maintain mutable state (e.g. a population pool)
91//! across calls. Wrap mutable state in `parking_lot::Mutex<State>` and lock
92//! for the duration of [`Sampler::sample`]. See [`DESampler`] and
93//! [`GridSampler`] for this pattern.
94//!
95//! [`AtomicU64`]: core::sync::atomic::AtomicU64
96//!
97//! # Cold start handling
98//!
99//! Model-based samplers need completed trials before their surrogate model is
100//! useful. The standard pattern is to check `history.len() < n_startup_trials`
101//! and fall back to random sampling during the startup phase. Expose this as a
102//! builder parameter so users can tune the trade-off between exploration and
103//! exploitation. See [`TpeSampler`] for a reference implementation.
104//!
105//! # Reading trial history
106//!
107//! The `history` slice contains only completed trials (never pending ones).
108//! Common operations:
109//!
110//! - **Extract a parameter value:**
111//! `trial.params.get(¶m_id)` returns `Option<&ParamValue>`.
112//! - **Find the best trial:**
113//! `history.iter().min_by(|a, b| a.value.partial_cmp(&b.value).unwrap())`.
114//! - **Filter by state:**
115//! `history.iter().filter(|t| t.state == TrialState::Complete)`.
116//! - **Check feasibility:**
117//! `trial.is_feasible()` returns `true` when all constraints are ≤ 0.
118//!
119//! # Thread safety
120//!
121//! The [`Sampler`] trait requires `Send + Sync`. [`Study`](crate::Study) stores
122//! the sampler as `Arc<dyn Sampler>`, so multiple threads may call
123//! [`Sampler::sample`] concurrently.
124//!
125//! - **Stateless:** `AtomicU64` counters satisfy `Send + Sync` without locking.
126//! - **Stateful:** use `parking_lot::Mutex` (the crate convention) or
127//! `std::sync::Mutex` to protect mutable state.
128//!
129//! # Testing custom samplers
130//!
131//! Recommended test categories:
132//!
133//! 1. **Bounds compliance** — sample many values and assert they fall within
134//! the distribution range.
135//! 2. **Step / log-scale correctness** — verify that discretized and
136//! log-scaled distributions produce valid values.
137//! 3. **Reproducibility** — the same seed must produce the same output.
138//! 4. **History sensitivity** — model-based samplers should produce different
139//! (better) samples as history grows.
140//! 5. **Empty history** — `sample()` must not panic when `history` is empty.
141//!
142//! # Using a custom sampler with Study
143//!
144//! ```rust
145//! use optimizer::{Direction, Study};
146//! use optimizer::sampler::{Sampler, CompletedTrial};
147//! use optimizer::distribution::Distribution;
148//! use optimizer::param::ParamValue;
149//!
150//! struct MySampler;
151//! impl Sampler for MySampler {
152//! fn sample(
153//! &self,
154//! distribution: &Distribution,
155//! _trial_id: u64,
156//! _history: &[CompletedTrial],
157//! ) -> ParamValue {
158//! match distribution {
159//! Distribution::Float(fd) => ParamValue::Float(fd.low),
160//! Distribution::Int(id) => ParamValue::Int(id.low),
161//! Distribution::Categorical(_) => ParamValue::Categorical(0),
162//! }
163//! }
164//! }
165//!
166//! let study: Study<f64> = Study::with_sampler(Direction::Minimize, MySampler);
167//! ```
168//!
169//! The sampler is wrapped in `Arc<dyn Sampler>` internally.
170//!
171//! # Reference implementations
172//!
173//! - [`RandomSampler`] — simplest sampler; stateless, ignores history.
174//! - [`TpeSampler`] — model-based with cold start fallback.
175//! - [`DESampler`] — stateful, population-based.
176//! - [`GridSampler`] — deterministic, exhaustive search.
177
178pub mod bohb;
179#[cfg(feature = "cma-es")]
180pub mod cma_es;
181pub(crate) mod common;
182pub mod de;
183pub(crate) mod genetic;
184#[cfg(feature = "gp")]
185pub mod gp;
186pub mod grid;
187pub mod moead;
188pub mod motpe;
189pub mod nsga2;
190pub mod nsga3;
191pub mod random;
192#[cfg(feature = "sobol")]
193pub mod sobol;
194pub mod tpe;
195
196use std::collections::HashMap;
197
198pub use bohb::BohbSampler;
199#[cfg(feature = "cma-es")]
200pub use cma_es::CmaEsSampler;
201pub use de::{DESampler, DEStrategy};
202#[cfg(feature = "gp")]
203pub use gp::GpSampler;
204pub use grid::GridSampler;
205pub use moead::{Decomposition, MoeadSampler};
206pub use motpe::MotpeSampler;
207pub use nsga2::Nsga2Sampler;
208pub use nsga3::Nsga3Sampler;
209pub use random::RandomSampler;
210#[cfg(feature = "sobol")]
211pub use sobol::SobolSampler;
212pub use tpe::TpeSampler;
213
214use crate::distribution::Distribution;
215use crate::param::ParamValue;
216use crate::parameter::{ParamId, Parameter};
217use crate::trial::AttrValue;
218use crate::types::TrialState;
219
220/// A completed trial with its parameters, distributions, and objective value.
221///
222/// This struct stores the results of a completed trial, including all sampled
223/// parameter values, their distributions, and the objective value returned
224/// by the objective function.
225#[derive(Clone, Debug)]
226#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
227pub struct CompletedTrial<V = f64> {
228 /// The unique identifier for this trial.
229 pub id: u64,
230 /// The sampled parameter values, keyed by parameter id.
231 pub params: HashMap<ParamId, ParamValue>,
232 /// The parameter distributions used, keyed by parameter id.
233 pub distributions: HashMap<ParamId, Distribution>,
234 /// Human-readable labels for parameters, keyed by parameter id.
235 pub param_labels: HashMap<ParamId, String>,
236 /// The objective value returned by the objective function.
237 pub value: V,
238 /// Intermediate objective values reported during the trial.
239 pub intermediate_values: Vec<(u64, f64)>,
240 /// The state of the trial (Complete, Pruned, or Failed).
241 pub state: TrialState,
242 /// User-defined attributes stored during the trial.
243 pub user_attrs: HashMap<String, AttrValue>,
244 /// Constraint values for this trial (<=0.0 means feasible).
245 #[cfg_attr(feature = "serde", serde(default))]
246 pub constraints: Vec<f64>,
247}
248
249impl<V> CompletedTrial<V> {
250 /// Creates a new completed trial.
251 pub fn new(
252 id: u64,
253 params: HashMap<ParamId, ParamValue>,
254 distributions: HashMap<ParamId, Distribution>,
255 param_labels: HashMap<ParamId, String>,
256 value: V,
257 ) -> Self {
258 Self {
259 id,
260 params,
261 distributions,
262 param_labels,
263 value,
264 intermediate_values: Vec::new(),
265 state: TrialState::Complete,
266 user_attrs: HashMap::new(),
267 constraints: Vec::new(),
268 }
269 }
270
271 /// Creates a new completed trial with intermediate values and user attributes.
272 pub fn with_intermediate_values(
273 id: u64,
274 params: HashMap<ParamId, ParamValue>,
275 distributions: HashMap<ParamId, Distribution>,
276 param_labels: HashMap<ParamId, String>,
277 value: V,
278 intermediate_values: Vec<(u64, f64)>,
279 user_attrs: HashMap<String, AttrValue>,
280 ) -> Self {
281 Self {
282 id,
283 params,
284 distributions,
285 param_labels,
286 value,
287 intermediate_values,
288 state: TrialState::Complete,
289 user_attrs,
290 constraints: Vec::new(),
291 }
292 }
293
294 /// Returns the typed value for the given parameter.
295 ///
296 /// Looks up the parameter by its unique id and casts the stored
297 /// [`ParamValue`] to the parameter's typed value.
298 ///
299 /// Returns `None` if the parameter was not used in this trial or if
300 /// the stored value is incompatible with the parameter type (e.g., a
301 /// `Float` value stored for an `IntParam`).
302 ///
303 /// # Examples
304 ///
305 /// ```
306 /// use optimizer::parameter::{FloatParam, Parameter};
307 /// use optimizer::{Direction, Study};
308 ///
309 /// let study: Study<f64> = Study::new(Direction::Minimize);
310 /// let x = FloatParam::new(-10.0, 10.0);
311 ///
312 /// study
313 /// .optimize(5, |trial: &mut optimizer::Trial| {
314 /// let val = x.suggest(trial)?;
315 /// Ok::<_, optimizer::Error>(val * val)
316 /// })
317 /// .unwrap();
318 ///
319 /// let best = study.best_trial().unwrap();
320 /// let x_val: f64 = best.get(&x).unwrap();
321 /// assert!((-10.0..=10.0).contains(&x_val));
322 /// ```
323 pub fn get<P: Parameter>(&self, param: &P) -> Option<P::Value> {
324 self.params
325 .get(¶m.id())
326 .and_then(|v| param.cast_param_value(v).ok())
327 }
328
329 /// Returns `true` if all constraints are satisfied (values <= 0.0).
330 ///
331 /// A trial with no constraints is considered feasible.
332 #[must_use]
333 pub fn is_feasible(&self) -> bool {
334 self.constraints.iter().all(|&c| c <= 0.0)
335 }
336
337 /// Gets a user attribute by key.
338 #[must_use]
339 pub fn user_attr(&self, key: &str) -> Option<&AttrValue> {
340 self.user_attrs.get(key)
341 }
342
343 /// Returns all user attributes.
344 #[must_use]
345 pub fn user_attrs(&self) -> &HashMap<String, AttrValue> {
346 &self.user_attrs
347 }
348
349 /// Validates that all floating-point fields are finite (not NaN or
350 /// Infinity).
351 ///
352 /// Checks distribution bounds, parameter values, constraints, and
353 /// intermediate values. Returns a description of the first invalid
354 /// field found, or `Ok(())` if everything is valid.
355 ///
356 /// # Errors
357 ///
358 /// Returns a `String` describing the first non-finite value found.
359 pub fn validate(&self) -> core::result::Result<(), String> {
360 for (id, dist) in &self.distributions {
361 if let Distribution::Float(fd) = dist {
362 if !fd.low.is_finite() {
363 return Err(format!(
364 "trial {}: float distribution for param {id} has non-finite low bound ({})",
365 self.id, fd.low
366 ));
367 }
368 if !fd.high.is_finite() {
369 return Err(format!(
370 "trial {}: float distribution for param {id} has non-finite high bound ({})",
371 self.id, fd.high
372 ));
373 }
374 if let Some(step) = fd.step
375 && !step.is_finite()
376 {
377 return Err(format!(
378 "trial {}: float distribution for param {id} has non-finite step ({step})",
379 self.id
380 ));
381 }
382 }
383 }
384
385 for (id, pv) in &self.params {
386 if let ParamValue::Float(v) = pv
387 && !v.is_finite()
388 {
389 return Err(format!(
390 "trial {}: param {id} has non-finite float value ({v})",
391 self.id
392 ));
393 }
394 }
395
396 for (i, &c) in self.constraints.iter().enumerate() {
397 if !c.is_finite() {
398 return Err(format!(
399 "trial {}: constraint[{i}] is non-finite ({c})",
400 self.id
401 ));
402 }
403 }
404
405 for &(step, v) in &self.intermediate_values {
406 if !v.is_finite() {
407 return Err(format!(
408 "trial {}: intermediate value at step {step} is non-finite ({v})",
409 self.id
410 ));
411 }
412 }
413
414 Ok(())
415 }
416}
417
418/// A pending (running) trial with its parameters and distributions, but no objective value yet.
419///
420/// This struct represents a trial that has been started and has sampled parameters,
421/// but is still running and hasn't returned an objective value. It is used with the
422/// constant liar strategy for parallel optimization.
423#[derive(Clone, Debug)]
424pub struct PendingTrial {
425 /// The unique identifier for this trial.
426 pub id: u64,
427 /// The sampled parameter values, keyed by parameter id.
428 pub params: HashMap<ParamId, ParamValue>,
429 /// The parameter distributions used, keyed by parameter id.
430 pub distributions: HashMap<ParamId, Distribution>,
431 /// Human-readable labels for parameters, keyed by parameter id.
432 pub param_labels: HashMap<ParamId, String>,
433}
434
435impl PendingTrial {
436 /// Creates a new pending trial.
437 #[must_use]
438 pub fn new(
439 id: u64,
440 params: HashMap<ParamId, ParamValue>,
441 distributions: HashMap<ParamId, Distribution>,
442 param_labels: HashMap<ParamId, String>,
443 ) -> Self {
444 Self {
445 id,
446 params,
447 distributions,
448 param_labels,
449 }
450 }
451}
452
453/// Trait for pluggable parameter sampling strategies.
454///
455/// Samplers are responsible for generating parameter values based on
456/// the distribution and historical trial data. The trait requires
457/// `Send + Sync` to support concurrent and async optimization.
458///
459/// # Implementing a custom sampler
460///
461/// ```
462/// use optimizer::sampler::{Sampler, CompletedTrial};
463/// use optimizer::distribution::Distribution;
464/// use optimizer::param::ParamValue;
465///
466/// struct NoisySampler {
467/// noise_scale: f64,
468/// seed: u64,
469/// }
470///
471/// impl Sampler for NoisySampler {
472/// fn sample(
473/// &self,
474/// distribution: &Distribution,
475/// trial_id: u64,
476/// history: &[CompletedTrial],
477/// ) -> ParamValue {
478/// // Find the best value seen so far, or fall back to the midpoint
479/// match distribution {
480/// Distribution::Float(fd) => {
481/// let center = if history.is_empty() {
482/// (fd.low + fd.high) / 2.0
483/// } else {
484/// history.iter()
485/// .filter_map(|t| t.params.values().next())
486/// .filter_map(|v| if let ParamValue::Float(f) = v { Some(*f) } else { None })
487/// .next()
488/// .unwrap_or((fd.low + fd.high) / 2.0)
489/// };
490/// let noise = (trial_id as f64 * 0.1).sin() * self.noise_scale;
491/// ParamValue::Float(center + noise)
492/// }
493/// Distribution::Int(id) => ParamValue::Int((id.low + id.high) / 2),
494/// Distribution::Categorical(cd) => ParamValue::Categorical(trial_id as usize % cd.n_choices),
495/// }
496/// }
497/// }
498/// ```
499///
500/// See the [module-level documentation](self) for a comprehensive guide
501/// covering cold start handling, thread safety patterns, and testing.
502pub trait Sampler: Send + Sync {
503 /// Samples a parameter value from the given distribution.
504 ///
505 /// # Arguments
506 ///
507 /// * `distribution` - The parameter distribution to sample from.
508 /// * `trial_id` - The unique ID of the trial being sampled for.
509 /// * `history` - Historical completed trials for informed sampling.
510 ///
511 /// # Returns
512 ///
513 /// A `ParamValue` sampled from the distribution.
514 fn sample(
515 &self,
516 distribution: &Distribution,
517 trial_id: u64,
518 history: &[CompletedTrial],
519 ) -> ParamValue;
520}