Skip to main content

karpal_proof/
rewrite.rs

1use core::marker::PhantomData;
2
3/// Witness that justification `Via` validates rewriting `Lhs` to `Rhs`.
4///
5/// Types implementing this trait declare which algebraic identities
6/// are valid rewrite steps.
7pub trait Justifies<Lhs, Rhs> {}
8
9/// Witness that type-level expression `Lhs` equals `Rhs`,
10/// justified by rule `Via`.
11///
12/// `Rewrite` is a zero-sized type — it carries no data, only
13/// type-level evidence that an algebraic identity has been invoked.
14///
15/// # Example
16///
17/// ```
18/// use karpal_proof::rewrite::*;
19///
20/// // Create a rewrite justified by associativity
21/// let step: Rewrite<AssocLeft, AssocRight, ByAssociativity> =
22///     Rewrite::witness();
23///
24/// // Reverse it
25/// let _back: Rewrite<AssocRight, AssocLeft, BySymmetry<ByAssociativity>> =
26///     step.sym();
27/// ```
28pub struct Rewrite<Lhs, Rhs, Via> {
29    _phantom: PhantomData<(Lhs, Rhs, Via)>,
30}
31
32impl<Lhs, Rhs, Via> Rewrite<Lhs, Rhs, Via> {
33    /// Construct a rewrite witness.
34    ///
35    /// Only compiles if `Via: Justifies<Lhs, Rhs>`.
36    pub fn witness() -> Self
37    where
38        Via: Justifies<Lhs, Rhs>,
39    {
40        Rewrite {
41            _phantom: PhantomData,
42        }
43    }
44
45    /// Reverse: if `Lhs = Rhs` then `Rhs = Lhs`.
46    pub fn sym(self) -> Rewrite<Rhs, Lhs, BySymmetry<Via>> {
47        Rewrite {
48            _phantom: PhantomData,
49        }
50    }
51
52    /// Chain: if `Lhs = Mid` (via self) and `Mid = Rhs2` (via next),
53    /// then `Lhs = Rhs2`.
54    ///
55    /// The `Rhs` of the first rewrite must be the same type as the
56    /// `Lhs` (i.e., `Mid`) of the second. This is enforced by sharing
57    /// the type parameter `Rhs`/`Mid`.
58    pub fn then<Rhs2, V2>(
59        self,
60        _next: Rewrite<Rhs, Rhs2, V2>,
61    ) -> Rewrite<Lhs, Rhs2, ByTransitivity<Via, V2, Rhs>> {
62        Rewrite {
63            _phantom: PhantomData,
64        }
65    }
66}
67
68// ---------------------------------------------------------------------------
69// Justification types (all ZSTs)
70// ---------------------------------------------------------------------------
71
72/// Justified by associativity: `a ∘ (b ∘ c) = (a ∘ b) ∘ c`.
73pub struct ByAssociativity;
74
75/// Justified by commutativity: `a ∘ b = b ∘ a`.
76pub struct ByCommutativity;
77
78/// Justified by identity law: `a ∘ e = a` or `e ∘ a = a`.
79pub struct ByIdentity;
80
81/// Justified by inverse law: `a ∘ a⁻¹ = e`.
82pub struct ByInverse;
83
84/// Justified by distribution: `a * (b + c) = a*b + a*c`.
85pub struct ByDistribution;
86
87/// Justified by zero annihilation: `0 * a = 0`.
88pub struct ByAnnihilation;
89
90/// Chain two justifications: if `Lhs = Mid` via `V1` and `Mid = Rhs` via `V2`.
91/// `Mid` is carried as a type parameter to satisfy Rust's coherence rules.
92pub struct ByTransitivity<V1, V2, Mid = ()>(PhantomData<(V1, V2, Mid)>);
93
94/// Reverse a justification: if `Lhs = Rhs` via `V`, then `Rhs = Lhs`.
95pub struct BySymmetry<V>(PhantomData<V>);
96
97// ---------------------------------------------------------------------------
98// Expression marker types for common algebraic patterns
99// ---------------------------------------------------------------------------
100
101/// `(a ∘ b) ∘ c` — left-associated.
102pub struct AssocLeft;
103/// `a ∘ (b ∘ c)` — right-associated.
104pub struct AssocRight;
105
106/// `a ∘ b`.
107pub struct CombineAB;
108/// `b ∘ a`.
109pub struct CombineBA;
110
111/// `a ∘ e` or `e ∘ a`.
112pub struct WithIdentity;
113/// Just `a`.
114pub struct JustA;
115
116/// `a ∘ a⁻¹`.
117pub struct WithInverse;
118/// The identity element `e`.
119pub struct Identity;
120
121/// `a * (b + c)` — undistributed.
122pub struct Undistributed;
123/// `a*b + a*c` — distributed.
124pub struct Distributed;
125
126/// `0 * a`.
127pub struct ZeroTimes;
128/// `0`.
129pub struct Zero;
130
131// ---------------------------------------------------------------------------
132// Justifies implementations
133// ---------------------------------------------------------------------------
134
135// Associativity: (a∘b)∘c = a∘(b∘c) and reverse
136impl Justifies<AssocLeft, AssocRight> for ByAssociativity {}
137impl Justifies<AssocRight, AssocLeft> for ByAssociativity {}
138
139// Commutativity: a∘b = b∘a and reverse
140impl Justifies<CombineAB, CombineBA> for ByCommutativity {}
141impl Justifies<CombineBA, CombineAB> for ByCommutativity {}
142
143// Identity: a∘e = a and reverse
144impl Justifies<WithIdentity, JustA> for ByIdentity {}
145impl Justifies<JustA, WithIdentity> for ByIdentity {}
146
147// Inverse: a∘a⁻¹ = e and reverse
148impl Justifies<WithInverse, Identity> for ByInverse {}
149impl Justifies<Identity, WithInverse> for ByInverse {}
150
151// Distribution: a*(b+c) = a*b + a*c and reverse
152impl Justifies<Undistributed, Distributed> for ByDistribution {}
153impl Justifies<Distributed, Undistributed> for ByDistribution {}
154
155// Zero annihilation: 0*a = 0 and reverse
156impl Justifies<ZeroTimes, Zero> for ByAnnihilation {}
157impl Justifies<Zero, ZeroTimes> for ByAnnihilation {}
158
159// Symmetry: if V justifies Lhs→Rhs, BySymmetry<V> justifies Rhs→Lhs
160impl<V, Lhs, Rhs> Justifies<Rhs, Lhs> for BySymmetry<V> where V: Justifies<Lhs, Rhs> {}
161
162// Transitivity: if V1 justifies Lhs→Mid and V2 justifies Mid→Rhs,
163// ByTransitivity<V1, V2, Mid> justifies Lhs→Rhs
164impl<V1, V2, Lhs, Mid, Rhs> Justifies<Lhs, Rhs> for ByTransitivity<V1, V2, Mid>
165where
166    V1: Justifies<Lhs, Mid>,
167    V2: Justifies<Mid, Rhs>,
168{
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn associativity_rewrite() {
177        let _: Rewrite<AssocLeft, AssocRight, ByAssociativity> = Rewrite::witness();
178        let _: Rewrite<AssocRight, AssocLeft, ByAssociativity> = Rewrite::witness();
179    }
180
181    #[test]
182    fn commutativity_rewrite() {
183        let _: Rewrite<CombineAB, CombineBA, ByCommutativity> = Rewrite::witness();
184    }
185
186    #[test]
187    fn identity_rewrite() {
188        let _: Rewrite<WithIdentity, JustA, ByIdentity> = Rewrite::witness();
189        let _: Rewrite<JustA, WithIdentity, ByIdentity> = Rewrite::witness();
190    }
191
192    #[test]
193    fn inverse_rewrite() {
194        let _: Rewrite<WithInverse, Identity, ByInverse> = Rewrite::witness();
195    }
196
197    #[test]
198    fn distribution_rewrite() {
199        let _: Rewrite<Undistributed, Distributed, ByDistribution> = Rewrite::witness();
200    }
201
202    #[test]
203    fn annihilation_rewrite() {
204        let _: Rewrite<ZeroTimes, Zero, ByAnnihilation> = Rewrite::witness();
205    }
206
207    #[test]
208    fn symmetry() {
209        let step: Rewrite<AssocLeft, AssocRight, ByAssociativity> = Rewrite::witness();
210        let _back: Rewrite<AssocRight, AssocLeft, BySymmetry<ByAssociativity>> = step.sym();
211    }
212
213    #[test]
214    fn transitivity_chain() {
215        // (a∘b)∘c → a∘(b∘c) → (b∘c)∘a  [via assoc then commutativity]
216        // We model this as: AssocLeft → AssocRight → CombineBA
217        // But we need CombineBA to match the second step...
218        // Instead demonstrate: WithIdentity → JustA → WithIdentity
219        let step1: Rewrite<WithIdentity, JustA, ByIdentity> = Rewrite::witness();
220        let step2: Rewrite<JustA, WithIdentity, ByIdentity> = Rewrite::witness();
221        let _chained: Rewrite<
222            WithIdentity,
223            WithIdentity,
224            ByTransitivity<ByIdentity, ByIdentity, JustA>,
225        > = step1.then(step2);
226    }
227
228    #[test]
229    fn three_step_chain() {
230        // WithInverse → Identity → WithInverse → Identity
231        let s1: Rewrite<WithInverse, Identity, ByInverse> = Rewrite::witness();
232        let s2: Rewrite<Identity, WithInverse, ByInverse> = Rewrite::witness();
233        let s3: Rewrite<WithInverse, Identity, ByInverse> = Rewrite::witness();
234        let _: Rewrite<
235            WithInverse,
236            Identity,
237            ByTransitivity<ByTransitivity<ByInverse, ByInverse, Identity>, ByInverse, WithInverse>,
238        > = s1.then(s2).then(s3);
239    }
240}