Skip to main content

karpal_proof/
rewrite.rs

1// Copyright (C) 2026 Industrial Algebra
2// SPDX-License-Identifier: Apache-2.0
3
4use core::marker::PhantomData;
5
6/// Witness that justification `Via` validates rewriting `Lhs` to `Rhs`.
7///
8/// Types implementing this trait declare which algebraic identities
9/// are valid rewrite steps.
10pub trait Justifies<Lhs, Rhs> {}
11
12/// Witness that type-level expression `Lhs` equals `Rhs`,
13/// justified by rule `Via`.
14///
15/// `Rewrite` is a zero-sized type — it carries no data, only
16/// type-level evidence that an algebraic identity has been invoked.
17///
18/// # Example
19///
20/// ```
21/// use karpal_proof::rewrite::*;
22///
23/// // Create a rewrite justified by associativity
24/// let step: Rewrite<AssocLeft, AssocRight, ByAssociativity> =
25///     Rewrite::witness();
26///
27/// // Reverse it
28/// let _back: Rewrite<AssocRight, AssocLeft, BySymmetry<ByAssociativity>> =
29///     step.sym();
30/// ```
31pub struct Rewrite<Lhs, Rhs, Via> {
32    _phantom: PhantomData<(Lhs, Rhs, Via)>,
33}
34
35impl<Lhs, Rhs, Via> Rewrite<Lhs, Rhs, Via> {
36    /// Construct a rewrite witness.
37    ///
38    /// Only compiles if `Via: Justifies<Lhs, Rhs>`.
39    pub fn witness() -> Self
40    where
41        Via: Justifies<Lhs, Rhs>,
42    {
43        Rewrite {
44            _phantom: PhantomData,
45        }
46    }
47
48    /// Reverse: if `Lhs = Rhs` then `Rhs = Lhs`.
49    pub fn sym(self) -> Rewrite<Rhs, Lhs, BySymmetry<Via>> {
50        Rewrite {
51            _phantom: PhantomData,
52        }
53    }
54
55    /// Chain: if `Lhs = Mid` (via self) and `Mid = Rhs2` (via next),
56    /// then `Lhs = Rhs2`.
57    ///
58    /// The `Rhs` of the first rewrite must be the same type as the
59    /// `Lhs` (i.e., `Mid`) of the second. This is enforced by sharing
60    /// the type parameter `Rhs`/`Mid`.
61    pub fn then<Rhs2, V2>(
62        self,
63        _next: Rewrite<Rhs, Rhs2, V2>,
64    ) -> Rewrite<Lhs, Rhs2, ByTransitivity<Via, V2, Rhs>> {
65        Rewrite {
66            _phantom: PhantomData,
67        }
68    }
69}
70
71// ---------------------------------------------------------------------------
72// Justification types (all ZSTs)
73// ---------------------------------------------------------------------------
74
75/// Justified by associativity: `a ∘ (b ∘ c) = (a ∘ b) ∘ c`.
76pub struct ByAssociativity;
77
78/// Justified by commutativity: `a ∘ b = b ∘ a`.
79pub struct ByCommutativity;
80
81/// Justified by identity law: `a ∘ e = a` or `e ∘ a = a`.
82pub struct ByIdentity;
83
84/// Justified by inverse law: `a ∘ a⁻¹ = e`.
85pub struct ByInverse;
86
87/// Justified by distribution: `a * (b + c) = a*b + a*c`.
88pub struct ByDistribution;
89
90/// Justified by zero annihilation: `0 * a = 0`.
91pub struct ByAnnihilation;
92
93/// Chain two justifications: if `Lhs = Mid` via `V1` and `Mid = Rhs` via `V2`.
94/// `Mid` is carried as a type parameter to satisfy Rust's coherence rules.
95pub struct ByTransitivity<V1, V2, Mid = ()>(PhantomData<(V1, V2, Mid)>);
96
97/// Reverse a justification: if `Lhs = Rhs` via `V`, then `Rhs = Lhs`.
98pub struct BySymmetry<V>(PhantomData<V>);
99
100// ---------------------------------------------------------------------------
101// Expression marker types for common algebraic patterns
102// ---------------------------------------------------------------------------
103
104/// `(a ∘ b) ∘ c` — left-associated.
105pub struct AssocLeft;
106/// `a ∘ (b ∘ c)` — right-associated.
107pub struct AssocRight;
108
109/// `a ∘ b`.
110pub struct CombineAB;
111/// `b ∘ a`.
112pub struct CombineBA;
113
114/// `a ∘ e` or `e ∘ a`.
115pub struct WithIdentity;
116/// Just `a`.
117pub struct JustA;
118
119/// `a ∘ a⁻¹`.
120pub struct WithInverse;
121/// The identity element `e`.
122pub struct Identity;
123
124/// `a * (b + c)` — undistributed.
125pub struct Undistributed;
126/// `a*b + a*c` — distributed.
127pub struct Distributed;
128
129/// `0 * a`.
130pub struct ZeroTimes;
131/// `0`.
132pub struct Zero;
133
134// ---------------------------------------------------------------------------
135// Justifies implementations
136// ---------------------------------------------------------------------------
137
138// Associativity: (a∘b)∘c = a∘(b∘c) and reverse
139impl Justifies<AssocLeft, AssocRight> for ByAssociativity {}
140impl Justifies<AssocRight, AssocLeft> for ByAssociativity {}
141
142// Commutativity: a∘b = b∘a and reverse
143impl Justifies<CombineAB, CombineBA> for ByCommutativity {}
144impl Justifies<CombineBA, CombineAB> for ByCommutativity {}
145
146// Identity: a∘e = a and reverse
147impl Justifies<WithIdentity, JustA> for ByIdentity {}
148impl Justifies<JustA, WithIdentity> for ByIdentity {}
149
150// Inverse: a∘a⁻¹ = e and reverse
151impl Justifies<WithInverse, Identity> for ByInverse {}
152impl Justifies<Identity, WithInverse> for ByInverse {}
153
154// Distribution: a*(b+c) = a*b + a*c and reverse
155impl Justifies<Undistributed, Distributed> for ByDistribution {}
156impl Justifies<Distributed, Undistributed> for ByDistribution {}
157
158// Zero annihilation: 0*a = 0 and reverse
159impl Justifies<ZeroTimes, Zero> for ByAnnihilation {}
160impl Justifies<Zero, ZeroTimes> for ByAnnihilation {}
161
162// Symmetry: if V justifies Lhs→Rhs, BySymmetry<V> justifies Rhs→Lhs
163impl<V, Lhs, Rhs> Justifies<Rhs, Lhs> for BySymmetry<V> where V: Justifies<Lhs, Rhs> {}
164
165// Transitivity: if V1 justifies Lhs→Mid and V2 justifies Mid→Rhs,
166// ByTransitivity<V1, V2, Mid> justifies Lhs→Rhs
167impl<V1, V2, Lhs, Mid, Rhs> Justifies<Lhs, Rhs> for ByTransitivity<V1, V2, Mid>
168where
169    V1: Justifies<Lhs, Mid>,
170    V2: Justifies<Mid, Rhs>,
171{
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[test]
179    fn associativity_rewrite() {
180        let _: Rewrite<AssocLeft, AssocRight, ByAssociativity> = Rewrite::witness();
181        let _: Rewrite<AssocRight, AssocLeft, ByAssociativity> = Rewrite::witness();
182    }
183
184    #[test]
185    fn commutativity_rewrite() {
186        let _: Rewrite<CombineAB, CombineBA, ByCommutativity> = Rewrite::witness();
187    }
188
189    #[test]
190    fn identity_rewrite() {
191        let _: Rewrite<WithIdentity, JustA, ByIdentity> = Rewrite::witness();
192        let _: Rewrite<JustA, WithIdentity, ByIdentity> = Rewrite::witness();
193    }
194
195    #[test]
196    fn inverse_rewrite() {
197        let _: Rewrite<WithInverse, Identity, ByInverse> = Rewrite::witness();
198    }
199
200    #[test]
201    fn distribution_rewrite() {
202        let _: Rewrite<Undistributed, Distributed, ByDistribution> = Rewrite::witness();
203    }
204
205    #[test]
206    fn annihilation_rewrite() {
207        let _: Rewrite<ZeroTimes, Zero, ByAnnihilation> = Rewrite::witness();
208    }
209
210    #[test]
211    fn symmetry() {
212        let step: Rewrite<AssocLeft, AssocRight, ByAssociativity> = Rewrite::witness();
213        let _back: Rewrite<AssocRight, AssocLeft, BySymmetry<ByAssociativity>> = step.sym();
214    }
215
216    #[test]
217    fn transitivity_chain() {
218        // (a∘b)∘c → a∘(b∘c) → (b∘c)∘a  [via assoc then commutativity]
219        // We model this as: AssocLeft → AssocRight → CombineBA
220        // But we need CombineBA to match the second step...
221        // Instead demonstrate: WithIdentity → JustA → WithIdentity
222        let step1: Rewrite<WithIdentity, JustA, ByIdentity> = Rewrite::witness();
223        let step2: Rewrite<JustA, WithIdentity, ByIdentity> = Rewrite::witness();
224        let _chained: Rewrite<
225            WithIdentity,
226            WithIdentity,
227            ByTransitivity<ByIdentity, ByIdentity, JustA>,
228        > = step1.then(step2);
229    }
230
231    #[test]
232    fn three_step_chain() {
233        // WithInverse → Identity → WithInverse → Identity
234        let s1: Rewrite<WithInverse, Identity, ByInverse> = Rewrite::witness();
235        let s2: Rewrite<Identity, WithInverse, ByInverse> = Rewrite::witness();
236        let s3: Rewrite<WithInverse, Identity, ByInverse> = Rewrite::witness();
237        let _: Rewrite<
238            WithInverse,
239            Identity,
240            ByTransitivity<ByTransitivity<ByInverse, ByInverse, Identity>, ByInverse, WithInverse>,
241        > = s1.then(s2).then(s3);
242    }
243}