optimal_pbil/
lib.rs

1#![allow(clippy::needless_doctest_main)]
2#![warn(missing_debug_implementations)]
3// `missing_docs` does not work with `IsVariant`,
4// see <https://github.com/JelteF/derive_more/issues/215>.
5// #![warn(missing_docs)]
6
7//! Population-based incremental learning (PBIL).
8//!
9//! # Examples
10//!
11//! ```
12//! use optimal_pbil::*;
13//!
14//! println!(
15//!     "{:?}",
16//!     UntilProbabilitiesConvergedConfig::default()
17//!         .start(Config::start_default_for(16, |point| point.iter().filter(|x| **x).count()))
18//!         .argmin()
19//! );
20//! ```
21
22mod state_machine;
23mod types;
24mod until_probabilities_converged;
25
26use derive_getters::{Dissolve, Getters};
27use derive_more::IsVariant;
28pub use optimal_core::prelude::*;
29use rand::prelude::*;
30use rand_xoshiro::{SplitMix64, Xoshiro256PlusPlus};
31
32use self::state_machine::DynState;
33pub use self::{types::*, until_probabilities_converged::*};
34
35#[cfg(feature = "serde")]
36use serde::{Deserialize, Serialize};
37
38/// Error returned when
39/// problem length does not match state length.
40#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)]
41#[error("problem length does not match state length")]
42pub struct MismatchedLengthError;
43
44/// A running PBIL optimizer.
45#[derive(Clone, Debug, Getters, Dissolve)]
46#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
47#[dissolve(rename = "into_parts")]
48pub struct Pbil<B, F> {
49    /// Optimizer configuration.
50    config: Config,
51
52    /// State of optimizer.
53    state: State<B>,
54
55    /// Objective function to minimize.
56    obj_func: F,
57}
58
59/// PBIL configuration parameters.
60#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
61#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
62pub struct Config {
63    /// Number of samples generated
64    /// during steps.
65    pub num_samples: NumSamples,
66    /// Degree to adjust probabilities towards best point
67    /// during steps.
68    pub adjust_rate: AdjustRate,
69    /// Probability for each probability to mutate,
70    /// independently.
71    pub mutation_chance: MutationChance,
72    /// Degree to adjust probability towards random value
73    /// when mutating.
74    pub mutation_adjust_rate: MutationAdjustRate,
75}
76
77/// PBIL state.
78#[derive(Clone, Debug, PartialEq)]
79#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
80#[cfg_attr(feature = "serde", serde(transparent))]
81pub struct State<B>(DynState<B>);
82
83/// PBIL state kind.
84#[derive(Clone, Debug, PartialEq, IsVariant)]
85#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
86pub enum StateKind {
87    /// Iteration started.
88    Started,
89    /// Sample generated.
90    Sampled,
91    /// Sample evaluated.
92    Evaluated,
93    /// Samples compared.
94    Compared,
95    /// Probabilities adjusted.
96    Adjusted,
97    /// Probabilities mutated.
98    Mutated,
99    /// Iteration finished.
100    Finished,
101}
102
103impl<B, F> Pbil<B, F> {
104    fn new(state: State<B>, config: Config, obj_func: F) -> Self {
105        Self {
106            config,
107            obj_func,
108            state,
109        }
110    }
111}
112
113impl<B, F> Pbil<B, F>
114where
115    F: Fn(&[bool]) -> B,
116{
117    /// Return value of the best point discovered.
118    pub fn best_point_value(&self) -> B {
119        (self.obj_func)(&self.best_point())
120    }
121}
122
123impl<B, F> StreamingIterator for Pbil<B, F>
124where
125    B: PartialOrd,
126    F: Fn(&[bool]) -> B,
127{
128    type Item = Self;
129
130    fn advance(&mut self) {
131        replace_with::replace_with_or_abort(&mut self.state.0, |state| match state {
132            DynState::Started(x) => {
133                DynState::SampledFirst(x.into_initialized_sampling().into_sampled_first())
134            }
135            DynState::SampledFirst(x) => {
136                DynState::EvaluatedFirst(x.into_evaluated_first(&self.obj_func))
137            }
138            DynState::EvaluatedFirst(x) => DynState::Sampled(x.into_sampled()),
139            DynState::Sampled(x) => DynState::Evaluated(x.into_evaluated(&self.obj_func)),
140            DynState::Evaluated(x) => DynState::Compared(x.into_compared()),
141            DynState::Compared(x) => {
142                if x.samples_generated < self.config.num_samples.into_inner() {
143                    DynState::Sampled(x.into_sampled())
144                } else {
145                    DynState::Adjusted(x.into_adjusted(self.config.adjust_rate))
146                }
147            }
148            DynState::Adjusted(x) => {
149                if self.config.mutation_chance.into_inner() > 0.0 {
150                    DynState::Mutated(x.into_mutated(
151                        self.config.mutation_chance,
152                        self.config.mutation_adjust_rate,
153                    ))
154                } else {
155                    DynState::Finished(x.into_finished())
156                }
157            }
158            DynState::Mutated(x) => DynState::Finished(x.into_finished()),
159            DynState::Finished(x) => DynState::Started(x.into_started()),
160        });
161    }
162
163    fn get(&self) -> Option<&Self::Item> {
164        Some(self)
165    }
166}
167
168impl<B, F> Optimizer for Pbil<B, F> {
169    type Point = Vec<bool>;
170
171    fn best_point(&self) -> Self::Point {
172        self.state.best_point()
173    }
174}
175
176impl Config {
177    /// Return a new PBIL configuration.
178    pub fn new(
179        num_samples: NumSamples,
180        adjust_rate: AdjustRate,
181        mutation_chance: MutationChance,
182        mutation_adjust_rate: MutationAdjustRate,
183    ) -> Self {
184        Self {
185            num_samples,
186            adjust_rate,
187            mutation_chance,
188            mutation_adjust_rate,
189        }
190    }
191}
192
193impl Config {
194    /// Return default 'Config'.
195    pub fn default_for(num_bits: usize) -> Self {
196        Self {
197            num_samples: NumSamples::default(),
198            adjust_rate: AdjustRate::default(),
199            mutation_chance: MutationChance::default_for(num_bits),
200            mutation_adjust_rate: MutationAdjustRate::default(),
201        }
202    }
203
204    /// Return this optimizer default
205    /// running on the given problem.
206    ///
207    /// # Arguments
208    ///
209    /// - `len`: number of bits in each point
210    /// - `obj_func`: objective function to minimize
211    pub fn start_default_for<B, F>(len: usize, obj_func: F) -> Pbil<B, F>
212    where
213        F: Fn(&[bool]) -> B,
214    {
215        Self::default_for(len).start(len, obj_func)
216    }
217
218    /// Return this optimizer
219    /// running on the given problem.
220    ///
221    /// This may be nondeterministic.
222    ///
223    /// # Arguments
224    ///
225    /// - `len`: number of bits in each point
226    /// - `obj_func`: objective function to minimize
227    pub fn start<B, F>(self, len: usize, obj_func: F) -> Pbil<B, F>
228    where
229        F: Fn(&[bool]) -> B,
230    {
231        Pbil::new(State::initial(len), self, obj_func)
232    }
233
234    /// Return this optimizer
235    /// running on the given problem
236    /// initialized using `rng`.
237    ///
238    /// # Arguments
239    ///
240    /// - `len`: number of bits in each point
241    /// - `obj_func`: objective function to minimize
242    /// - `rng`: source of randomness
243    pub fn start_using<B, F>(self, len: usize, obj_func: F, rng: &mut SplitMix64) -> Pbil<B, F>
244    where
245        F: Fn(&[bool]) -> B,
246    {
247        Pbil::new(State::initial_using(len, rng), self, obj_func)
248    }
249
250    /// Return this optimizer
251    /// running on the given problem.
252    /// if the given `state` is valid.
253    ///
254    /// # Arguments
255    ///
256    /// - `obj_func`: objective function to minimize
257    /// - `state`: PBIL state to start from
258    pub fn start_from<B, F>(self, obj_func: F, state: State<B>) -> Pbil<B, F>
259    where
260        F: Fn(&[bool]) -> B,
261    {
262        Pbil::new(state, self, obj_func)
263    }
264}
265
266impl<B> State<B> {
267    /// Return recommended initial state.
268    ///
269    /// # Arguments
270    ///
271    /// - `len`: number of bits in each sample
272    fn initial(len: usize) -> Self {
273        Self::new(
274            [Probability::default()].repeat(len),
275            Xoshiro256PlusPlus::from_entropy(),
276        )
277    }
278
279    /// Return recommended initial state.
280    ///
281    /// # Arguments
282    ///
283    /// - `len`: number of bits in each sample
284    /// - `rng`: source of randomness
285    fn initial_using<R>(len: usize, rng: R) -> Self
286    where
287        R: Rng,
288    {
289        Self::new(
290            [Probability::default()].repeat(len),
291            Xoshiro256PlusPlus::from_rng(rng).expect("RNG should initialize"),
292        )
293    }
294
295    /// Return custom initial state.
296    pub fn new(probabilities: Vec<Probability>, rng: Xoshiro256PlusPlus) -> Self {
297        Self(DynState::new(probabilities, rng))
298    }
299
300    /// Return number of bits being optimized.
301    #[allow(clippy::len_without_is_empty)]
302    pub fn len(&self) -> usize {
303        self.probabilities().len()
304    }
305
306    /// Return data to be evaluated.
307    pub fn evaluatee(&self) -> Option<&[bool]> {
308        match &self.0 {
309            DynState::Started(_) => None,
310            DynState::SampledFirst(x) => Some(&x.sample),
311            DynState::EvaluatedFirst(_) => None,
312            DynState::Sampled(x) => Some(&x.sample),
313            DynState::Evaluated(_) => None,
314            DynState::Compared(_) => None,
315            DynState::Adjusted(_) => None,
316            DynState::Mutated(_) => None,
317            DynState::Finished(_) => None,
318        }
319    }
320
321    /// Return result of evaluation.
322    pub fn evaluation(&self) -> Option<&B> {
323        match &self.0 {
324            DynState::Started(_) => None,
325            DynState::SampledFirst(_) => None,
326            DynState::EvaluatedFirst(x) => Some(x.sample.value()),
327            DynState::Sampled(_) => None,
328            DynState::Evaluated(x) => Some(x.sample.value()),
329            DynState::Compared(_) => None,
330            DynState::Adjusted(_) => None,
331            DynState::Mutated(_) => None,
332            DynState::Finished(_) => None,
333        }
334    }
335
336    /// Return sample if stored.
337    pub fn sample(&self) -> Option<&[bool]> {
338        match &self.0 {
339            DynState::Started(_) => None,
340            DynState::SampledFirst(x) => Some(&x.sample),
341            DynState::EvaluatedFirst(x) => Some(x.sample.x()),
342            DynState::Sampled(x) => Some(&x.sample),
343            DynState::Evaluated(x) => Some(x.sample.x()),
344            DynState::Compared(_) => None,
345            DynState::Adjusted(_) => None,
346            DynState::Mutated(_) => None,
347            DynState::Finished(_) => None,
348        }
349    }
350
351    /// Return the best point discovered.
352    pub fn best_point(&self) -> Vec<bool> {
353        self.probabilities()
354            .iter()
355            .map(|p| f64::from(*p) >= 0.5)
356            .collect()
357    }
358
359    /// Return kind of state of inner state-machine.
360    pub fn kind(&self) -> StateKind {
361        match self.0 {
362            DynState::Started(_) => StateKind::Started,
363            DynState::SampledFirst(_) => StateKind::Sampled,
364            DynState::EvaluatedFirst(_) => StateKind::Evaluated,
365            DynState::Sampled(_) => StateKind::Sampled,
366            DynState::Evaluated(_) => StateKind::Evaluated,
367            DynState::Compared(_) => StateKind::Compared,
368            DynState::Adjusted(_) => StateKind::Adjusted,
369            DynState::Mutated(_) => StateKind::Mutated,
370            DynState::Finished(_) => StateKind::Finished,
371        }
372    }
373}
374
375impl<B> Probabilities for State<B> {
376    fn probabilities(&self) -> &[Probability] {
377        match &self.0 {
378            DynState::Started(x) => &x.probabilities,
379            DynState::SampledFirst(x) => x.probabilities.probabilities(),
380            DynState::EvaluatedFirst(x) => x.probabilities.probabilities(),
381            DynState::Sampled(x) => x.probabilities.probabilities(),
382            DynState::Evaluated(x) => x.probabilities.probabilities(),
383            DynState::Compared(x) => x.probabilities.probabilities(),
384            DynState::Adjusted(x) => &x.probabilities,
385            DynState::Mutated(x) => &x.probabilities,
386            DynState::Finished(x) => &x.probabilities,
387        }
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn pbil_should_not_mutate_if_chance_is_zero() {
397        Config {
398            num_samples: NumSamples::default(),
399            adjust_rate: AdjustRate::default(),
400            mutation_chance: MutationChance::new(0.0).unwrap(),
401            mutation_adjust_rate: MutationAdjustRate::default(),
402        }
403        .start(16, |point| point.iter().filter(|x| **x).count())
404        .inspect(|x| assert!(!x.state().kind().is_mutated()))
405        .nth(100);
406    }
407}