optimal_pbil/
until_probabilities_converged.rs

1use derive_getters::Getters;
2use optimal_core::prelude::*;
3
4use crate::{types::*, Pbil};
5
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9/// A type containing an array of probabilities.
10pub trait Probabilities {
11    /// Return probabilities.
12    fn probabilities(&self) -> &[Probability];
13}
14
15impl<B, F> Probabilities for Pbil<B, F> {
16    fn probabilities(&self) -> &[Probability] {
17        self.state().probabilities()
18    }
19}
20
21/// PBIL runner
22/// to check for converged probabilities.
23#[derive(Clone, Debug, Getters)]
24#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
25pub struct UntilProbabilitiesConverged<I> {
26    config: UntilProbabilitiesConvergedConfig,
27    it: I,
28}
29
30/// Config for PBIL runner
31/// to check for converged probabilities.
32#[derive(Clone, Debug, Default, PartialEq)]
33#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
34pub struct UntilProbabilitiesConvergedConfig {
35    /// Probability convergence parameter.
36    pub threshold: ProbabilityThreshold,
37}
38
39impl UntilProbabilitiesConvergedConfig {
40    /// Return this runner
41    /// wrapping the given iterator.
42    pub fn start<I>(self, it: I) -> UntilProbabilitiesConverged<I> {
43        UntilProbabilitiesConverged { config: self, it }
44    }
45}
46
47impl<I> UntilProbabilitiesConverged<I> {
48    /// Return configuration and iterator.
49    pub fn into_inner(self) -> (UntilProbabilitiesConvergedConfig, I) {
50        (self.config, self.it)
51    }
52}
53
54impl<I> StreamingIterator for UntilProbabilitiesConverged<I>
55where
56    I: StreamingIterator + Probabilities,
57{
58    type Item = I::Item;
59
60    fn advance(&mut self) {
61        self.it.advance()
62    }
63
64    fn get(&self) -> Option<&Self::Item> {
65        self.it.get()
66    }
67
68    fn is_done(&self) -> bool {
69        self.it.is_done()
70            || self.it.probabilities().iter().all(|p| {
71                p > &self.config.threshold.upper_bound() || p < &self.config.threshold.lower_bound()
72            })
73    }
74}