bolero_generator/
bounded.rs

1use crate::{Driver, ValueGenerator};
2use core::{
3    marker::PhantomData,
4    ops::{Bound, RangeBounds},
5};
6
7pub(crate) trait BoundExt<T> {
8    fn as_ref(&self) -> Bound<&T>;
9    fn map<U, F: FnOnce(T) -> U>(self, f: F) -> Bound<U>;
10}
11
12impl<T> BoundExt<T> for Bound<T> {
13    #[inline(always)]
14    fn as_ref(&self) -> Bound<&T> {
15        match self {
16            Self::Excluded(v) => Bound::Excluded(v),
17            Self::Included(v) => Bound::Included(v),
18            Self::Unbounded => Bound::Unbounded,
19        }
20    }
21
22    #[inline(always)]
23    fn map<U, F: FnOnce(T) -> U>(self, f: F) -> Bound<U> {
24        match self {
25            Self::Excluded(v) => Bound::Excluded(f(v)),
26            Self::Included(v) => Bound::Included(f(v)),
27            Self::Unbounded => Bound::Unbounded,
28        }
29    }
30}
31
32pub trait BoundedValue<B = Self>: 'static + Sized {
33    fn gen_bounded<D: Driver>(driver: &mut D, min: Bound<&B>, max: Bound<&B>) -> Option<Self>;
34
35    fn mutate_bounded<D: Driver>(
36        &mut self,
37        driver: &mut D,
38        min: Bound<&B>,
39        max: Bound<&B>,
40    ) -> Option<()> {
41        *self = Self::gen_bounded(driver, min, max)?;
42        Some(())
43    }
44}
45
46macro_rules! range_generator {
47    ($ty:ident) => {
48        impl<T: BoundedValue> ValueGenerator for core::ops::$ty<T> {
49            type Output = T;
50
51            #[inline]
52            fn generate<D: Driver>(&self, driver: &mut D) -> Option<Self::Output> {
53                let min = self.start_bound();
54                let max = self.end_bound();
55                T::gen_bounded(driver, min, max)
56            }
57
58            #[inline]
59            fn mutate<D: Driver>(&self, driver: &mut D, value: &mut Self::Output) -> Option<()> {
60                let min = self.start_bound();
61                let max = self.end_bound();
62                value.mutate_bounded(driver, min, max)
63            }
64        }
65    };
66}
67
68range_generator!(Range);
69range_generator!(RangeFrom);
70range_generator!(RangeInclusive);
71range_generator!(RangeTo);
72range_generator!(RangeToInclusive);
73
74impl<T: BoundedValue> ValueGenerator for (core::ops::Bound<T>, core::ops::Bound<T>) {
75    type Output = T;
76
77    #[inline]
78    fn generate<D: Driver>(&self, driver: &mut D) -> Option<Self::Output> {
79        let min = self.start_bound();
80        let max = self.end_bound();
81        T::gen_bounded(driver, min, max)
82    }
83
84    #[inline]
85    fn mutate<D: Driver>(&self, driver: &mut D, value: &mut Self::Output) -> Option<()> {
86        let min = self.start_bound();
87        let max = self.end_bound();
88        value.mutate_bounded(driver, min, max)
89    }
90}
91
92#[derive(Debug)]
93pub struct BoundedGenerator<T, B> {
94    range_bounds: B,
95    output: PhantomData<T>,
96}
97
98impl<T: BoundedValue, B: RangeBounds<T>> BoundedGenerator<T, B> {
99    pub fn new(range_bounds: B) -> Self {
100        BoundedGenerator {
101            range_bounds,
102            output: PhantomData,
103        }
104    }
105
106    pub fn bounds<NewB: RangeBounds<T>>(self, range_bounds: NewB) -> BoundedGenerator<T, NewB> {
107        BoundedGenerator {
108            range_bounds,
109            output: PhantomData,
110        }
111    }
112}
113
114impl<T: BoundedValue, B: RangeBounds<T>> ValueGenerator for BoundedGenerator<T, B> {
115    type Output = T;
116
117    #[inline]
118    fn generate<D: Driver>(&self, driver: &mut D) -> Option<Self::Output> {
119        let min = self.range_bounds.start_bound();
120        let max = self.range_bounds.end_bound();
121        T::gen_bounded(driver, min, max)
122    }
123
124    #[inline]
125    fn mutate<D: Driver>(&self, driver: &mut D, value: &mut Self::Output) -> Option<()> {
126        let min = self.range_bounds.start_bound();
127        let max = self.range_bounds.end_bound();
128        value.mutate_bounded(driver, min, max)
129    }
130}
131
132#[test]
133fn with_bounds_test() {
134    let _ = generator_test!(produce::<u8>().with().bounds(0..32));
135}