refinement_types/
length.rs

1//! Predicates based on length.
2
3use core::{fmt, marker::PhantomData};
4
5#[cfg(feature = "diagnostics")]
6use miette::Diagnostic;
7
8use thiserror::Error;
9
10use crate::{
11    core::Predicate,
12    logic::{And, Not},
13};
14
15/// Represents types that have length defined for their values.
16pub trait HasLength {
17    /// Returns the value length.
18    fn length(&self) -> usize;
19}
20
21/// Represents errors that occur when the provided value has
22/// length greater than or equal to some bound.
23#[derive(Debug, Error)]
24#[error("received value with length >= {other}")]
25#[cfg_attr(
26    feature = "diagnostics",
27    derive(Diagnostic),
28    diagnostic(code(length::lt), help("make sure the length is less than {other}"))
29)]
30pub struct LessError {
31    /// The length against which the check was performed (the `N`).
32    pub other: usize,
33}
34
35impl LessError {
36    /// Constructs [`Self`].
37    pub const fn new(other: usize) -> Self {
38        Self { other }
39    }
40}
41
42/// Checks whether the given value has length less than `N`.
43pub struct Less<const N: usize> {
44    private: PhantomData<()>,
45}
46
47impl<const N: usize, T: HasLength + ?Sized> Predicate<T> for Less<N> {
48    type Error = LessError;
49
50    fn check(value: &T) -> Result<(), Self::Error> {
51        if value.length() < N {
52            Ok(())
53        } else {
54            Err(Self::Error::new(N))
55        }
56    }
57
58    fn expect(formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
59        write!(formatter, "value with length < {N}")
60    }
61
62    fn expect_code(formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
63        write!(formatter, "length::lt<{N}>")
64    }
65}
66
67/// Represents errors that occur when the provided value has
68/// length greater than some bound.
69#[derive(Debug, Error)]
70#[error("received value with length > {other}")]
71#[cfg_attr(
72    feature = "diagnostics",
73    derive(Diagnostic),
74    diagnostic(
75        code(length::le),
76        help("make sure the length is less than or equal to {other}")
77    )
78)]
79pub struct LessOrEqualError {
80    /// The length against which the check was performed (the `N`).
81    pub other: usize,
82}
83
84impl LessOrEqualError {
85    /// Constructs [`Self`].
86    pub const fn new(other: usize) -> Self {
87        Self { other }
88    }
89}
90
91/// Checks whether the given value has length less than or equal to `N`.
92pub struct LessOrEqual<const N: usize> {
93    private: PhantomData<()>,
94}
95
96impl<const N: usize, T: HasLength + ?Sized> Predicate<T> for LessOrEqual<N> {
97    type Error = LessOrEqualError;
98
99    fn check(value: &T) -> Result<(), Self::Error> {
100        if value.length() <= N {
101            Ok(())
102        } else {
103            Err(Self::Error::new(N))
104        }
105    }
106
107    fn expect(formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
108        write!(formatter, "value with length <= {N}")
109    }
110
111    fn expect_code(formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
112        write!(formatter, "length::le<{N}>")
113    }
114}
115
116/// Represents errors that occur when the provided value has
117/// length less than or equal to some bound.
118#[derive(Debug, Error)]
119#[error("received value with length <= {other}")]
120#[cfg_attr(
121    feature = "diagnostics",
122    derive(Diagnostic),
123    diagnostic(code(length::gt), help("make sure the length is greater than {other}"))
124)]
125pub struct GreaterError {
126    /// The length against which the check was performed (the `N`).
127    pub other: usize,
128}
129
130impl GreaterError {
131    /// Constructs [`Self`].
132    pub const fn new(other: usize) -> Self {
133        Self { other }
134    }
135}
136
137/// Checks whether the given value has length greater than `N`.
138pub struct Greater<const N: usize> {
139    private: PhantomData<()>,
140}
141
142impl<const N: usize, T: HasLength + ?Sized> Predicate<T> for Greater<N> {
143    type Error = GreaterError;
144
145    fn check(value: &T) -> Result<(), Self::Error> {
146        if value.length() > N {
147            Ok(())
148        } else {
149            Err(Self::Error::new(N))
150        }
151    }
152
153    fn expect(formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
154        write!(formatter, "value with length > {N}")
155    }
156
157    fn expect_code(formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
158        write!(formatter, "length::gt<{N}>")
159    }
160}
161
162/// Represents errors that occur when the provided value has
163/// length less than some bound.
164#[derive(Debug, Error)]
165#[error("received value with length < {other}")]
166#[cfg_attr(
167    feature = "diagnostics",
168    derive(Diagnostic),
169    diagnostic(
170        code(length::ge),
171        help("make sure the length is greater than or equal to {other}")
172    )
173)]
174pub struct GreaterOrEqualError {
175    /// The length against which the check was performed (the `N`).
176    pub other: usize,
177}
178
179impl GreaterOrEqualError {
180    /// Constructs [`Self`].
181    pub const fn new(other: usize) -> Self {
182        Self { other }
183    }
184}
185
186/// Checks whether the given value has length greater than or equal to `N`.
187pub struct GreaterOrEqual<const N: usize> {
188    private: PhantomData<()>,
189}
190
191impl<const N: usize, T: HasLength + ?Sized> Predicate<T> for GreaterOrEqual<N> {
192    type Error = GreaterOrEqualError;
193
194    fn check(value: &T) -> Result<(), Self::Error> {
195        if value.length() >= N {
196            Ok(())
197        } else {
198            Err(Self::Error::new(N))
199        }
200    }
201
202    fn expect(formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
203        write!(formatter, "value with length >= {N}")
204    }
205
206    fn expect_code(formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
207        write!(formatter, "length::ge<{N}>")
208    }
209}
210
211/// Represents errors that occur when the provided value has
212/// length not equal to some bound.
213#[derive(Debug, Error)]
214#[error("received value with length != {other}")]
215#[cfg_attr(
216    feature = "diagnostics",
217    derive(Diagnostic),
218    diagnostic(code(length::eq), help("make sure the length is equal to {other}"))
219)]
220pub struct EqualError {
221    /// The length against which the check was performed (the `N`).
222    pub other: usize,
223}
224
225impl EqualError {
226    /// Constructs [`Self`].
227    pub const fn new(other: usize) -> Self {
228        Self { other }
229    }
230}
231
232/// Checks whether the given value has length equal to `N`.
233pub struct Equal<const N: usize> {
234    private: PhantomData<()>,
235}
236
237impl<const N: usize, T: HasLength + ?Sized> Predicate<T> for Equal<N> {
238    type Error = EqualError;
239
240    fn check(value: &T) -> Result<(), Self::Error> {
241        if value.length() == N {
242            Ok(())
243        } else {
244            Err(Self::Error::new(N))
245        }
246    }
247
248    fn expect(formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
249        write!(formatter, "value with length == {N}")
250    }
251
252    fn expect_code(formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
253        write!(formatter, "length::eq<{N}>")
254    }
255}
256
257/// Represents errors that occur when the provided value has
258/// length equal to some bound.
259#[derive(Debug, Error)]
260#[error("received value with length == {other}")]
261#[cfg_attr(
262    feature = "diagnostics",
263    derive(Diagnostic),
264    diagnostic(code(length::ne), help("make sure the length is not equal to {other}"))
265)]
266pub struct NotEqualError {
267    /// The length against which the check was performed (the `N`).
268    pub other: usize,
269}
270
271impl NotEqualError {
272    /// Constructs [`Self`].
273    pub const fn new(other: usize) -> Self {
274        Self { other }
275    }
276}
277
278/// Checks whether the given value has length not equal to `N`.
279pub struct NotEqual<const N: usize> {
280    private: PhantomData<()>,
281}
282
283impl<const N: usize, T: HasLength + ?Sized> Predicate<T> for NotEqual<N> {
284    type Error = NotEqualError;
285
286    #[allow(clippy::if_not_else)]
287    fn check(value: &T) -> Result<(), Self::Error> {
288        if value.length() != N {
289            Ok(())
290        } else {
291            Err(Self::Error::new(N))
292        }
293    }
294
295    fn expect(formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
296        write!(formatter, "value with length != {N}")
297    }
298
299    fn expect_code(formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
300        write!(formatter, "length::ne<{N}>")
301    }
302}
303
304/// Represents `(M, N)` intervals.
305pub type Open<const M: usize, const N: usize> = And<Greater<M>, Less<N>>;
306
307/// Represents `[M, N)` intervals.
308pub type ClosedOpen<const M: usize, const N: usize> = And<GreaterOrEqual<M>, Less<N>>;
309
310/// Represents `(M, N]` intervals.
311pub type OpenClosed<const M: usize, const N: usize> = And<Greater<M>, LessOrEqual<N>>;
312
313/// Represents `[M, N]` intervals.
314pub type Closed<const M: usize, const N: usize> = And<GreaterOrEqual<M>, LessOrEqual<N>>;
315
316/// Checks whether the given value has zero length.
317pub type Zero = Equal<0>;
318
319/// Checks whether the given value has non-zero length.
320pub type NonZero = NotEqual<0>;
321
322/// Represents errors when the provided value has
323/// length divided by [`divisor`] not equal to [`modulo`].
324///
325/// [`divisor`]: Self::divisor
326/// [`modulo`]: Self::modulo
327#[derive(Debug, Error)]
328#[error("received value % {divisor} != {modulo}")]
329pub struct ModuloError {
330    /// The divisor that the value length should be divided by (the `D`).
331    pub divisor: usize,
332    /// The expected modulo of the length division (the `M`).
333    pub modulo: usize,
334}
335
336impl ModuloError {
337    /// Constructs [`Self`].
338    pub const fn new(divisor: usize, modulo: usize) -> Self {
339        Self { divisor, modulo }
340    }
341}
342
343/// Checks whether the given value length divided by `D` has modulo `M`.
344pub struct Modulo<const D: usize, const M: usize> {
345    private: PhantomData<()>,
346}
347
348impl<const D: usize, const M: usize, T: HasLength + ?Sized> Predicate<T> for Modulo<D, M> {
349    type Error = ModuloError;
350
351    fn check(value: &T) -> Result<(), Self::Error> {
352        if value.length() % D == M {
353            Ok(())
354        } else {
355            Err(Self::Error::new(D, M))
356        }
357    }
358
359    fn expect(formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
360        write!(formatter, "length % {D} == {M}")
361    }
362
363    fn expect_code(formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
364        write!(formatter, "length::mod<{D}, {M}>")
365    }
366}
367
368/// Checks whether the given value length is divisible by `D`.
369pub type Divisible<const D: usize> = Modulo<D, 0>;
370
371/// Checks whether the given value length is even.
372pub type Even = Divisible<2>;
373
374/// Checks whether the given value length is odd.
375pub type Odd = Not<Even>;
376
377// core
378
379impl HasLength for str {
380    fn length(&self) -> usize {
381        self.len()
382    }
383}
384
385impl<T> HasLength for [T] {
386    fn length(&self) -> usize {
387        self.len()
388    }
389}
390
391impl<T: HasLength + ?Sized> HasLength for &T {
392    fn length(&self) -> usize {
393        T::length(self)
394    }
395}
396
397// prelude imports
398
399#[cfg(feature = "alloc")]
400use alloc::{boxed::Box, string::String, vec::Vec};
401
402#[cfg(any(feature = "alloc", feature = "std"))]
403impl<T: HasLength + ?Sized> HasLength for Box<T> {
404    fn length(&self) -> usize {
405        T::length(self)
406    }
407}
408
409#[cfg(any(feature = "alloc", feature = "std"))]
410impl HasLength for String {
411    fn length(&self) -> usize {
412        self.len()
413    }
414}
415
416#[cfg(any(feature = "alloc", feature = "std"))]
417impl<T> HasLength for Vec<T> {
418    fn length(&self) -> usize {
419        self.len()
420    }
421}
422
423// clone-on-write
424
425#[cfg(feature = "alloc")]
426use alloc::borrow::{Cow, ToOwned};
427
428#[cfg(all(not(feature = "alloc"), feature = "std"))]
429use std::borrow::{Cow, ToOwned};
430
431#[cfg(any(feature = "alloc", feature = "std"))]
432impl<T: ToOwned + HasLength + ?Sized> HasLength for Cow<'_, T> {
433    fn length(&self) -> usize {
434        T::length(self)
435    }
436}
437
438// pointers
439
440#[cfg(feature = "alloc")]
441use alloc::rc::Rc;
442
443#[cfg(all(not(feature = "alloc"), feature = "std"))]
444use std::rc::Rc;
445
446#[cfg(any(feature = "alloc", feature = "std"))]
447impl<T: HasLength + ?Sized> HasLength for Rc<T> {
448    fn length(&self) -> usize {
449        T::length(self)
450    }
451}
452
453#[cfg(feature = "alloc")]
454use alloc::sync::Arc;
455
456#[cfg(all(not(feature = "alloc"), feature = "std"))]
457use std::sync::Arc;
458
459#[cfg(any(feature = "alloc", feature = "std"))]
460impl<T: HasLength + ?Sized> HasLength for Arc<T> {
461    fn length(&self) -> usize {
462        T::length(self)
463    }
464}
465
466// shared collections
467
468#[cfg(feature = "alloc")]
469use alloc::collections::{BTreeMap, BTreeSet, BinaryHeap, LinkedList, VecDeque};
470
471#[cfg(all(not(feature = "alloc"), feature = "std"))]
472use std::collections::{BTreeMap, BTreeSet, BinaryHeap, LinkedList, VecDeque};
473
474#[cfg(any(feature = "alloc", feature = "std"))]
475impl<K, V> HasLength for BTreeMap<K, V> {
476    fn length(&self) -> usize {
477        self.len()
478    }
479}
480
481#[cfg(any(feature = "alloc", feature = "std"))]
482impl<T> HasLength for BTreeSet<T> {
483    fn length(&self) -> usize {
484        self.len()
485    }
486}
487
488#[cfg(any(feature = "alloc", feature = "std"))]
489impl<T> HasLength for BinaryHeap<T> {
490    fn length(&self) -> usize {
491        self.len()
492    }
493}
494
495#[cfg(any(feature = "alloc", feature = "std"))]
496impl<T> HasLength for LinkedList<T> {
497    fn length(&self) -> usize {
498        self.len()
499    }
500}
501
502#[cfg(any(feature = "alloc", feature = "std"))]
503impl<T> HasLength for VecDeque<T> {
504    fn length(&self) -> usize {
505        self.len()
506    }
507}
508
509// collections
510
511#[cfg(feature = "std")]
512use std::collections::{HashMap, HashSet};
513
514#[cfg(feature = "std")]
515impl<K, V, S> HasLength for HashMap<K, V, S> {
516    fn length(&self) -> usize {
517        self.len()
518    }
519}
520
521#[cfg(feature = "std")]
522impl<T, S> HasLength for HashSet<T, S> {
523    fn length(&self) -> usize {
524        self.len()
525    }
526}
527
528// OS strings
529
530#[cfg(feature = "std")]
531use std::ffi::{OsStr, OsString};
532
533#[cfg(feature = "std")]
534impl HasLength for OsStr {
535    fn length(&self) -> usize {
536        self.len()
537    }
538}
539
540#[cfg(feature = "std")]
541impl HasLength for OsString {
542    fn length(&self) -> usize {
543        self.len()
544    }
545}
546
547// paths (via underlying strings)
548
549#[cfg(feature = "std")]
550use std::path::{Path, PathBuf};
551
552#[cfg(feature = "std")]
553impl HasLength for Path {
554    fn length(&self) -> usize {
555        self.as_os_str().length()
556    }
557}
558
559#[cfg(feature = "std")]
560impl HasLength for PathBuf {
561    fn length(&self) -> usize {
562        self.as_os_str().length()
563    }
564}