1#![allow(clippy::needless_doctest_main)]
2#![warn(missing_debug_implementations)]
3mod state_machine;
23mod types;
24mod until_probabilities_converged;
25
26use derive_getters::{Dissolve, Getters};
27use derive_more::IsVariant;
28pub use optimal_core::prelude::*;
29use rand::prelude::*;
30use rand_xoshiro::{SplitMix64, Xoshiro256PlusPlus};
31
32use self::state_machine::DynState;
33pub use self::{types::*, until_probabilities_converged::*};
34
35#[cfg(feature = "serde")]
36use serde::{Deserialize, Serialize};
37
38#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)]
41#[error("problem length does not match state length")]
42pub struct MismatchedLengthError;
43
44#[derive(Clone, Debug, Getters, Dissolve)]
46#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
47#[dissolve(rename = "into_parts")]
48pub struct Pbil<B, F> {
49 config: Config,
51
52 state: State<B>,
54
55 obj_func: F,
57}
58
59#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
61#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
62pub struct Config {
63 pub num_samples: NumSamples,
66 pub adjust_rate: AdjustRate,
69 pub mutation_chance: MutationChance,
72 pub mutation_adjust_rate: MutationAdjustRate,
75}
76
77#[derive(Clone, Debug, PartialEq)]
79#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
80#[cfg_attr(feature = "serde", serde(transparent))]
81pub struct State<B>(DynState<B>);
82
83#[derive(Clone, Debug, PartialEq, IsVariant)]
85#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
86pub enum StateKind {
87 Started,
89 Sampled,
91 Evaluated,
93 Compared,
95 Adjusted,
97 Mutated,
99 Finished,
101}
102
103impl<B, F> Pbil<B, F> {
104 fn new(state: State<B>, config: Config, obj_func: F) -> Self {
105 Self {
106 config,
107 obj_func,
108 state,
109 }
110 }
111}
112
113impl<B, F> Pbil<B, F>
114where
115 F: Fn(&[bool]) -> B,
116{
117 pub fn best_point_value(&self) -> B {
119 (self.obj_func)(&self.best_point())
120 }
121}
122
123impl<B, F> StreamingIterator for Pbil<B, F>
124where
125 B: PartialOrd,
126 F: Fn(&[bool]) -> B,
127{
128 type Item = Self;
129
130 fn advance(&mut self) {
131 replace_with::replace_with_or_abort(&mut self.state.0, |state| match state {
132 DynState::Started(x) => {
133 DynState::SampledFirst(x.into_initialized_sampling().into_sampled_first())
134 }
135 DynState::SampledFirst(x) => {
136 DynState::EvaluatedFirst(x.into_evaluated_first(&self.obj_func))
137 }
138 DynState::EvaluatedFirst(x) => DynState::Sampled(x.into_sampled()),
139 DynState::Sampled(x) => DynState::Evaluated(x.into_evaluated(&self.obj_func)),
140 DynState::Evaluated(x) => DynState::Compared(x.into_compared()),
141 DynState::Compared(x) => {
142 if x.samples_generated < self.config.num_samples.into_inner() {
143 DynState::Sampled(x.into_sampled())
144 } else {
145 DynState::Adjusted(x.into_adjusted(self.config.adjust_rate))
146 }
147 }
148 DynState::Adjusted(x) => {
149 if self.config.mutation_chance.into_inner() > 0.0 {
150 DynState::Mutated(x.into_mutated(
151 self.config.mutation_chance,
152 self.config.mutation_adjust_rate,
153 ))
154 } else {
155 DynState::Finished(x.into_finished())
156 }
157 }
158 DynState::Mutated(x) => DynState::Finished(x.into_finished()),
159 DynState::Finished(x) => DynState::Started(x.into_started()),
160 });
161 }
162
163 fn get(&self) -> Option<&Self::Item> {
164 Some(self)
165 }
166}
167
168impl<B, F> Optimizer for Pbil<B, F> {
169 type Point = Vec<bool>;
170
171 fn best_point(&self) -> Self::Point {
172 self.state.best_point()
173 }
174}
175
176impl Config {
177 pub fn new(
179 num_samples: NumSamples,
180 adjust_rate: AdjustRate,
181 mutation_chance: MutationChance,
182 mutation_adjust_rate: MutationAdjustRate,
183 ) -> Self {
184 Self {
185 num_samples,
186 adjust_rate,
187 mutation_chance,
188 mutation_adjust_rate,
189 }
190 }
191}
192
193impl Config {
194 pub fn default_for(num_bits: usize) -> Self {
196 Self {
197 num_samples: NumSamples::default(),
198 adjust_rate: AdjustRate::default(),
199 mutation_chance: MutationChance::default_for(num_bits),
200 mutation_adjust_rate: MutationAdjustRate::default(),
201 }
202 }
203
204 pub fn start_default_for<B, F>(len: usize, obj_func: F) -> Pbil<B, F>
212 where
213 F: Fn(&[bool]) -> B,
214 {
215 Self::default_for(len).start(len, obj_func)
216 }
217
218 pub fn start<B, F>(self, len: usize, obj_func: F) -> Pbil<B, F>
228 where
229 F: Fn(&[bool]) -> B,
230 {
231 Pbil::new(State::initial(len), self, obj_func)
232 }
233
234 pub fn start_using<B, F>(self, len: usize, obj_func: F, rng: &mut SplitMix64) -> Pbil<B, F>
244 where
245 F: Fn(&[bool]) -> B,
246 {
247 Pbil::new(State::initial_using(len, rng), self, obj_func)
248 }
249
250 pub fn start_from<B, F>(self, obj_func: F, state: State<B>) -> Pbil<B, F>
259 where
260 F: Fn(&[bool]) -> B,
261 {
262 Pbil::new(state, self, obj_func)
263 }
264}
265
266impl<B> State<B> {
267 fn initial(len: usize) -> Self {
273 Self::new(
274 [Probability::default()].repeat(len),
275 Xoshiro256PlusPlus::from_entropy(),
276 )
277 }
278
279 fn initial_using<R>(len: usize, rng: R) -> Self
286 where
287 R: Rng,
288 {
289 Self::new(
290 [Probability::default()].repeat(len),
291 Xoshiro256PlusPlus::from_rng(rng).expect("RNG should initialize"),
292 )
293 }
294
295 pub fn new(probabilities: Vec<Probability>, rng: Xoshiro256PlusPlus) -> Self {
297 Self(DynState::new(probabilities, rng))
298 }
299
300 #[allow(clippy::len_without_is_empty)]
302 pub fn len(&self) -> usize {
303 self.probabilities().len()
304 }
305
306 pub fn evaluatee(&self) -> Option<&[bool]> {
308 match &self.0 {
309 DynState::Started(_) => None,
310 DynState::SampledFirst(x) => Some(&x.sample),
311 DynState::EvaluatedFirst(_) => None,
312 DynState::Sampled(x) => Some(&x.sample),
313 DynState::Evaluated(_) => None,
314 DynState::Compared(_) => None,
315 DynState::Adjusted(_) => None,
316 DynState::Mutated(_) => None,
317 DynState::Finished(_) => None,
318 }
319 }
320
321 pub fn evaluation(&self) -> Option<&B> {
323 match &self.0 {
324 DynState::Started(_) => None,
325 DynState::SampledFirst(_) => None,
326 DynState::EvaluatedFirst(x) => Some(x.sample.value()),
327 DynState::Sampled(_) => None,
328 DynState::Evaluated(x) => Some(x.sample.value()),
329 DynState::Compared(_) => None,
330 DynState::Adjusted(_) => None,
331 DynState::Mutated(_) => None,
332 DynState::Finished(_) => None,
333 }
334 }
335
336 pub fn sample(&self) -> Option<&[bool]> {
338 match &self.0 {
339 DynState::Started(_) => None,
340 DynState::SampledFirst(x) => Some(&x.sample),
341 DynState::EvaluatedFirst(x) => Some(x.sample.x()),
342 DynState::Sampled(x) => Some(&x.sample),
343 DynState::Evaluated(x) => Some(x.sample.x()),
344 DynState::Compared(_) => None,
345 DynState::Adjusted(_) => None,
346 DynState::Mutated(_) => None,
347 DynState::Finished(_) => None,
348 }
349 }
350
351 pub fn best_point(&self) -> Vec<bool> {
353 self.probabilities()
354 .iter()
355 .map(|p| f64::from(*p) >= 0.5)
356 .collect()
357 }
358
359 pub fn kind(&self) -> StateKind {
361 match self.0 {
362 DynState::Started(_) => StateKind::Started,
363 DynState::SampledFirst(_) => StateKind::Sampled,
364 DynState::EvaluatedFirst(_) => StateKind::Evaluated,
365 DynState::Sampled(_) => StateKind::Sampled,
366 DynState::Evaluated(_) => StateKind::Evaluated,
367 DynState::Compared(_) => StateKind::Compared,
368 DynState::Adjusted(_) => StateKind::Adjusted,
369 DynState::Mutated(_) => StateKind::Mutated,
370 DynState::Finished(_) => StateKind::Finished,
371 }
372 }
373}
374
375impl<B> Probabilities for State<B> {
376 fn probabilities(&self) -> &[Probability] {
377 match &self.0 {
378 DynState::Started(x) => &x.probabilities,
379 DynState::SampledFirst(x) => x.probabilities.probabilities(),
380 DynState::EvaluatedFirst(x) => x.probabilities.probabilities(),
381 DynState::Sampled(x) => x.probabilities.probabilities(),
382 DynState::Evaluated(x) => x.probabilities.probabilities(),
383 DynState::Compared(x) => x.probabilities.probabilities(),
384 DynState::Adjusted(x) => &x.probabilities,
385 DynState::Mutated(x) => &x.probabilities,
386 DynState::Finished(x) => &x.probabilities,
387 }
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn pbil_should_not_mutate_if_chance_is_zero() {
397 Config {
398 num_samples: NumSamples::default(),
399 adjust_rate: AdjustRate::default(),
400 mutation_chance: MutationChance::new(0.0).unwrap(),
401 mutation_adjust_rate: MutationAdjustRate::default(),
402 }
403 .start(16, |point| point.iter().filter(|x| **x).count())
404 .inspect(|x| assert!(!x.state().kind().is_mutated()))
405 .nth(100);
406 }
407}