Skip to main content

karpal_proof_derive/
lib.rs

1// Copyright (C) 2026 Industrial Algebra
2// SPDX-License-Identifier: Apache-2.0
3
4//! Derive macros for automatic algebraic law verification.
5//!
6//! These macros generate `#[cfg(test)]` modules containing proptest-based
7//! property tests that verify algebraic laws for user-defined types.
8//!
9//! # Usage
10//!
11//! ```ignore
12//! use karpal_proof_derive::VerifySemigroup;
13//!
14//! #[derive(Clone, Debug, PartialEq, VerifySemigroup)]
15//! #[verify(strategy = "0u32..100")]
16//! struct MyWrapper(u32);
17//! ```
18//!
19//! # Attributes
20//!
21//! - `#[verify(strategy = "...")]` — **Required**. A proptest strategy expression
22//!   that generates values of the annotated type. Examples:
23//!   - `"0u32..100"` for numeric ranges
24//!   - `"any::<MyType>()"` for types implementing `Arbitrary`
25//!   - `"(0i32..50).prop_map(MyWrapper)"` for newtype wrappers
26//!
27//! - `#[verify(epsilon = "1e-10")]` — Optional. Use approximate floating-point
28//!   comparison instead of exact `PartialEq`. The generated tests will use
29//!   `(left - right).abs() < epsilon` instead of `prop_assert_eq!`.
30
31use proc_macro::TokenStream;
32use quote::{format_ident, quote};
33use syn::{DeriveInput, Expr, parse_macro_input};
34
35/// Extract a string-valued attribute from `#[verify(key = "...")]`.
36fn extract_verify_attr(input: &DeriveInput, key: &str) -> Option<proc_macro2::TokenStream> {
37    for attr in &input.attrs {
38        if !attr.path().is_ident("verify") {
39            continue;
40        }
41        let Ok(nested) = attr.parse_args_with(
42            syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated,
43        ) else {
44            continue;
45        };
46        for meta in &nested {
47            if let syn::Meta::NameValue(nv) = meta
48                && nv.path.is_ident(key)
49                && let Expr::Lit(syn::ExprLit {
50                    lit: syn::Lit::Str(s),
51                    ..
52                }) = &nv.value
53            {
54                let tokens: proc_macro2::TokenStream = s
55                    .value()
56                    .parse()
57                    .unwrap_or_else(|_| panic!("{key} must be a valid Rust expression"));
58                return Some(tokens);
59            }
60        }
61    }
62    None
63}
64
65/// Extract the `strategy` string from `#[verify(strategy = "...")]`.
66fn extract_strategy(input: &DeriveInput) -> proc_macro2::TokenStream {
67    extract_verify_attr(input, "strategy")
68        .expect("#[verify(strategy = \"...\")] attribute is required")
69}
70
71/// Extract the optional `epsilon` string from `#[verify(epsilon = "...")]`.
72fn extract_epsilon(input: &DeriveInput) -> Option<proc_macro2::TokenStream> {
73    extract_verify_attr(input, "epsilon")
74}
75
76/// Generate an assertion: either exact or approximate.
77fn make_assert(
78    epsilon: &Option<proc_macro2::TokenStream>,
79    left: proc_macro2::TokenStream,
80    right: proc_macro2::TokenStream,
81) -> proc_macro2::TokenStream {
82    if let Some(eps) = epsilon {
83        // For types that support subtraction and abs (floats)
84        quote! {
85            let __left = #left;
86            let __right = #right;
87            prop_assert!(
88                (__left - __right).abs() < #eps,
89                "expected {:?} ≈ {:?} (epsilon {})", __left, __right, #eps
90            );
91        }
92    } else {
93        quote! {
94            prop_assert_eq!(#left, #right);
95        }
96    }
97}
98
99/// Derive `VerifySemigroup`: generates proptest for associativity of `combine`.
100///
101/// Requires: `T: Semigroup + Clone + Debug + PartialEq`
102#[proc_macro_derive(VerifySemigroup, attributes(verify))]
103pub fn derive_verify_semigroup(input: TokenStream) -> TokenStream {
104    let input = parse_macro_input!(input as DeriveInput);
105    let name = &input.ident;
106    let strategy = extract_strategy(&input);
107    let epsilon = extract_epsilon(&input);
108    let mod_name = format_ident!("verify_semigroup_{}", name.to_string().to_lowercase());
109
110    let assert_assoc = make_assert(
111        &epsilon,
112        quote! { a.clone().combine(b.clone()).combine(c.clone()) },
113        quote! { a.combine(b.combine(c)) },
114    );
115
116    let output = quote! {
117        #[cfg(test)]
118        mod #mod_name {
119            use super::*;
120            use karpal_core::Semigroup;
121            use proptest::prelude::*;
122
123            proptest! {
124                #[test]
125                fn associativity(a in #strategy, b in #strategy, c in #strategy) {
126                    #assert_assoc
127                }
128            }
129        }
130    };
131    output.into()
132}
133
134/// Derive `VerifyMonoid`: generates proptest for left/right identity of `empty`.
135///
136/// Requires: `T: Monoid + Clone + Debug + PartialEq`
137#[proc_macro_derive(VerifyMonoid, attributes(verify))]
138pub fn derive_verify_monoid(input: TokenStream) -> TokenStream {
139    let input = parse_macro_input!(input as DeriveInput);
140    let name = &input.ident;
141    let strategy = extract_strategy(&input);
142    let epsilon = extract_epsilon(&input);
143    let mod_name = format_ident!("verify_monoid_{}", name.to_string().to_lowercase());
144
145    let assert_left = make_assert(
146        &epsilon,
147        quote! { <#name as Monoid>::empty().combine(a.clone()) },
148        quote! { a },
149    );
150
151    let assert_right = make_assert(
152        &epsilon,
153        quote! { a.clone().combine(<#name as Monoid>::empty()) },
154        quote! { a },
155    );
156
157    let output = quote! {
158        #[cfg(test)]
159        mod #mod_name {
160            use super::*;
161            use karpal_core::{Semigroup, Monoid};
162            use proptest::prelude::*;
163
164            proptest! {
165                #[test]
166                fn left_identity(a in #strategy) {
167                    #assert_left
168                }
169
170                #[test]
171                fn right_identity(a in #strategy) {
172                    #assert_right
173                }
174            }
175        }
176    };
177    output.into()
178}
179
180/// Derive `VerifyGroup`: generates proptest for left/right inverse.
181///
182/// Requires: `T: Group + Clone + Debug + PartialEq`
183#[proc_macro_derive(VerifyGroup, attributes(verify))]
184pub fn derive_verify_group(input: TokenStream) -> TokenStream {
185    let input = parse_macro_input!(input as DeriveInput);
186    let name = &input.ident;
187    let strategy = extract_strategy(&input);
188    let epsilon = extract_epsilon(&input);
189    let mod_name = format_ident!("verify_group_{}", name.to_string().to_lowercase());
190
191    let assert_left = make_assert(
192        &epsilon,
193        quote! { a.clone().invert().combine(a.clone()) },
194        quote! { <#name as Monoid>::empty() },
195    );
196
197    let assert_right = make_assert(
198        &epsilon,
199        quote! { a.clone().combine(a.clone().invert()) },
200        quote! { <#name as Monoid>::empty() },
201    );
202
203    let output = quote! {
204        #[cfg(test)]
205        mod #mod_name {
206            use super::*;
207            use karpal_core::{Semigroup, Monoid};
208            use karpal_algebra::Group;
209            use proptest::prelude::*;
210
211            proptest! {
212                #[test]
213                fn left_inverse(a in #strategy) {
214                    #assert_left
215                }
216
217                #[test]
218                fn right_inverse(a in #strategy) {
219                    #assert_right
220                }
221            }
222        }
223    };
224    output.into()
225}
226
227/// Derive `VerifyCommutative`: generates proptest for commutativity of `combine`.
228///
229/// Requires: `T: Semigroup + Clone + Debug + PartialEq`
230#[proc_macro_derive(VerifyCommutative, attributes(verify))]
231pub fn derive_verify_commutative(input: TokenStream) -> TokenStream {
232    let input = parse_macro_input!(input as DeriveInput);
233    let name = &input.ident;
234    let strategy = extract_strategy(&input);
235    let epsilon = extract_epsilon(&input);
236    let mod_name = format_ident!("verify_commutative_{}", name.to_string().to_lowercase());
237
238    let assert_comm = make_assert(
239        &epsilon,
240        quote! { a.clone().combine(b.clone()) },
241        quote! { b.combine(a) },
242    );
243
244    let output = quote! {
245        #[cfg(test)]
246        mod #mod_name {
247            use super::*;
248            use karpal_core::Semigroup;
249            use proptest::prelude::*;
250
251            proptest! {
252                #[test]
253                fn commutativity(a in #strategy, b in #strategy) {
254                    #assert_comm
255                }
256            }
257        }
258    };
259    output.into()
260}
261
262/// Derive `VerifySemiring`: generates proptests for additive monoid,
263/// multiplicative monoid, distributivity, and zero annihilation.
264///
265/// Requires: `T: Semiring + Clone + Debug + PartialEq`
266#[proc_macro_derive(VerifySemiring, attributes(verify))]
267pub fn derive_verify_semiring(input: TokenStream) -> TokenStream {
268    let input = parse_macro_input!(input as DeriveInput);
269    let name = &input.ident;
270    let strategy = extract_strategy(&input);
271    let epsilon = extract_epsilon(&input);
272    let mod_name = format_ident!("verify_semiring_{}", name.to_string().to_lowercase());
273
274    let assert_add_assoc = make_assert(
275        &epsilon,
276        quote! { a.clone().add(b.clone()).add(c.clone()) },
277        quote! { a.clone().add(b.clone().add(c.clone())) },
278    );
279
280    let assert_add_comm = make_assert(
281        &epsilon,
282        quote! { a.clone().add(b.clone()) },
283        quote! { b.clone().add(a.clone()) },
284    );
285
286    let assert_add_left_id = make_assert(
287        &epsilon,
288        quote! { <#name as Semiring>::zero().add(a.clone()) },
289        quote! { a.clone() },
290    );
291
292    let assert_add_right_id = make_assert(
293        &epsilon,
294        quote! { a.clone().add(<#name as Semiring>::zero()) },
295        quote! { a.clone() },
296    );
297
298    let assert_mul_assoc = make_assert(
299        &epsilon,
300        quote! { a.clone().mul(b.clone()).mul(c.clone()) },
301        quote! { a.clone().mul(b.clone().mul(c.clone())) },
302    );
303
304    let assert_mul_left_id = make_assert(
305        &epsilon,
306        quote! { <#name as Semiring>::one().mul(a.clone()) },
307        quote! { a.clone() },
308    );
309
310    let assert_mul_right_id = make_assert(
311        &epsilon,
312        quote! { a.clone().mul(<#name as Semiring>::one()) },
313        quote! { a.clone() },
314    );
315
316    let assert_left_dist = make_assert(
317        &epsilon,
318        quote! { a.clone().mul(b.clone().add(c.clone())) },
319        quote! { a.clone().mul(b.clone()).add(a.clone().mul(c.clone())) },
320    );
321
322    let assert_right_dist = make_assert(
323        &epsilon,
324        quote! { a.clone().add(b.clone()).mul(c.clone()) },
325        quote! { a.clone().mul(c.clone()).add(b.clone().mul(c.clone())) },
326    );
327
328    let assert_zero_left = make_assert(
329        &epsilon,
330        quote! { <#name as Semiring>::zero().mul(a.clone()) },
331        quote! { <#name as Semiring>::zero() },
332    );
333
334    let assert_zero_right = make_assert(
335        &epsilon,
336        quote! { a.clone().mul(<#name as Semiring>::zero()) },
337        quote! { <#name as Semiring>::zero() },
338    );
339
340    let output = quote! {
341        #[cfg(test)]
342        mod #mod_name {
343            use super::*;
344            use karpal_algebra::Semiring;
345            use proptest::prelude::*;
346
347            proptest! {
348                // Additive commutative monoid
349                #[test]
350                fn add_associativity(a in #strategy, b in #strategy, c in #strategy) {
351                    #assert_add_assoc
352                }
353
354                #[test]
355                fn add_commutativity(a in #strategy, b in #strategy) {
356                    #assert_add_comm
357                }
358
359                #[test]
360                fn add_identity(a in #strategy) {
361                    #assert_add_left_id
362                    #assert_add_right_id
363                }
364
365                // Multiplicative monoid
366                #[test]
367                fn mul_associativity(a in #strategy, b in #strategy, c in #strategy) {
368                    #assert_mul_assoc
369                }
370
371                #[test]
372                fn mul_identity(a in #strategy) {
373                    #assert_mul_left_id
374                    #assert_mul_right_id
375                }
376
377                // Distributivity
378                #[test]
379                fn left_distributivity(a in #strategy, b in #strategy, c in #strategy) {
380                    #assert_left_dist
381                }
382
383                #[test]
384                fn right_distributivity(a in #strategy, b in #strategy, c in #strategy) {
385                    #assert_right_dist
386                }
387
388                // Zero annihilation
389                #[test]
390                fn zero_annihilation(a in #strategy) {
391                    #assert_zero_left
392                    #assert_zero_right
393                }
394            }
395        }
396    };
397    output.into()
398}
399
400/// Derive `VerifyRing`: generates proptest for additive inverse (negate).
401///
402/// Requires: `T: Ring + Clone + Debug + PartialEq`
403#[proc_macro_derive(VerifyRing, attributes(verify))]
404pub fn derive_verify_ring(input: TokenStream) -> TokenStream {
405    let input = parse_macro_input!(input as DeriveInput);
406    let name = &input.ident;
407    let strategy = extract_strategy(&input);
408    let epsilon = extract_epsilon(&input);
409    let mod_name = format_ident!("verify_ring_{}", name.to_string().to_lowercase());
410
411    let assert_left = make_assert(
412        &epsilon,
413        quote! { a.clone().negate().add(a.clone()) },
414        quote! { <#name as Semiring>::zero() },
415    );
416
417    let assert_right = make_assert(
418        &epsilon,
419        quote! { a.clone().add(a.clone().negate()) },
420        quote! { <#name as Semiring>::zero() },
421    );
422
423    let output = quote! {
424        #[cfg(test)]
425        mod #mod_name {
426            use super::*;
427            use karpal_algebra::{Semiring, Ring};
428            use proptest::prelude::*;
429
430            proptest! {
431                #[test]
432                fn left_additive_inverse(a in #strategy) {
433                    #assert_left
434                }
435
436                #[test]
437                fn right_additive_inverse(a in #strategy) {
438                    #assert_right
439                }
440            }
441        }
442    };
443    output.into()
444}
445
446/// Derive `VerifyLattice`: generates proptests for associativity, commutativity,
447/// idempotency, and absorption of `meet`/`join`.
448///
449/// Requires: `T: Lattice + Clone + Debug + PartialEq`
450#[proc_macro_derive(VerifyLattice, attributes(verify))]
451pub fn derive_verify_lattice(input: TokenStream) -> TokenStream {
452    let input = parse_macro_input!(input as DeriveInput);
453    let name = &input.ident;
454    let strategy = extract_strategy(&input);
455    let mod_name = format_ident!("verify_lattice_{}", name.to_string().to_lowercase());
456
457    // Lattice laws don't use epsilon — meet/join should be exact
458    let output = quote! {
459        #[cfg(test)]
460        mod #mod_name {
461            use super::*;
462            use karpal_algebra::Lattice;
463            use proptest::prelude::*;
464
465            proptest! {
466                // Associativity
467                #[test]
468                fn join_associativity(a in #strategy, b in #strategy, c in #strategy) {
469                    prop_assert_eq!(
470                        a.clone().join(b.clone()).join(c.clone()),
471                        a.join(b.join(c))
472                    );
473                }
474
475                #[test]
476                fn meet_associativity(a in #strategy, b in #strategy, c in #strategy) {
477                    prop_assert_eq!(
478                        a.clone().meet(b.clone()).meet(c.clone()),
479                        a.meet(b.meet(c))
480                    );
481                }
482
483                // Commutativity
484                #[test]
485                fn join_commutativity(a in #strategy, b in #strategy) {
486                    prop_assert_eq!(a.clone().join(b.clone()), b.join(a));
487                }
488
489                #[test]
490                fn meet_commutativity(a in #strategy, b in #strategy) {
491                    prop_assert_eq!(a.clone().meet(b.clone()), b.meet(a));
492                }
493
494                // Idempotency
495                #[test]
496                fn join_idempotency(a in #strategy) {
497                    prop_assert_eq!(a.clone().join(a.clone()), a);
498                }
499
500                #[test]
501                fn meet_idempotency(a in #strategy) {
502                    prop_assert_eq!(a.clone().meet(a.clone()), a);
503                }
504
505                // Absorption
506                #[test]
507                fn absorption_join_meet(a in #strategy, b in #strategy) {
508                    prop_assert_eq!(a.clone().join(a.clone().meet(b)), a);
509                }
510
511                #[test]
512                fn absorption_meet_join(a in #strategy, b in #strategy) {
513                    prop_assert_eq!(a.clone().meet(a.clone().join(b)), a);
514                }
515            }
516        }
517    };
518    output.into()
519}