Skip to main content

reflect_nat/
nat.rs

1//! Peano-encoded type-level natural numbers with type-level arithmetic.
2
3use reify_reflect_core::{Reflect, RuntimeValue};
4use std::marker::PhantomData;
5
6/// Type-level zero.
7///
8/// # Examples
9///
10/// ```
11/// use reflect_nat::Z;
12/// use reify_reflect_core::{Reflect, RuntimeValue};
13///
14/// assert_eq!(Z::reflect(), RuntimeValue::Nat(0));
15/// ```
16#[derive(Default, Debug, Clone, Copy)]
17pub struct Z;
18
19/// Type-level successor. `S<N>` represents `N + 1`.
20///
21/// # Examples
22///
23/// ```
24/// use reflect_nat::{Z, S};
25/// use reify_reflect_core::{Reflect, RuntimeValue};
26///
27/// type One = S<Z>;
28/// type Two = S<S<Z>>;
29/// assert_eq!(One::reflect(), RuntimeValue::Nat(1));
30/// assert_eq!(Two::reflect(), RuntimeValue::Nat(2));
31/// ```
32#[derive(Default, Debug, Clone, Copy)]
33pub struct S<N>(pub PhantomData<N>);
34
35/// Marker trait for types that represent type-level natural numbers.
36///
37/// # Examples
38///
39/// ```
40/// use reflect_nat::{Z, S, Nat};
41///
42/// fn require_nat<N: Nat>() {}
43/// require_nat::<Z>();
44/// require_nat::<S<Z>>();
45/// ```
46pub trait Nat {
47    /// The runtime `u64` value of this type-level natural.
48    fn to_u64() -> u64;
49}
50
51impl Nat for Z {
52    fn to_u64() -> u64 {
53        0
54    }
55}
56
57impl<N: Nat> Nat for S<N> {
58    fn to_u64() -> u64 {
59        1 + N::to_u64()
60    }
61}
62
63impl Reflect for Z {
64    type Value = RuntimeValue;
65
66    fn reflect() -> Self::Value {
67        RuntimeValue::Nat(0)
68    }
69}
70
71impl<N: Nat> Reflect for S<N> {
72    type Value = RuntimeValue;
73
74    fn reflect() -> Self::Value {
75        RuntimeValue::Nat(<S<N> as Nat>::to_u64())
76    }
77}
78
79// ---------------------------------------------------------------------------
80// Type-level arithmetic
81// ---------------------------------------------------------------------------
82
83/// Type-level addition. `Add<A, B>` computes `A + B`.
84///
85/// # Examples
86///
87/// ```
88/// use reflect_nat::{Z, S, Add};
89/// use reify_reflect_core::{Reflect, RuntimeValue};
90///
91/// // 2 + 3 = 5
92/// type Two = S<S<Z>>;
93/// type Three = S<S<S<Z>>>;
94/// type Five = <Two as Add<Three>>::Result;
95/// assert_eq!(Five::reflect(), RuntimeValue::Nat(5));
96/// ```
97pub trait Add<Rhs> {
98    /// The resulting type-level natural.
99    type Result: Nat;
100}
101
102// Z + N = N
103impl<N: Nat> Add<N> for Z {
104    type Result = N;
105}
106
107// S<M> + N = S<M + N>
108impl<M: Nat + Add<N>, N: Nat> Add<N> for S<M>
109where
110    <M as Add<N>>::Result: Nat,
111{
112    type Result = S<<M as Add<N>>::Result>;
113}
114
115/// Type-level multiplication. `Mul<A, B>` computes `A * B`.
116///
117/// # Examples
118///
119/// ```
120/// use reflect_nat::{Z, S, Mul};
121/// use reify_reflect_core::{Reflect, RuntimeValue};
122///
123/// // 2 * 3 = 6
124/// type Two = S<S<Z>>;
125/// type Three = S<S<S<Z>>>;
126/// type Six = <Two as Mul<Three>>::Result;
127/// assert_eq!(Six::reflect(), RuntimeValue::Nat(6));
128/// ```
129pub trait Mul<Rhs> {
130    /// The resulting type-level natural.
131    type Result: Nat;
132}
133
134// Z * N = Z
135impl<N: Nat> Mul<N> for Z {
136    type Result = Z;
137}
138
139// S<M> * N = N + (M * N)
140impl<M, N> Mul<N> for S<M>
141where
142    M: Nat + Mul<N>,
143    N: Nat + Add<<M as Mul<N>>::Result>,
144    <M as Mul<N>>::Result: Nat,
145    <N as Add<<M as Mul<N>>::Result>>::Result: Nat,
146{
147    type Result = <N as Add<<M as Mul<N>>::Result>>::Result;
148}
149
150/// Type-level less-than comparison. `Lt<A, B>` is true when `A < B`.
151///
152/// # Examples
153///
154/// ```
155/// use reflect_nat::{Z, S, Lt};
156///
157/// // 0 < 1 is true
158/// assert!(<Z as Lt<S<Z>>>::VALUE);
159///
160/// // 2 < 1 is false
161/// type Two = S<S<Z>>;
162/// assert!(!<Two as Lt<S<Z>>>::VALUE);
163/// ```
164pub trait Lt<Rhs> {
165    /// `true` if `Self < Rhs` at the type level.
166    const VALUE: bool;
167}
168
169// Z < Z = false
170impl Lt<Z> for Z {
171    const VALUE: bool = false;
172}
173
174// Z < S<N> = true
175impl<N: Nat> Lt<S<N>> for Z {
176    const VALUE: bool = true;
177}
178
179// S<M> < Z = false
180impl<M: Nat> Lt<Z> for S<M> {
181    const VALUE: bool = false;
182}
183
184// S<M> < S<N> = M < N
185impl<M: Nat + Lt<N>, N: Nat> Lt<S<N>> for S<M> {
186    const VALUE: bool = <M as Lt<N>>::VALUE;
187}
188
189// ---------------------------------------------------------------------------
190// Convenience type aliases
191// ---------------------------------------------------------------------------
192
193/// Type alias for 0.
194pub type N0 = Z;
195/// Type alias for 1.
196pub type N1 = S<N0>;
197/// Type alias for 2.
198pub type N2 = S<N1>;
199/// Type alias for 3.
200pub type N3 = S<N2>;
201/// Type alias for 4.
202pub type N4 = S<N3>;
203/// Type alias for 5.
204pub type N5 = S<N4>;
205/// Type alias for 6.
206pub type N6 = S<N5>;
207/// Type alias for 7.
208pub type N7 = S<N6>;
209/// Type alias for 8.
210pub type N8 = S<N7>;
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn zero_reflects_to_nat_0() {
218        assert_eq!(Z::reflect(), RuntimeValue::Nat(0));
219    }
220
221    #[test]
222    fn successor_reflects_correctly() {
223        assert_eq!(S::<Z>::reflect(), RuntimeValue::Nat(1));
224        assert_eq!(<S<S<Z>>>::reflect(), RuntimeValue::Nat(2));
225        assert_eq!(<S<S<S<Z>>>>::reflect(), RuntimeValue::Nat(3));
226    }
227
228    #[test]
229    fn nat_to_u64() {
230        assert_eq!(Z::to_u64(), 0);
231        assert_eq!(<S<Z>>::to_u64(), 1);
232        assert_eq!(<S<S<S<S<S<Z>>>>>>::to_u64(), 5);
233    }
234
235    #[test]
236    fn type_aliases() {
237        assert_eq!(N0::to_u64(), 0);
238        assert_eq!(N1::to_u64(), 1);
239        assert_eq!(N5::to_u64(), 5);
240        assert_eq!(N8::to_u64(), 8);
241    }
242
243    #[test]
244    fn addition() {
245        // 0 + 3 = 3
246        assert_eq!(<<Z as Add<N3>>::Result as Nat>::to_u64(), 3);
247        // 2 + 3 = 5
248        assert_eq!(<<N2 as Add<N3>>::Result as Nat>::to_u64(), 5);
249        // 1 + 0 = 1
250        assert_eq!(<<N1 as Add<N0>>::Result as Nat>::to_u64(), 1);
251    }
252
253    #[test]
254    fn multiplication() {
255        // 0 * 3 = 0
256        assert_eq!(<<Z as Mul<N3>>::Result as Nat>::to_u64(), 0);
257        // 2 * 3 = 6
258        assert_eq!(<<N2 as Mul<N3>>::Result as Nat>::to_u64(), 6);
259        // 1 * 5 = 5
260        assert_eq!(<<N1 as Mul<N5>>::Result as Nat>::to_u64(), 5);
261        // 3 * 1 = 3
262        assert_eq!(<<N3 as Mul<N1>>::Result as Nat>::to_u64(), 3);
263    }
264
265    #[test]
266    #[allow(clippy::assertions_on_constants)]
267    fn less_than() {
268        assert!(!<Z as Lt<Z>>::VALUE);
269        assert!(<Z as Lt<S<Z>>>::VALUE);
270        assert!(!<S<Z> as Lt<Z>>::VALUE);
271        assert!(<N2 as Lt<N5>>::VALUE);
272        assert!(!<N5 as Lt<N2>>::VALUE);
273        assert!(!<N3 as Lt<N3>>::VALUE);
274    }
275
276    #[test]
277    fn reflect_returns_runtime_value() {
278        assert_eq!(N5::reflect(), RuntimeValue::Nat(5));
279        assert_eq!(N0::reflect(), RuntimeValue::Nat(0));
280    }
281}