optimal_pbil/
types.rs

1use core::convert::TryFrom;
2use std::f64::EPSILON;
3
4use derive_more::{Display, Into};
5use derive_num_bounded::{
6    derive_from_str_from_try_into, derive_into_inner, derive_new_from_bounded_float,
7    derive_new_from_lower_bounded, derive_try_from_from_new,
8};
9use num_traits::bounds::{LowerBounded, UpperBounded};
10use rand::distributions::Bernoulli;
11
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14
15/// Number of samples generated
16/// during steps.
17#[derive(Clone, Copy, Debug, Display, PartialEq, Eq, PartialOrd, Ord, Into)]
18#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
19#[cfg_attr(feature = "serde", serde(into = "usize"))]
20#[cfg_attr(feature = "serde", serde(try_from = "usize"))]
21pub struct NumSamples(usize);
22
23impl Default for NumSamples {
24    fn default() -> Self {
25        Self(20)
26    }
27}
28
29impl LowerBounded for NumSamples {
30    fn min_value() -> Self {
31        Self(2)
32    }
33}
34
35derive_new_from_lower_bounded!(NumSamples(usize));
36derive_into_inner!(NumSamples(usize));
37derive_try_from_from_new!(NumSamples(usize));
38derive_from_str_from_try_into!(NumSamples(usize));
39
40/// Degree to adjust probabilities towards best point
41/// during steps.
42#[derive(Clone, Copy, Debug, Display, PartialEq, PartialOrd, Into)]
43#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
44#[cfg_attr(feature = "serde", serde(into = "f64"))]
45#[cfg_attr(feature = "serde", serde(try_from = "f64"))]
46pub struct AdjustRate(f64);
47
48impl Default for AdjustRate {
49    fn default() -> Self {
50        Self(0.1)
51    }
52}
53
54impl LowerBounded for AdjustRate {
55    fn min_value() -> Self {
56        Self(EPSILON)
57    }
58}
59
60impl UpperBounded for AdjustRate {
61    fn max_value() -> Self {
62        Self(1.)
63    }
64}
65
66derive_new_from_bounded_float!(AdjustRate(f64));
67derive_into_inner!(AdjustRate(f64));
68derive_try_from_from_new!(AdjustRate(f64));
69derive_from_str_from_try_into!(AdjustRate(f64));
70
71impl Eq for AdjustRate {}
72
73#[allow(clippy::derive_ord_xor_partial_ord)]
74impl Ord for AdjustRate {
75    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
76        // `f64` has total ordering for the the range of values allowed by this type.
77        unsafe { self.partial_cmp(other).unwrap_unchecked() }
78    }
79}
80
81/// Probability for each probability to mutate,
82/// independently.
83#[derive(Clone, Copy, Debug, Display, PartialEq, PartialOrd, Into)]
84#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
85#[cfg_attr(feature = "serde", serde(into = "f64"))]
86#[cfg_attr(feature = "serde", serde(try_from = "f64"))]
87pub struct MutationChance(f64);
88
89impl MutationChance {
90    /// Return recommended default mutation chance,
91    /// average of one mutation per step.
92    pub fn default_for(num_bits: usize) -> Self {
93        if num_bits == 0 {
94            Self(1.)
95        } else {
96            Self(1. / num_bits as f64)
97        }
98    }
99}
100
101impl LowerBounded for MutationChance {
102    fn min_value() -> Self {
103        Self(0.)
104    }
105}
106
107impl UpperBounded for MutationChance {
108    fn max_value() -> Self {
109        Self(1.)
110    }
111}
112
113impl From<MutationChance> for Bernoulli {
114    fn from(x: MutationChance) -> Self {
115        Bernoulli::new(x.into()).unwrap()
116    }
117}
118
119derive_new_from_bounded_float!(MutationChance(f64));
120derive_into_inner!(MutationChance(f64));
121derive_try_from_from_new!(MutationChance(f64));
122derive_from_str_from_try_into!(MutationChance(f64));
123
124impl Eq for MutationChance {}
125
126#[allow(clippy::derive_ord_xor_partial_ord)]
127impl Ord for MutationChance {
128    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
129        // `f64` has total ordering for the the range of values allowed by this type.
130        unsafe { self.partial_cmp(other).unwrap_unchecked() }
131    }
132}
133
134impl MutationChance {
135    pub fn is_zero(&self) -> bool {
136        self.0 == 0.0
137    }
138}
139
140/// Degree to adjust probability towards random value
141/// when mutating.
142#[derive(Clone, Copy, Debug, Display, PartialEq, PartialOrd, Into)]
143#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
144#[cfg_attr(feature = "serde", serde(into = "f64"))]
145#[cfg_attr(feature = "serde", serde(try_from = "f64"))]
146pub struct MutationAdjustRate(f64);
147
148impl Default for MutationAdjustRate {
149    fn default() -> Self {
150        Self(0.05)
151    }
152}
153
154impl LowerBounded for MutationAdjustRate {
155    fn min_value() -> Self {
156        Self(EPSILON)
157    }
158}
159
160impl UpperBounded for MutationAdjustRate {
161    fn max_value() -> Self {
162        Self(1.)
163    }
164}
165
166derive_new_from_bounded_float!(MutationAdjustRate(f64));
167derive_into_inner!(MutationAdjustRate(f64));
168derive_try_from_from_new!(MutationAdjustRate(f64));
169derive_from_str_from_try_into!(MutationAdjustRate(f64));
170
171impl Eq for MutationAdjustRate {}
172
173#[allow(clippy::derive_ord_xor_partial_ord)]
174impl Ord for MutationAdjustRate {
175    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
176        // `f64` has total ordering for the the range of values allowed by this type.
177        unsafe { self.partial_cmp(other).unwrap_unchecked() }
178    }
179}
180
181/// Probability for a sampled bit to be true.
182#[derive(Clone, Copy, Debug, Display, PartialEq, PartialOrd, Into)]
183#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
184#[cfg_attr(feature = "serde", serde(into = "f64"))]
185#[cfg_attr(feature = "serde", serde(try_from = "f64"))]
186pub struct Probability(f64);
187
188impl Probability {
189    /// # Safety
190    ///
191    /// This function is safe
192    /// if the given value
193    /// is within range `[0,1]`.
194    pub const unsafe fn new_unchecked(x: f64) -> Self {
195        Self(x)
196    }
197}
198
199impl Default for Probability {
200    fn default() -> Self {
201        Self(0.5)
202    }
203}
204
205impl LowerBounded for Probability {
206    fn min_value() -> Self {
207        Self(0.)
208    }
209}
210
211impl UpperBounded for Probability {
212    fn max_value() -> Self {
213        Self(1.)
214    }
215}
216
217derive_new_from_bounded_float!(Probability(f64));
218derive_into_inner!(Probability(f64));
219derive_try_from_from_new!(Probability(f64));
220derive_from_str_from_try_into!(Probability(f64));
221
222impl Eq for Probability {}
223
224#[allow(clippy::derive_ord_xor_partial_ord)]
225impl Ord for Probability {
226    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
227        // `f64` has total ordering for the the range of values allowed by this type.
228        unsafe { self.partial_cmp(other).unwrap_unchecked() }
229    }
230}
231
232/// PBIL can be considered done
233/// when all probabilities are above this threshold
234/// or below the inverse.
235#[derive(Clone, Copy, Debug)]
236#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
237#[cfg_attr(feature = "serde", serde(into = "Probability"))]
238#[cfg_attr(feature = "serde", serde(try_from = "Probability"))]
239pub struct ProbabilityThreshold {
240    ub: Probability,
241    lb: Probability,
242}
243
244/// Error returned when 'ConvergedThreshold' is given an invalid value.
245#[derive(Clone, Copy, Debug, Display, PartialEq, Eq)]
246pub enum InvalidProbabilityThresholdError {
247    /// Value is below the lower bound.
248    TooLow,
249    /// Value is above the upper bound.
250    TooHigh,
251}
252
253impl ProbabilityThreshold {
254    /// Return a new 'ConvergedThreshold' if given a valid value.
255    pub fn new(value: Probability) -> Result<Self, InvalidProbabilityThresholdError> {
256        if value < Self::min_value().into() {
257            Err(InvalidProbabilityThresholdError::TooLow)
258        } else if value > Self::max_value().into() {
259            Err(InvalidProbabilityThresholdError::TooHigh)
260        } else {
261            Ok(Self {
262                ub: value,
263                lb: Probability(1. - f64::from(value)),
264            })
265        }
266    }
267
268    /// Unwrap 'ConvergedThreshold' into inner value.
269    pub fn into_inner(self) -> Probability {
270        self.ub
271    }
272
273    /// Return the threshold upper bound.
274    pub fn upper_bound(&self) -> Probability {
275        self.ub
276    }
277
278    /// Return the threshold lower bound.
279    pub fn lower_bound(&self) -> Probability {
280        self.lb
281    }
282}
283
284impl LowerBounded for ProbabilityThreshold {
285    fn min_value() -> Self {
286        Self {
287            ub: Probability(0.5 + EPSILON),
288            lb: Probability(0.5 - EPSILON),
289        }
290    }
291}
292
293impl UpperBounded for ProbabilityThreshold {
294    fn max_value() -> Self {
295        Self {
296            ub: Probability(1. - EPSILON),
297            lb: Probability(EPSILON),
298        }
299    }
300}
301
302impl Default for ProbabilityThreshold {
303    fn default() -> Self {
304        Self {
305            ub: Probability(0.75),
306            lb: Probability(0.25),
307        }
308    }
309}
310
311// `lb` is fully dependent on `ub`.
312impl PartialEq for ProbabilityThreshold {
313    fn eq(&self, other: &Self) -> bool {
314        self.ub.eq(&other.ub)
315    }
316}
317impl Eq for ProbabilityThreshold {}
318impl PartialOrd for ProbabilityThreshold {
319    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
320        self.ub.partial_cmp(&other.ub)
321    }
322}
323impl Ord for ProbabilityThreshold {
324    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
325        // `f64` has total ordering for the the range of values allowed by this type.
326        unsafe { self.partial_cmp(other).unwrap_unchecked() }
327    }
328}
329
330impl From<ProbabilityThreshold> for Probability {
331    fn from(x: ProbabilityThreshold) -> Self {
332        x.ub
333    }
334}
335
336impl TryFrom<Probability> for ProbabilityThreshold {
337    type Error = InvalidProbabilityThresholdError;
338    fn try_from(value: Probability) -> Result<Self, Self::Error> {
339        Self::new(value)
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn num_samples_from_str_returns_correct_result() {
349        assert_eq!("10".parse::<NumSamples>().unwrap(), NumSamples(10));
350    }
351
352    #[test]
353    fn adjust_rate_from_str_returns_correct_result() {
354        assert_eq!("0.2".parse::<AdjustRate>().unwrap(), AdjustRate(0.2));
355    }
356
357    #[test]
358    fn mutation_chance_from_str_returns_correct_result() {
359        assert_eq!(
360            "0.2".parse::<MutationChance>().unwrap(),
361            MutationChance(0.2)
362        );
363    }
364
365    #[test]
366    fn mutation_adjust_rate_from_str_returns_correct_result() {
367        assert_eq!(
368            "0.2".parse::<MutationAdjustRate>().unwrap(),
369            MutationAdjustRate(0.2)
370        );
371    }
372}