Skip to main content

rstat/params/
constraints.rs

1use crate::linalg::{Matrix, Vector};
2use failure::{Backtrace, Fail};
3use num::{zero, PrimInt, Zero};
4use std::{
5    fmt::{self, Debug, Display},
6    ops::Deref,
7};
8
9macro_rules! assert_constraint {
10    ($x:ident+) => {
11        $crate::params::constraints::Constraint::check($crate::params::constraints::Positive, $x)
12    };
13    ($x:ident == $t:tt) => {
14        $crate::params::constraints::Constraint::check($crate::params::constraints::Equal($t), $x)
15    };
16    ($x:ident < $t:tt) => {
17        $crate::params::constraints::Constraint::check(
18            $crate::params::constraints::LessThan($t),
19            $x,
20        )
21    };
22    ($x:ident <= $t:tt) => {
23        $crate::params::constraints::Constraint::check(
24            $crate::params::constraints::LessThanOrEqual($t),
25            $x,
26        )
27    };
28    ($x:tt > $t:tt) => {
29        $crate::params::constraints::Constraint::check(
30            $crate::params::constraints::GreaterThan($t),
31            $x,
32        )
33    };
34    ($x:ident >= $t:tt) => {
35        $crate::params::constraints::Constraint::check(
36            $crate::params::constraints::GreaterThanOrEqual($t),
37            $x,
38        )
39    };
40}
41
42pub struct UnsatisfiedConstraintError<T> {
43    pub value: T,
44    pub target: Option<String>,
45    pub constraint: Box<dyn Constraint<T>>,
46}
47
48impl<T: Debug + Send + Sync + 'static> Fail for UnsatisfiedConstraintError<T> {
49    fn name(&self) -> Option<&str> { Some("UnsatisfiedConstraint") }
50
51    fn cause(&self) -> Option<&dyn Fail> { None }
52
53    fn backtrace(&self) -> Option<&Backtrace> { None }
54}
55
56impl<T> UnsatisfiedConstraintError<T> {
57    pub fn new(value: T, constraint: Box<dyn Constraint<T>>) -> Self {
58        UnsatisfiedConstraintError {
59            value,
60            target: None,
61            constraint,
62        }
63    }
64
65    pub fn with_target<S: ToString>(self, target: S) -> Self {
66        UnsatisfiedConstraintError {
67            value: self.value,
68            target: Some(target.to_string()),
69            constraint: self.constraint,
70        }
71    }
72}
73
74impl<T: Debug> Debug for UnsatisfiedConstraintError<T> {
75    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::result::Result<(), std::fmt::Error> {
76        if let Some(ref target) = self.target {
77            f.debug_struct("UnsatisfiedConstraintError")
78                .field("value", &self.value)
79                .field("target", target)
80                .field("constraint", &self.constraint)
81                .finish()
82        } else {
83            f.debug_struct("UnsatisfiedConstraintError")
84                .field("value", &self.value)
85                .field("constraint", &self.constraint)
86                .finish()
87        }
88    }
89}
90
91impl<T: Debug> Display for UnsatisfiedConstraintError<T> {
92    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::result::Result<(), std::fmt::Error> {
93        if let Some(ref target) = self.target {
94            write!(
95                f,
96                "Constraint {} on {:?} is unsatisfied for value {:?}.",
97                self.constraint, target, self.value
98            )
99        } else {
100            write!(
101                f,
102                "Constraint {} is unsatisfied for value {:?}.",
103                self.constraint, self.value
104            )
105        }
106    }
107}
108
109pub(crate) type Result<T> = std::result::Result<T, UnsatisfiedConstraintError<T>>;
110
111pub trait Constraint<T>: Display + Debug + Send + Sync {
112    fn is_satisfied_by(&self, value: &T) -> bool;
113
114    fn check(self, value: T) -> Result<T>
115    where
116        T: Debug,
117        Self: Sized + 'static,
118    {
119        if self.is_satisfied_by(&value) {
120            Ok(value)
121        } else {
122            Err(UnsatisfiedConstraintError::new(value, Box::new(self)))
123        }
124    }
125}
126
127pub type Constraints<T> = Vec<Box<dyn Constraint<T>>>;
128
129impl<T, C: Constraint<T> + ?Sized> Constraint<T> for Box<C> {
130    fn is_satisfied_by(&self, value: &T) -> bool { self.deref().is_satisfied_by(value) }
131}
132
133macro_rules! impl_display {
134    ($type:ty) => {
135        impl Display for $type {
136            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, stringify!($type)) }
137        }
138    };
139}
140
141macro_rules! impl_constraint {
142    ($name:ident<$tc0:ident $(+ $tc:ident)*>; $self:ident, $value:ident, $impl:block) => {
143        impl_display!($name);
144
145        impl<T: $tc0 $(+$tc)*> Constraint<T> for $name {
146            fn is_satisfied_by(&$self, $value: &T) -> bool { $impl }
147        }
148    };
149    ($name:ident; $self:ident, $value:ident, $impl:block) => {
150        impl_display!($name);
151
152        impl<T> Constraint<T> for $name {
153            fn is_satisfied_by(&$self, $value: &T) -> bool { $impl }
154        }
155    }
156}
157
158#[derive(Debug, Clone, Copy)]
159pub struct All<C>(pub C);
160
161impl<C: Display> Display for All<C> {
162    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "All({})", self.0) }
163}
164
165impl<C: Constraint<f64>> Constraint<f64> for All<C> {
166    fn check(self, value: f64) -> Result<f64>
167    where Self: Sized + 'static {
168        self.0.check(value)
169    }
170
171    fn is_satisfied_by(&self, value: &f64) -> bool { self.0.is_satisfied_by(value) }
172}
173
174impl<C: Constraint<i64>> Constraint<i64> for All<C> {
175    fn check(self, value: i64) -> Result<i64>
176    where Self: Sized + 'static {
177        self.0.check(value)
178    }
179
180    fn is_satisfied_by(&self, value: &i64) -> bool { self.0.is_satisfied_by(value) }
181}
182
183impl<C: Constraint<usize>> Constraint<usize> for All<C> {
184    fn check(self, value: usize) -> Result<usize>
185    where Self: Sized + 'static {
186        self.0.check(value)
187    }
188
189    fn is_satisfied_by(&self, value: &usize) -> bool { self.0.is_satisfied_by(value) }
190}
191
192impl<T, C: Constraint<T>> Constraint<Vec<T>> for All<C> {
193    fn check(self, value: Vec<T>) -> Result<Vec<T>>
194    where
195        Vec<T>: Debug,
196        Self: Sized + 'static,
197    {
198        for v in value.iter() {
199            if !self.0.is_satisfied_by(v) {
200                return Err(UnsatisfiedConstraintError::new(value, Box::new(self)));
201            }
202        }
203
204        Ok(value)
205    }
206
207    fn is_satisfied_by(&self, value: &Vec<T>) -> bool {
208        value.iter().all(|v| self.0.is_satisfied_by(v))
209    }
210}
211
212impl<T, C: Constraint<T>> Constraint<[T; 2]> for All<C> {
213    fn check(self, value: [T; 2]) -> Result<[T; 2]>
214    where
215        [T; 2]: Debug,
216        Self: Sized + 'static,
217    {
218        for v in value.iter() {
219            if !self.0.is_satisfied_by(v) {
220                return Err(UnsatisfiedConstraintError::new(value, Box::new(self)));
221            }
222        }
223
224        Ok(value)
225    }
226
227    fn is_satisfied_by(&self, value: &[T; 2]) -> bool {
228        value.iter().all(|v| self.0.is_satisfied_by(v))
229    }
230}
231
232impl<T, C: Constraint<T>> Constraint<Vector<T>> for All<C> {
233    fn check(self, value: Vector<T>) -> Result<Vector<T>>
234    where
235        Vector<T>: Debug,
236        Self: Sized + 'static,
237    {
238        for v in value.iter() {
239            if !self.0.is_satisfied_by(v) {
240                return Err(UnsatisfiedConstraintError::new(value, Box::new(self)));
241            }
242        }
243
244        Ok(value)
245    }
246
247    fn is_satisfied_by(&self, value: &Vector<T>) -> bool {
248        value.iter().all(|v| self.0.is_satisfied_by(v))
249    }
250}
251
252impl<T, C: Constraint<T>> Constraint<Matrix<T>> for All<C> {
253    fn check(self, value: Matrix<T>) -> Result<Matrix<T>>
254    where
255        Matrix<T>: Debug,
256        Self: Sized + 'static,
257    {
258        for v in value.iter() {
259            if !self.0.is_satisfied_by(&v) {
260                return Err(UnsatisfiedConstraintError::new(value, Box::new(self)));
261            }
262        }
263
264        Ok(value)
265    }
266
267    fn is_satisfied_by(&self, value: &Matrix<T>) -> bool {
268        value.iter().all(|v| self.0.is_satisfied_by(v))
269    }
270}
271
272#[derive(Debug, Clone, Copy)]
273pub struct Not<C>(pub C);
274
275impl<C: Display> Display for Not<C> {
276    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "Not({})", self.0) }
277}
278
279impl<T, C: Constraint<T>> Constraint<T> for Not<C> {
280    fn is_satisfied_by(&self, value: &T) -> bool { !self.0.is_satisfied_by(value) }
281}
282
283#[derive(Debug, Clone, Copy)]
284pub struct Or<C1, C2>(pub (C1, C2));
285
286impl<C1: Display, C2: Display> Display for Or<C1, C2> {
287    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
288        write!(f, "Or({} || {})", (self.0).0, (self.0).1)
289    }
290}
291
292impl<T, C1: Constraint<T>, C2: Constraint<T>> Constraint<T> for Or<C1, C2> {
293    fn check(self, value: T) -> Result<T>
294    where
295        T: Debug,
296        Self: Sized + 'static,
297    {
298        let (c1, c2) = self.0;
299
300        c1.check(value).or_else(|uc| c2.check(uc.value))
301    }
302
303    fn is_satisfied_by(&self, value: &T) -> bool {
304        (self.0).0.is_satisfied_by(value) || (self.0).1.is_satisfied_by(value)
305    }
306}
307
308#[derive(Debug, Clone, Copy)]
309pub struct And<C1, C2>(pub (C1, C2));
310
311impl<C1: Display, C2: Display> Display for And<C1, C2> {
312    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
313        write!(f, "And({} || {})", (self.0).0, (self.0).1)
314    }
315}
316
317impl<T, C1: Constraint<T>, C2: Constraint<T>> Constraint<T> for And<C1, C2> {
318    fn check(self, value: T) -> Result<T>
319    where
320        T: Debug,
321        Self: Sized + 'static,
322    {
323        let (c1, c2) = self.0;
324
325        c1.check(value).and_then(|value| c2.check(value))
326    }
327
328    fn is_satisfied_by(&self, value: &T) -> bool {
329        (self.0).0.is_satisfied_by(value) && (self.0).1.is_satisfied_by(value)
330    }
331}
332
333#[derive(Debug, Clone, Copy)]
334pub struct Equal<T>(pub T);
335
336impl<T: PartialEq + Debug> Display for Equal<T> {
337    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "Equal({:?})", self.0) }
338}
339
340impl<T: PartialEq + Send + Sync + Debug> Constraint<T> for Equal<T> {
341    fn is_satisfied_by(&self, value: &T) -> bool { value == &self.0 }
342}
343
344#[derive(Debug, Clone, Copy)]
345pub struct LessThan<T>(pub T);
346
347impl<T: PartialOrd + Debug> Display for LessThan<T> {
348    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "LessThan({:?})", self.0) }
349}
350
351impl<T: PartialOrd + Send + Sync + Debug> Constraint<T> for LessThan<T> {
352    fn is_satisfied_by(&self, value: &T) -> bool { value < &self.0 }
353}
354
355#[derive(Debug, Clone, Copy)]
356pub struct Negative;
357
358impl_constraint!(Negative<PartialOrd + Zero>; self, value, { value < &zero() });
359
360#[derive(Debug, Clone, Copy)]
361pub struct LessThanOrEqual<T>(pub T);
362
363impl<T: PartialOrd + Debug> Display for LessThanOrEqual<T> {
364    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
365        write!(f, "LessThanOrEqual({:?})", self.0)
366    }
367}
368
369impl<T: PartialOrd + Send + Sync + Debug> Constraint<T> for LessThanOrEqual<T> {
370    fn is_satisfied_by(&self, value: &T) -> bool { value <= &self.0 }
371}
372
373#[derive(Debug, Clone, Copy)]
374pub struct NonPositive;
375
376impl_constraint!(NonPositive<PartialOrd + Zero>; self, value, { value <= &zero() });
377
378#[derive(Debug, Clone, Copy)]
379pub struct GreaterThan<T>(pub T);
380
381impl<T: PartialOrd + Debug> Display for GreaterThan<T> {
382    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "GreaterThan({:?})", self.0) }
383}
384
385impl<T: PartialOrd + Send + Sync + Debug> Constraint<T> for GreaterThan<T> {
386    fn is_satisfied_by(&self, value: &T) -> bool { value > &self.0 }
387}
388
389#[derive(Debug, Clone, Copy)]
390pub struct Positive;
391
392impl_constraint!(Positive<PartialOrd + Zero>; self, value, { value > &zero() });
393
394#[derive(Debug, Clone, Copy)]
395pub struct GreaterThanOrEqual<T>(pub T);
396
397impl<T: PartialOrd + Debug> Display for GreaterThanOrEqual<T> {
398    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
399        write!(f, "GreaterThanOrEqual({:?})", self.0)
400    }
401}
402
403impl<T: PartialOrd + Send + Sync + Debug> Constraint<T> for GreaterThanOrEqual<T> {
404    fn is_satisfied_by(&self, value: &T) -> bool { value >= &self.0 }
405}
406
407#[derive(Debug, Clone, Copy)]
408pub struct Interval<T> {
409    pub lb: T,
410    pub ub: T,
411}
412
413impl<T: PartialOrd + Debug> Display for Interval<T> {
414    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
415        write!(f, "[{:?}, {:?}]", self.lb, self.ub)
416    }
417}
418
419impl<T: PartialOrd + Send + Sync + Debug> Constraint<T> for Interval<T> {
420    fn is_satisfied_by(&self, value: &T) -> bool { value >= &self.lb && value <= &self.ub }
421}
422
423#[derive(Debug, Clone, Copy)]
424pub struct NonNegative;
425
426impl_constraint!(NonNegative<PartialOrd + Zero>; self, value, { value >= &zero() });
427
428#[derive(Debug, Clone, Copy)]
429pub struct Natural;
430
431impl_constraint!(Natural<PrimInt>; self, value, { value > &zero() });
432
433#[derive(Debug, Clone, Copy)]
434pub struct Empty;
435
436impl Display for Empty {
437    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "Empty") }
438}
439
440impl<T> Constraint<Vec<T>> for Empty {
441    fn is_satisfied_by(&self, vec: &Vec<T>) -> bool { vec.len() == 0 }
442}
443
444impl<T> Constraint<Vector<T>> for Empty {
445    fn is_satisfied_by(&self, vector: &Vector<T>) -> bool { vector.len() == 0 }
446}
447
448#[derive(Debug, Clone, Copy)]
449pub struct Square;
450
451impl Display for Square {
452    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "Square") }
453}
454
455impl<T> Constraint<Matrix<T>> for Square {
456    fn is_satisfied_by(&self, matrix: &Matrix<T>) -> bool { matrix.is_square() }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    #[test]
464    fn test_non_negative() {
465        let c = NonNegative;
466
467        assert_eq!(c.to_string(), "NonNegative");
468
469        assert!(!c.is_satisfied_by(&-1.50));
470        assert!(!c.is_satisfied_by(&-0.50));
471
472        assert!(c.is_satisfied_by(&0.50));
473        assert!(c.is_satisfied_by(&1.50));
474    }
475}