refinement/
lib.rs

1//! Convenient creation of type-safe refinement types.
2//!
3//! This crate tries to capture the idea of a refinement type, which
4//! is a type endowed with a predicate which is assumed to hold for
5//! any element of the refined type.[^1]
6//!
7//! Refinement types are useful when only certain values of a type are expected at runtime.
8//! As an example, suppose there's a function that only logically works on even integers.
9//!
10//! ```should_panic
11//! fn takes_even(i: i32) {
12//!     if i % 2 == 0 {
13//!         println!("Received even number {}", i);
14//!     } else {
15//!         panic!("Received odd number");
16//!     }
17//! }
18//!
19//! takes_even(1);  // oops
20//! ```
21//!
22//! Using a refinement type, this function may be defined in a way where it is impossible to supply
23//! an odd number.
24//!
25//! ```
26//! use refinement::{Refinement, Predicate};
27//!
28//! struct Even;
29//!
30//! impl Predicate<i32> for Even {
31//!     fn test(x: &i32) -> bool {
32//!         *x % 2 == 0
33//!     }
34//! }
35//!
36//! type EvenInt = Refinement<i32, Even>;
37//!
38//! fn takes_even(i: EvenInt) {
39//!     println!("Received even number {}", i);
40//! }
41//!
42//! match EvenInt::new(4) {
43//!     Some(x) => takes_even(x),  // "Received even number 4"
44//!     None => { /* ... */ }      // Handle logical error
45//! }
46//!
47//! ```
48//! [^1]: https://en.wikipedia.org/wiki/Refinement_type
49
50use std::borrow::Borrow;
51use std::convert::AsRef;
52use std::fmt;
53use std::marker::PhantomData;
54use std::ops::{
55    Add, BitAnd, BitOr, BitXor, Bound, Div, Index, Mul, Neg, Not, RangeBounds, Rem, Shl, Shr, Sub,
56};
57
58/// A `Predicate` tests if a value satisfies a particular refinement type.
59///
60/// Used in conjunction with [`Refinement`](self::Refinement).
61///
62/// # Example
63///
64/// ```
65/// use refinement::Predicate;
66///
67/// struct LessThanTen;
68///
69/// impl Predicate<i32> for LessThanTen {
70///     fn test(x: &i32) -> bool {
71///         *x < 10
72///     }
73/// }
74///
75/// ```
76pub trait Predicate<T> {
77    /// Test if a value satisfies the `Predicate`.
78    ///
79    /// See [`Refinement`](Refinement) for usage examples.
80    fn test(x: &T) -> bool;
81}
82
83/// A `Refinement` type ensures all values of a particular type satisfy a [`Predicate`].
84///
85/// Use [`as_inner`](Refinement::as_inner)/[`to_inner`](Refinement::to_inner) to access the
86/// underlying value or [`into_inner`](Refinement::into_inner) to unwrap the value.
87///
88/// `Refinement` also implements many common standard library traits if the underlying
89/// value also implements them.
90///
91/// # Examples
92/// ```
93/// use refinement::{Predicate, Refinement};
94///
95/// struct LessThanTen;
96///
97/// impl Predicate<i32> for LessThanTen {
98///     fn test(x: &i32) -> bool {
99///         *x < 10
100///     }
101/// }
102///
103/// type LessThanTenInt = Refinement<i32, LessThanTen>;
104///
105/// let x = LessThanTenInt::new(5);
106/// assert!(x.is_some());
107///
108/// let y = LessThanTenInt::new(11);
109/// assert!(y.is_none());
110/// ```
111#[derive(PartialEq, Eq, PartialOrd, Ord, Hash)]
112pub struct Refinement<T, P>(T, PhantomData<P>);
113
114impl<T, P> Refinement<T, P>
115where
116    P: Predicate<T>,
117{
118    /// Create a refined value from the underlying type `T`.
119    ///
120    /// Returns `x` under the refinement type if `x` satisfies `P`, otherwise returns [`None`](std::option::Option::None).
121    ///
122    /// # Examples
123    ///
124    /// ```
125    /// use refinement::{Predicate, Refinement};
126    ///
127    /// struct NonEmpty;
128    ///
129    /// impl Predicate<String> for NonEmpty {
130    ///     fn test(x: &String) -> bool {
131    ///        !x.is_empty()
132    ///     }
133    /// }
134    ///
135    /// type NonEmptyString = Refinement<String, NonEmpty>;
136    ///
137    /// let s1 = NonEmptyString::new(String::from("Hello"));
138    /// assert!(s1.is_some());
139    ///
140    /// let s2 = NonEmptyString::new(String::from(""));
141    /// assert!(s2.is_none());
142    /// ```
143    pub fn new(x: T) -> Option<Self> {
144        if P::test(&x) {
145            Some(Refinement(x, PhantomData))
146        } else {
147            None
148        }
149    }
150
151    /// Unwrap the underlying value, consuming `self`.
152    ///
153    /// # Examples
154    ///
155    /// ```
156    /// use refinement::{Predicate, Refinement};
157    ///
158    /// struct ThreeDigit;
159    ///
160    /// impl Predicate<String> for ThreeDigit {
161    ///     fn test(x: &String) -> bool {
162    ///        x.chars().count() == 3 && x.chars().filter(|c| c.is_ascii_digit()).count() == 3
163    ///     }
164    /// }
165    ///
166    /// type ThreeDigitString = Refinement<String, ThreeDigit>;
167    ///
168    /// let s = ThreeDigitString::new(String::from("123"));
169    ///
170    /// assert_eq!(String::from("123"), s.unwrap().into_inner());
171    /// ```
172    pub fn into_inner(self) -> T {
173        self.0
174    }
175}
176
177impl<T, P> std::fmt::Debug for Refinement<T, P>
178where
179    T: std::fmt::Debug
180{
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        f.debug_tuple("Refinement")
183            .field(&self.0) // : T
184            .field(&format_args!("_")) // : PhantomData<P>
185            .finish()
186    }
187}
188
189impl<T: Clone, P> Clone for Refinement<T, P> {
190    fn clone(&self) -> Self {
191        Self(self.0.clone(), self.1.clone())
192    }
193}
194
195impl<T: Copy, P> Copy for Refinement<T, P> {}
196
197impl<T, P> std::ops::Deref for Refinement<T, P> {
198    type Target = T;
199
200    fn deref(&self) -> &Self::Target {
201        &self.0
202    }
203}
204
205impl<T, P> std::ops::DerefMut for Refinement<T, P> {
206    fn deref_mut(&mut self) -> &mut Self::Target {
207        &mut self.0
208    }
209}
210
211impl<T, P> Refinement<T, P>
212where
213    T: Clone,
214    P: Predicate<T>,
215{
216    /// Retrieve the underlying value without consuming `self`.
217    ///
218    /// # Examples
219    ///
220    /// ```
221    /// use refinement::{Predicate, Refinement};
222    ///
223    /// struct ThreeDigit;
224    ///
225    /// impl Predicate<String> for ThreeDigit {
226    ///     fn test(x: &String) -> bool {
227    ///        x.chars().count() == 3 && x.chars().filter(|c| c.is_ascii_digit()).count() == 3
228    ///     }
229    /// }
230    ///
231    /// type ThreeDigitString = Refinement<String, ThreeDigit>;
232    ///
233    /// let s = ThreeDigitString::new(String::from("123"));
234    ///
235    /// assert_eq!(String::from("123"), s.unwrap().to_inner());
236    /// ```
237    pub fn to_inner(&self) -> T {
238        self.0.clone()
239    }
240}
241
242impl<T, P> Refinement<T, P>
243where
244    T: Copy,
245    P: Predicate<T>,
246{
247    /// Retrieve the underlying value for [`Copy`] types without consuming `self`.
248    ///
249    /// # Examples
250    ///
251    /// ```
252    /// use refinement::{Predicate, Refinement};
253    ///
254    /// struct LessThanTen;
255    ///
256    /// impl Predicate<i32> for LessThanTen {
257    ///     fn test(x: &i32) -> bool {
258    ///         *x < 10
259    ///     }
260    /// }
261    ///
262    /// type LessThanTenInt = Refinement<i32, LessThanTen>;
263    ///
264    /// let x = LessThanTenInt::new(5);
265    /// assert_eq!(5, x.unwrap().as_inner());
266    /// ```
267    pub fn as_inner(&self) -> T {
268        self.0
269    }
270}
271
272impl<T, P> Borrow<T> for Refinement<T, P>
273where
274    P: Predicate<T>,
275{
276    fn borrow(&self) -> &T {
277        &self.0
278    }
279}
280
281impl<T, P> AsRef<T> for Refinement<T, P>
282where
283    P: Predicate<T>,
284{
285    fn as_ref(&self) -> &T {
286        &self.0
287    }
288}
289
290impl<T, P> Add<T> for Refinement<T, P>
291where
292    T: Add<Output = T>,
293    P: Predicate<T>,
294{
295    type Output = Option<Self>;
296    fn add(self, rhs: T) -> Self::Output {
297        Self::new(self.0 + rhs)
298    }
299}
300
301impl<T, P> Add for Refinement<T, P>
302where
303    T: Add<Output = T>,
304    P: Predicate<T>,
305{
306    type Output = Option<Self>;
307    fn add(self, rhs: Self) -> Self::Output {
308        Self::new(self.0 + rhs.0)
309    }
310}
311
312impl<T, P> BitAnd<T> for Refinement<T, P>
313where
314    T: BitAnd<Output = T>,
315    P: Predicate<T>,
316{
317    type Output = Option<Self>;
318    fn bitand(self, rhs: T) -> Self::Output {
319        Self::new(self.0 & rhs)
320    }
321}
322
323impl<T, P> BitAnd for Refinement<T, P>
324where
325    T: BitAnd<Output = T>,
326    P: Predicate<T>,
327{
328    type Output = Option<Self>;
329    fn bitand(self, rhs: Self) -> Self::Output {
330        Self::new(self.0 & rhs.0)
331    }
332}
333
334impl<T, P> BitOr<T> for Refinement<T, P>
335where
336    T: BitOr<Output = T>,
337    P: Predicate<T>,
338{
339    type Output = Option<Self>;
340    fn bitor(self, rhs: T) -> Self::Output {
341        Self::new(self.0 | rhs)
342    }
343}
344
345impl<T, P> BitOr for Refinement<T, P>
346where
347    T: BitOr<Output = T>,
348    P: Predicate<T>,
349{
350    type Output = Option<Self>;
351    fn bitor(self, rhs: Self) -> Self::Output {
352        Self::new(self.0 | rhs.0)
353    }
354}
355
356impl<T, P> BitXor<T> for Refinement<T, P>
357where
358    T: BitXor<Output = T>,
359    P: Predicate<T>,
360{
361    type Output = Option<Self>;
362    fn bitxor(self, rhs: T) -> Self::Output {
363        Self::new(self.0 ^ rhs)
364    }
365}
366
367impl<T, P> BitXor for Refinement<T, P>
368where
369    T: BitXor<Output = T>,
370    P: Predicate<T>,
371{
372    type Output = Option<Self>;
373    fn bitxor(self, rhs: Self) -> Self::Output {
374        Self::new(self.0 ^ rhs.0)
375    }
376}
377
378impl<T, P> Div<T> for Refinement<T, P>
379where
380    T: Div<Output = T>,
381    P: Predicate<T>,
382{
383    type Output = Option<Self>;
384    fn div(self, rhs: T) -> Self::Output {
385        Self::new(self.0 / rhs)
386    }
387}
388
389impl<T, P> Div for Refinement<T, P>
390where
391    T: Div<Output = T>,
392    P: Predicate<T>,
393{
394    type Output = Option<Self>;
395    fn div(self, rhs: Self) -> Self::Output {
396        Self::new(self.0 / rhs.0)
397    }
398}
399
400impl<T, P, I> Index<I> for Refinement<T, P>
401where
402    T: Index<I>,
403    P: Predicate<T>,
404{
405    type Output = T::Output;
406    fn index(&self, index: I) -> &Self::Output {
407        self.0.index(index)
408    }
409}
410
411impl<T, P> Mul<T> for Refinement<T, P>
412where
413    T: Mul<Output = T>,
414    P: Predicate<T>,
415{
416    type Output = Option<Self>;
417    fn mul(self, rhs: T) -> Self::Output {
418        Self::new(self.0 * rhs)
419    }
420}
421
422impl<T, P> Mul for Refinement<T, P>
423where
424    T: Mul<Output = T>,
425    P: Predicate<T>,
426{
427    type Output = Option<Self>;
428    fn mul(self, rhs: Self) -> Self::Output {
429        Self::new(self.0 * rhs.0)
430    }
431}
432
433impl<T, P> Neg for Refinement<T, P>
434where
435    T: Neg<Output = T>,
436    P: Predicate<T>,
437{
438    type Output = Option<Self>;
439    fn neg(self) -> Self::Output {
440        Self::new(self.0.neg())
441    }
442}
443
444impl<T, P> Not for Refinement<T, P>
445where
446    T: Not<Output = T>,
447    P: Predicate<T>,
448{
449    type Output = Option<Self>;
450    fn not(self) -> Self::Output {
451        Self::new(self.0.not())
452    }
453}
454
455impl<T, P, B> RangeBounds<B> for Refinement<T, P>
456where
457    T: RangeBounds<B>,
458    P: Predicate<T>,
459{
460    fn start_bound(&self) -> Bound<&B> {
461        self.0.start_bound()
462    }
463
464    fn end_bound(&self) -> Bound<&B> {
465        self.0.end_bound()
466    }
467}
468
469impl<T, P> Rem<T> for Refinement<T, P>
470where
471    T: Rem<Output = T>,
472    P: Predicate<T>,
473{
474    type Output = Option<Self>;
475    fn rem(self, rhs: T) -> Self::Output {
476        Self::new(self.0 % rhs)
477    }
478}
479
480impl<T, P> Rem for Refinement<T, P>
481where
482    T: Rem<Output = T>,
483    P: Predicate<T>,
484{
485    type Output = Option<Self>;
486    fn rem(self, rhs: Self) -> Self::Output {
487        Self::new(self.0 % rhs.0)
488    }
489}
490
491impl<T, P> Shl<T> for Refinement<T, P>
492where
493    T: Shl<Output = T>,
494    P: Predicate<T>,
495{
496    type Output = Option<Self>;
497    fn shl(self, rhs: T) -> Self::Output {
498        Self::new(self.0 << rhs)
499    }
500}
501
502impl<T, P> Shl for Refinement<T, P>
503where
504    T: Shl<Output = T>,
505    P: Predicate<T>,
506{
507    type Output = Option<Self>;
508    fn shl(self, rhs: Self) -> Self::Output {
509        Self::new(self.0 << rhs.0)
510    }
511}
512
513impl<T, P> Shr<T> for Refinement<T, P>
514where
515    T: Shr<Output = T>,
516    P: Predicate<T>,
517{
518    type Output = Option<Self>;
519    fn shr(self, rhs: T) -> Self::Output {
520        Self::new(self.0 >> rhs)
521    }
522}
523
524impl<T, P> Shr for Refinement<T, P>
525where
526    T: Shr<Output = T>,
527    P: Predicate<T>,
528{
529    type Output = Option<Self>;
530    fn shr(self, rhs: Self) -> Self::Output {
531        Self::new(self.0 >> rhs.0)
532    }
533}
534
535impl<T, P> Sub<T> for Refinement<T, P>
536where
537    T: Sub<Output = T>,
538    P: Predicate<T>,
539{
540    type Output = Option<Self>;
541    fn sub(self, rhs: T) -> Self::Output {
542        Self::new(self.0 - rhs)
543    }
544}
545
546impl<T, P> Sub for Refinement<T, P>
547where
548    T: Sub<Output = T>,
549    P: Predicate<T>,
550{
551    type Output = Option<Self>;
552    fn sub(self, rhs: Self) -> Self::Output {
553        Self::new(self.0 - rhs.0)
554    }
555}
556
557impl<T, P> fmt::Display for Refinement<T, P>
558where
559    T: fmt::Display,
560    P: Predicate<T>,
561{
562    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
563        self.0.fmt(f)
564    }
565}