Skip to main content

const_reify/
nat_reify.rs

1//! True value→type reification for natural numbers.
2//!
3//! This module provides [`NatCallback`], [`Nat2Callback`], [`reify_nat`],
4//! and [`reify_nat2`] for genuine runtime-to-type-level dispatch: inside
5//! the callback, const generics are fully available as type parameters.
6//!
7//! ## Why a trait, not a closure?
8//!
9//! With `&dyn HasModulus`, the const generic is erased. By using a trait
10//! with a const-generic method, the full `N` is available inside the
11//! callback, enabling construction of types like `Mod<N>`, const-generic
12//! arithmetic, and type-safe invariants parameterized by the reified value.
13//!
14//! ## Composing two reified values
15//!
16//! [`reify_nat2`] nests two dispatches so both `A` and `B` are known
17//! const generics inside the callback.
18//!
19//! # Examples
20//!
21//! ```
22//! use const_reify::nat_reify::{NatCallback, reify_nat};
23//!
24//! struct ModMul { a: u64, b: u64 }
25//!
26//! impl NatCallback<u64> for ModMul {
27//!     fn call<const N: u64>(&self) -> u64 {
28//!         if N == 0 { return 0; }
29//!         (self.a % N) * (self.b % N) % N
30//!     }
31//! }
32//!
33//! let result = reify_nat(7, &ModMul { a: 10, b: 20 });
34//! assert_eq!(result, (10 % 7) * (20 % 7) % 7);
35//! ```
36//!
37//! Two values:
38//!
39//! ```
40//! use const_reify::nat_reify::{Nat2Callback, reify_nat2};
41//!
42//! struct Add;
43//! impl Nat2Callback<u64> for Add {
44//!     fn call<const A: u64, const B: u64>(&self) -> u64 { A + B }
45//! }
46//!
47//! assert_eq!(reify_nat2(5, 3, &Add), 8);
48//! ```
49
50/// Callback trait for single-value reification.
51///
52/// Inside [`call`](NatCallback::call), `N` is a fully known const generic.
53///
54/// # Examples
55///
56/// ```
57/// use const_reify::nat_reify::{NatCallback, reify_nat};
58///
59/// struct IsEven;
60/// impl NatCallback<bool> for IsEven {
61///     fn call<const N: u64>(&self) -> bool { N % 2 == 0 }
62/// }
63///
64/// assert_eq!(reify_nat(4, &IsEven), true);
65/// assert_eq!(reify_nat(7, &IsEven), false);
66/// ```
67pub trait NatCallback<R> {
68    /// Called with the const-generic `N` matching the runtime value.
69    fn call<const N: u64>(&self) -> R;
70}
71
72/// Callback trait for two-value reification.
73///
74/// Both `A` and `B` are known const generics inside [`call`](Nat2Callback::call).
75///
76/// # Examples
77///
78/// ```
79/// use const_reify::nat_reify::{Nat2Callback, reify_nat2};
80///
81/// struct Mul;
82/// impl Nat2Callback<u64> for Mul {
83///     fn call<const A: u64, const B: u64>(&self) -> u64 { A * B }
84/// }
85///
86/// assert_eq!(reify_nat2(6, 7, &Mul), 42);
87/// ```
88pub trait Nat2Callback<R> {
89    /// Called with both const-generic values.
90    fn call<const A: u64, const B: u64>(&self) -> R;
91}
92
93/// Reify a runtime `u64` (0..=255) into a const-generic context.
94///
95/// # Panics
96///
97/// Panics if `val > 255`.
98///
99/// # Examples
100///
101/// ```
102/// use const_reify::nat_reify::{NatCallback, reify_nat};
103///
104/// struct Square;
105/// impl NatCallback<u64> for Square {
106///     fn call<const N: u64>(&self) -> u64 { N * N }
107/// }
108///
109/// assert_eq!(reify_nat(12, &Square), 144);
110/// ```
111pub fn reify_nat<C: NatCallback<R>, R>(val: u64, callback: &C) -> R {
112    macro_rules! dispatch {
113        ($($n:literal),*) => {
114            match val {
115                $( $n => callback.call::<$n>(), )*
116                other => panic!(
117                    "const-reify: value {} is out of supported range 0..=255", other
118                ),
119            }
120        };
121    }
122
123    dispatch!(
124        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
125        25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
126        48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70,
127        71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
128        94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
129        113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130,
130        131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
131        149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166,
132        167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184,
133        185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202,
134        203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220,
135        221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238,
136        239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255
137    )
138}
139
140/// Reify two runtime `u64` values (each 0..=255) into a const-generic context.
141///
142/// Nested dispatch: both `A` and `B` are known const generics in the callback.
143///
144/// # Panics
145///
146/// Panics if either value > 255.
147///
148/// # Examples
149///
150/// ```
151/// use const_reify::nat_reify::{Nat2Callback, reify_nat2};
152///
153/// struct Lt;
154/// impl Nat2Callback<bool> for Lt {
155///     fn call<const A: u64, const B: u64>(&self) -> bool { A < B }
156/// }
157///
158/// assert_eq!(reify_nat2(3, 5, &Lt), true);
159/// assert_eq!(reify_nat2(5, 3, &Lt), false);
160/// ```
161pub fn reify_nat2<C: Nat2Callback<R>, R>(a: u64, b: u64, callback: &C) -> R {
162    struct Outer<'a, C, R> {
163        b: u64,
164        inner: &'a C,
165        _r: std::marker::PhantomData<R>,
166    }
167
168    impl<C: Nat2Callback<R>, R> NatCallback<R> for Outer<'_, C, R> {
169        fn call<const A: u64>(&self) -> R {
170            struct Inner<'a, const A: u64, C, R> {
171                inner: &'a C,
172                _r: std::marker::PhantomData<R>,
173            }
174
175            impl<const A: u64, C: Nat2Callback<R>, R> NatCallback<R> for Inner<'_, A, C, R> {
176                fn call<const B: u64>(&self) -> R {
177                    self.inner.call::<A, B>()
178                }
179            }
180
181            reify_nat(
182                self.b,
183                &Inner::<A, C, R> {
184                    inner: self.inner,
185                    _r: std::marker::PhantomData,
186                },
187            )
188        }
189    }
190
191    reify_nat(
192        a,
193        &Outer {
194            b,
195            inner: callback,
196            _r: std::marker::PhantomData,
197        },
198    )
199}
200
201// ---------------------------------------------------------------------------
202// Closure-based ergonomic wrappers
203// ---------------------------------------------------------------------------
204
205/// Wrapper that adapts a `Fn(u64) -> R` closure into a [`NatCallback`].
206///
207/// Inside the dispatch, the const generic `N` is passed to the closure as
208/// a plain `u64`. This loses the ability to construct types parameterized
209/// by `N`, but covers the common case where you just need the value.
210///
211/// Prefer [`reify_nat_fn`] which uses this internally.
212pub struct FnNat<F>(pub F);
213
214impl<F: Fn(u64) -> R, R> NatCallback<R> for FnNat<F> {
215    fn call<const N: u64>(&self) -> R {
216        (self.0)(N)
217    }
218}
219
220/// Wrapper that adapts a `Fn(u64, u64) -> R` closure into a [`Nat2Callback`].
221///
222/// Prefer [`reify_nat2_fn`] which uses this internally.
223pub struct FnNat2<F>(pub F);
224
225impl<F: Fn(u64, u64) -> R, R> Nat2Callback<R> for FnNat2<F> {
226    fn call<const A: u64, const B: u64>(&self) -> R {
227        (self.0)(A, B)
228    }
229}
230
231/// Ergonomic single-value reification with a closure.
232///
233/// The closure receives the reified value as a plain `u64`. For cases
234/// where you need the actual const generic (e.g., to construct `Mod<N>`),
235/// use [`reify_nat`] with a [`NatCallback`] impl instead.
236///
237/// # Examples
238///
239/// ```
240/// use const_reify::nat_reify::reify_nat_fn;
241///
242/// assert_eq!(reify_nat_fn(5, |n| n * n), 25);
243/// assert_eq!(reify_nat_fn(10, |n| n + 1), 11);
244/// assert_eq!(reify_nat_fn(0, |n| n == 0), true);
245/// ```
246pub fn reify_nat_fn<F: Fn(u64) -> R, R>(val: u64, f: F) -> R {
247    reify_nat(val, &FnNat(f))
248}
249
250/// Ergonomic two-value reification with a closure.
251///
252/// The closure receives both reified values as plain `u64`s.
253///
254/// # Examples
255///
256/// ```
257/// use const_reify::nat_reify::reify_nat2_fn;
258///
259/// assert_eq!(reify_nat2_fn(5, 3, |a, b| a + b), 8);
260/// assert_eq!(reify_nat2_fn(6, 7, |a, b| a * b), 42);
261/// assert_eq!(reify_nat2_fn(3, 5, |a, b| a < b), true);
262/// ```
263pub fn reify_nat2_fn<F: Fn(u64, u64) -> R, R>(a: u64, b: u64, f: F) -> R {
264    reify_nat2(a, b, &FnNat2(f))
265}
266
267// ---------------------------------------------------------------------------
268// Macros for defining callbacks and inline reification
269// ---------------------------------------------------------------------------
270
271/// Define a [`NatCallback`] struct with minimal boilerplate.
272///
273/// Two forms:
274///
275/// **Stateless**: no captured data, just a const-generic body:
276/// ```
277/// use const_reify::{def_nat_callback, nat_reify::reify_nat};
278///
279/// def_nat_callback!(Square -> u64 { N * N });
280///
281/// assert_eq!(reify_nat(5, &Square), 25);
282/// ```
283///
284/// **With fields**: captures runtime data alongside the const generic:
285/// ```
286/// use const_reify::{def_nat_callback, nat_reify::reify_nat};
287///
288/// def_nat_callback!(ModMul { a: u64, b: u64 } -> u64 {
289///     |s| if N == 0 { 0 } else { (s.a % N) * (s.b % N) % N }
290/// });
291///
292/// assert_eq!(reify_nat(7, &ModMul { a: 10, b: 20 }), 4);
293/// ```
294///
295/// In both forms, `N` refers to the const-generic `u64` parameter.
296#[macro_export]
297macro_rules! def_nat_callback {
298    // Stateless: def_nat_callback!(Name -> RetType { body using N })
299    ($name:ident -> $ret:ty $body:block) => {
300        struct $name;
301
302        impl $crate::nat_reify::NatCallback<$ret> for $name {
303            #[allow(non_snake_case)]
304            fn call<const N: u64>(&self) -> $ret $body
305        }
306    };
307
308    // With fields: def_nat_callback!(Name { field: Type, ... } -> RetType { |s| body using N, s })
309    ($name:ident { $($field:ident : $fty:ty),* $(,)? } -> $ret:ty { |$s:ident| $($body:tt)* }) => {
310        struct $name {
311            $( $field: $fty, )*
312        }
313
314        impl $crate::nat_reify::NatCallback<$ret> for $name {
315            #[allow(non_snake_case)]
316            fn call<const N: u64>(&self) -> $ret {
317                let $s = self;
318                $($body)*
319            }
320        }
321    };
322}
323
324/// Define a [`Nat2Callback`] struct with minimal boilerplate.
325///
326/// Two forms:
327///
328/// **Stateless:**
329/// ```
330/// use const_reify::{def_nat2_callback, nat_reify::reify_nat2};
331///
332/// def_nat2_callback!(Add -> u64 { A + B });
333///
334/// assert_eq!(reify_nat2(5, 3, &Add), 8);
335/// ```
336///
337/// **With fields:**
338/// ```
339/// use const_reify::{def_nat2_callback, nat_reify::reify_nat2};
340///
341/// def_nat2_callback!(ScaledSum { scale: u64 } -> u64 { |s| (A + B) * s.scale });
342///
343/// assert_eq!(reify_nat2(5, 3, &ScaledSum { scale: 10 }), 80);
344/// ```
345///
346/// `A` and `B` refer to the two const-generic `u64` parameters.
347#[macro_export]
348macro_rules! def_nat2_callback {
349    ($name:ident -> $ret:ty $body:block) => {
350        struct $name;
351
352        impl $crate::nat_reify::Nat2Callback<$ret> for $name {
353            #[allow(non_snake_case)]
354            fn call<const A: u64, const B: u64>(&self) -> $ret $body
355        }
356    };
357
358    ($name:ident { $($field:ident : $fty:ty),* $(,)? } -> $ret:ty { |$s:ident| $($body:tt)* }) => {
359        struct $name {
360            $( $field: $fty, )*
361        }
362
363        impl $crate::nat_reify::Nat2Callback<$ret> for $name {
364            #[allow(non_snake_case)]
365            fn call<const A: u64, const B: u64>(&self) -> $ret {
366                let $s = self;
367                $($body)*
368            }
369        }
370    };
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    struct Identity;
378    impl NatCallback<u64> for Identity {
379        fn call<const N: u64>(&self) -> u64 {
380            N
381        }
382    }
383
384    #[test]
385    fn reify_nat_identity() {
386        for v in 0..=255u64 {
387            assert_eq!(reify_nat(v, &Identity), v);
388        }
389    }
390
391    struct Square;
392    impl NatCallback<u64> for Square {
393        fn call<const N: u64>(&self) -> u64 {
394            N * N
395        }
396    }
397
398    #[test]
399    fn reify_nat_square() {
400        assert_eq!(reify_nat(0, &Square), 0);
401        assert_eq!(reify_nat(5, &Square), 25);
402        assert_eq!(reify_nat(12, &Square), 144);
403    }
404
405    struct IsEven;
406    impl NatCallback<bool> for IsEven {
407        fn call<const N: u64>(&self) -> bool {
408            N % 2 == 0
409        }
410    }
411
412    #[test]
413    fn reify_nat_predicate() {
414        assert!(reify_nat(0, &IsEven));
415        assert!(!reify_nat(1, &IsEven));
416        assert!(reify_nat(42, &IsEven));
417    }
418
419    #[test]
420    #[should_panic(expected = "out of supported range")]
421    fn reify_nat_out_of_range() {
422        reify_nat(256, &Identity);
423    }
424
425    // --- Two-value tests ---
426
427    struct Add2;
428    impl Nat2Callback<u64> for Add2 {
429        fn call<const A: u64, const B: u64>(&self) -> u64 {
430            A + B
431        }
432    }
433
434    #[test]
435    fn reify_nat2_add() {
436        assert_eq!(reify_nat2(5, 3, &Add2), 8);
437        assert_eq!(reify_nat2(0, 0, &Add2), 0);
438        assert_eq!(reify_nat2(100, 155, &Add2), 255);
439    }
440
441    struct Mul2;
442    impl Nat2Callback<u64> for Mul2 {
443        fn call<const A: u64, const B: u64>(&self) -> u64 {
444            A * B
445        }
446    }
447
448    #[test]
449    fn reify_nat2_mul() {
450        assert_eq!(reify_nat2(6, 7, &Mul2), 42);
451        assert_eq!(reify_nat2(0, 255, &Mul2), 0);
452    }
453
454    struct Lt2;
455    impl Nat2Callback<bool> for Lt2 {
456        fn call<const A: u64, const B: u64>(&self) -> bool {
457            A < B
458        }
459    }
460
461    #[test]
462    fn reify_nat2_lt() {
463        assert!(reify_nat2(3, 5, &Lt2));
464        assert!(!reify_nat2(5, 3, &Lt2));
465        assert!(!reify_nat2(5, 5, &Lt2));
466    }
467
468    // --- The real power: type-safe modular arithmetic with runtime modulus ---
469
470    #[derive(Debug, Clone, Copy, PartialEq)]
471    struct Mod<const M: u64> {
472        value: u64,
473    }
474
475    impl<const M: u64> Mod<M> {
476        fn new(v: u64) -> Self {
477            Mod {
478                value: if M == 0 { 0 } else { v % M },
479            }
480        }
481
482        fn mul(self, other: Self) -> Self {
483            Self::new(self.value * other.value)
484        }
485
486        fn pow(self, exp: u64) -> Self {
487            let mut result = Self::new(1);
488            let mut base = self;
489            let mut e = exp;
490            while e > 0 {
491                if e % 2 == 1 {
492                    result = result.mul(base);
493                }
494                base = base.mul(base);
495                e /= 2;
496            }
497            result
498        }
499    }
500
501    struct ModPow {
502        base: u64,
503        exp: u64,
504    }
505
506    impl NatCallback<u64> for ModPow {
507        fn call<const M: u64>(&self) -> u64 {
508            // M is a const generic — we can construct Mod<M> here!
509            // The type system ensures all arithmetic stays in the same modulus.
510            let b = Mod::<M>::new(self.base);
511            b.pow(self.exp).value
512        }
513    }
514
515    #[test]
516    fn modular_exponentiation_with_runtime_modulus() {
517        // 3^5 mod 7 = 243 mod 7 = 5
518        assert_eq!(reify_nat(7, &ModPow { base: 3, exp: 5 }), 243 % 7);
519
520        // Fermat's little theorem: a^(p-1) ≡ 1 (mod p) for prime p, gcd(a,p)=1
521        assert_eq!(reify_nat(7, &ModPow { base: 3, exp: 6 }), 1);
522        assert_eq!(reify_nat(11, &ModPow { base: 2, exp: 10 }), 1);
523        assert_eq!(reify_nat(13, &ModPow { base: 5, exp: 12 }), 1);
524    }
525
526    // --- Closure-based ergonomic API ---
527
528    #[test]
529    fn reify_nat_fn_basic() {
530        assert_eq!(reify_nat_fn(5, |n| n * n), 25);
531        assert!(reify_nat_fn(0, |n| n == 0));
532        assert_eq!(reify_nat_fn(255, |n| n), 255);
533    }
534
535    #[test]
536    fn reify_nat_fn_captures_environment() {
537        let offset = 100u64;
538        assert_eq!(reify_nat_fn(5, |n| n + offset), 105);
539    }
540
541    #[test]
542    fn reify_nat2_fn_basic() {
543        assert_eq!(reify_nat2_fn(5, 3, |a, b| a + b), 8);
544        assert_eq!(reify_nat2_fn(6, 7, |a, b| a * b), 42);
545        assert!(reify_nat2_fn(3, 5, |a, b| a < b));
546    }
547
548    #[test]
549    fn reify_nat2_fn_captures_environment() {
550        let scale = 10u64;
551        assert_eq!(reify_nat2_fn(5, 3, |a, b| (a + b) * scale), 80);
552    }
553
554    // --- Macro-defined callbacks ---
555
556    def_nat_callback!(Cube -> u64 { N * N * N });
557
558    #[test]
559    fn macro_stateless_callback() {
560        assert_eq!(reify_nat(3, &Cube), 27);
561        assert_eq!(reify_nat(5, &Cube), 125);
562    }
563
564    def_nat_callback!(AddOffset { offset: u64 } -> u64 { |s| N + s.offset });
565
566    #[test]
567    fn macro_callback_with_fields() {
568        assert_eq!(reify_nat(10, &AddOffset { offset: 5 }), 15);
569    }
570
571    def_nat2_callback!(Hypotenuse2 -> u64 { A * A + B * B });
572
573    #[test]
574    fn macro_nat2_stateless() {
575        assert_eq!(reify_nat2(3, 4, &Hypotenuse2), 25); // 3² + 4² = 25
576    }
577
578    def_nat2_callback!(ScaledDiff { scale: u64 } -> u64 {
579        |s| if A > B { (A - B) * s.scale } else { (B - A) * s.scale }
580    });
581
582    #[test]
583    fn macro_nat2_with_fields() {
584        assert_eq!(reify_nat2(10, 3, &ScaledDiff { scale: 5 }), 35);
585    }
586}