Skip to main content

karpal_proof_derive/
lib.rs

1//! Derive macros for automatic algebraic law verification.
2//!
3//! These macros generate `#[cfg(test)]` modules containing proptest-based
4//! property tests that verify algebraic laws for user-defined types.
5//!
6//! # Usage
7//!
8//! ```ignore
9//! use karpal_proof_derive::VerifySemigroup;
10//!
11//! #[derive(Clone, Debug, PartialEq, VerifySemigroup)]
12//! #[verify(strategy = "0u32..100")]
13//! struct MyWrapper(u32);
14//! ```
15//!
16//! # Attributes
17//!
18//! - `#[verify(strategy = "...")]` — **Required**. A proptest strategy expression
19//!   that generates values of the annotated type. Examples:
20//!   - `"0u32..100"` for numeric ranges
21//!   - `"any::<MyType>()"` for types implementing `Arbitrary`
22//!   - `"(0i32..50).prop_map(MyWrapper)"` for newtype wrappers
23//!
24//! - `#[verify(epsilon = "1e-10")]` — Optional. Use approximate floating-point
25//!   comparison instead of exact `PartialEq`. The generated tests will use
26//!   `(left - right).abs() < epsilon` instead of `prop_assert_eq!`.
27
28use proc_macro::TokenStream;
29use quote::{format_ident, quote};
30use syn::{DeriveInput, Expr, parse_macro_input};
31
32/// Extract a string-valued attribute from `#[verify(key = "...")]`.
33fn extract_verify_attr(input: &DeriveInput, key: &str) -> Option<proc_macro2::TokenStream> {
34    for attr in &input.attrs {
35        if !attr.path().is_ident("verify") {
36            continue;
37        }
38        let Ok(nested) = attr.parse_args_with(
39            syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated,
40        ) else {
41            continue;
42        };
43        for meta in &nested {
44            if let syn::Meta::NameValue(nv) = meta
45                && nv.path.is_ident(key)
46                && let Expr::Lit(syn::ExprLit {
47                    lit: syn::Lit::Str(s),
48                    ..
49                }) = &nv.value
50            {
51                let tokens: proc_macro2::TokenStream = s
52                    .value()
53                    .parse()
54                    .unwrap_or_else(|_| panic!("{key} must be a valid Rust expression"));
55                return Some(tokens);
56            }
57        }
58    }
59    None
60}
61
62/// Extract the `strategy` string from `#[verify(strategy = "...")]`.
63fn extract_strategy(input: &DeriveInput) -> proc_macro2::TokenStream {
64    extract_verify_attr(input, "strategy")
65        .expect("#[verify(strategy = \"...\")] attribute is required")
66}
67
68/// Extract the optional `epsilon` string from `#[verify(epsilon = "...")]`.
69fn extract_epsilon(input: &DeriveInput) -> Option<proc_macro2::TokenStream> {
70    extract_verify_attr(input, "epsilon")
71}
72
73/// Generate an assertion: either exact or approximate.
74fn make_assert(
75    epsilon: &Option<proc_macro2::TokenStream>,
76    left: proc_macro2::TokenStream,
77    right: proc_macro2::TokenStream,
78) -> proc_macro2::TokenStream {
79    if let Some(eps) = epsilon {
80        // For types that support subtraction and abs (floats)
81        quote! {
82            let __left = #left;
83            let __right = #right;
84            prop_assert!(
85                (__left - __right).abs() < #eps,
86                "expected {:?} ≈ {:?} (epsilon {})", __left, __right, #eps
87            );
88        }
89    } else {
90        quote! {
91            prop_assert_eq!(#left, #right);
92        }
93    }
94}
95
96/// Derive `VerifySemigroup`: generates proptest for associativity of `combine`.
97///
98/// Requires: `T: Semigroup + Clone + Debug + PartialEq`
99#[proc_macro_derive(VerifySemigroup, attributes(verify))]
100pub fn derive_verify_semigroup(input: TokenStream) -> TokenStream {
101    let input = parse_macro_input!(input as DeriveInput);
102    let name = &input.ident;
103    let strategy = extract_strategy(&input);
104    let epsilon = extract_epsilon(&input);
105    let mod_name = format_ident!("verify_semigroup_{}", name.to_string().to_lowercase());
106
107    let assert_assoc = make_assert(
108        &epsilon,
109        quote! { a.clone().combine(b.clone()).combine(c.clone()) },
110        quote! { a.combine(b.combine(c)) },
111    );
112
113    let output = quote! {
114        #[cfg(test)]
115        mod #mod_name {
116            use super::*;
117            use karpal_core::Semigroup;
118            use proptest::prelude::*;
119
120            proptest! {
121                #[test]
122                fn associativity(a in #strategy, b in #strategy, c in #strategy) {
123                    #assert_assoc
124                }
125            }
126        }
127    };
128    output.into()
129}
130
131/// Derive `VerifyMonoid`: generates proptest for left/right identity of `empty`.
132///
133/// Requires: `T: Monoid + Clone + Debug + PartialEq`
134#[proc_macro_derive(VerifyMonoid, attributes(verify))]
135pub fn derive_verify_monoid(input: TokenStream) -> TokenStream {
136    let input = parse_macro_input!(input as DeriveInput);
137    let name = &input.ident;
138    let strategy = extract_strategy(&input);
139    let epsilon = extract_epsilon(&input);
140    let mod_name = format_ident!("verify_monoid_{}", name.to_string().to_lowercase());
141
142    let assert_left = make_assert(
143        &epsilon,
144        quote! { <#name as Monoid>::empty().combine(a.clone()) },
145        quote! { a },
146    );
147
148    let assert_right = make_assert(
149        &epsilon,
150        quote! { a.clone().combine(<#name as Monoid>::empty()) },
151        quote! { a },
152    );
153
154    let output = quote! {
155        #[cfg(test)]
156        mod #mod_name {
157            use super::*;
158            use karpal_core::{Semigroup, Monoid};
159            use proptest::prelude::*;
160
161            proptest! {
162                #[test]
163                fn left_identity(a in #strategy) {
164                    #assert_left
165                }
166
167                #[test]
168                fn right_identity(a in #strategy) {
169                    #assert_right
170                }
171            }
172        }
173    };
174    output.into()
175}
176
177/// Derive `VerifyGroup`: generates proptest for left/right inverse.
178///
179/// Requires: `T: Group + Clone + Debug + PartialEq`
180#[proc_macro_derive(VerifyGroup, attributes(verify))]
181pub fn derive_verify_group(input: TokenStream) -> TokenStream {
182    let input = parse_macro_input!(input as DeriveInput);
183    let name = &input.ident;
184    let strategy = extract_strategy(&input);
185    let epsilon = extract_epsilon(&input);
186    let mod_name = format_ident!("verify_group_{}", name.to_string().to_lowercase());
187
188    let assert_left = make_assert(
189        &epsilon,
190        quote! { a.clone().invert().combine(a.clone()) },
191        quote! { <#name as Monoid>::empty() },
192    );
193
194    let assert_right = make_assert(
195        &epsilon,
196        quote! { a.clone().combine(a.clone().invert()) },
197        quote! { <#name as Monoid>::empty() },
198    );
199
200    let output = quote! {
201        #[cfg(test)]
202        mod #mod_name {
203            use super::*;
204            use karpal_core::{Semigroup, Monoid};
205            use karpal_algebra::Group;
206            use proptest::prelude::*;
207
208            proptest! {
209                #[test]
210                fn left_inverse(a in #strategy) {
211                    #assert_left
212                }
213
214                #[test]
215                fn right_inverse(a in #strategy) {
216                    #assert_right
217                }
218            }
219        }
220    };
221    output.into()
222}
223
224/// Derive `VerifyCommutative`: generates proptest for commutativity of `combine`.
225///
226/// Requires: `T: Semigroup + Clone + Debug + PartialEq`
227#[proc_macro_derive(VerifyCommutative, attributes(verify))]
228pub fn derive_verify_commutative(input: TokenStream) -> TokenStream {
229    let input = parse_macro_input!(input as DeriveInput);
230    let name = &input.ident;
231    let strategy = extract_strategy(&input);
232    let epsilon = extract_epsilon(&input);
233    let mod_name = format_ident!("verify_commutative_{}", name.to_string().to_lowercase());
234
235    let assert_comm = make_assert(
236        &epsilon,
237        quote! { a.clone().combine(b.clone()) },
238        quote! { b.combine(a) },
239    );
240
241    let output = quote! {
242        #[cfg(test)]
243        mod #mod_name {
244            use super::*;
245            use karpal_core::Semigroup;
246            use proptest::prelude::*;
247
248            proptest! {
249                #[test]
250                fn commutativity(a in #strategy, b in #strategy) {
251                    #assert_comm
252                }
253            }
254        }
255    };
256    output.into()
257}
258
259/// Derive `VerifySemiring`: generates proptests for additive monoid,
260/// multiplicative monoid, distributivity, and zero annihilation.
261///
262/// Requires: `T: Semiring + Clone + Debug + PartialEq`
263#[proc_macro_derive(VerifySemiring, attributes(verify))]
264pub fn derive_verify_semiring(input: TokenStream) -> TokenStream {
265    let input = parse_macro_input!(input as DeriveInput);
266    let name = &input.ident;
267    let strategy = extract_strategy(&input);
268    let epsilon = extract_epsilon(&input);
269    let mod_name = format_ident!("verify_semiring_{}", name.to_string().to_lowercase());
270
271    let assert_add_assoc = make_assert(
272        &epsilon,
273        quote! { a.clone().add(b.clone()).add(c.clone()) },
274        quote! { a.clone().add(b.clone().add(c.clone())) },
275    );
276
277    let assert_add_comm = make_assert(
278        &epsilon,
279        quote! { a.clone().add(b.clone()) },
280        quote! { b.clone().add(a.clone()) },
281    );
282
283    let assert_add_left_id = make_assert(
284        &epsilon,
285        quote! { <#name as Semiring>::zero().add(a.clone()) },
286        quote! { a.clone() },
287    );
288
289    let assert_add_right_id = make_assert(
290        &epsilon,
291        quote! { a.clone().add(<#name as Semiring>::zero()) },
292        quote! { a.clone() },
293    );
294
295    let assert_mul_assoc = make_assert(
296        &epsilon,
297        quote! { a.clone().mul(b.clone()).mul(c.clone()) },
298        quote! { a.clone().mul(b.clone().mul(c.clone())) },
299    );
300
301    let assert_mul_left_id = make_assert(
302        &epsilon,
303        quote! { <#name as Semiring>::one().mul(a.clone()) },
304        quote! { a.clone() },
305    );
306
307    let assert_mul_right_id = make_assert(
308        &epsilon,
309        quote! { a.clone().mul(<#name as Semiring>::one()) },
310        quote! { a.clone() },
311    );
312
313    let assert_left_dist = make_assert(
314        &epsilon,
315        quote! { a.clone().mul(b.clone().add(c.clone())) },
316        quote! { a.clone().mul(b.clone()).add(a.clone().mul(c.clone())) },
317    );
318
319    let assert_right_dist = make_assert(
320        &epsilon,
321        quote! { a.clone().add(b.clone()).mul(c.clone()) },
322        quote! { a.clone().mul(c.clone()).add(b.clone().mul(c.clone())) },
323    );
324
325    let assert_zero_left = make_assert(
326        &epsilon,
327        quote! { <#name as Semiring>::zero().mul(a.clone()) },
328        quote! { <#name as Semiring>::zero() },
329    );
330
331    let assert_zero_right = make_assert(
332        &epsilon,
333        quote! { a.clone().mul(<#name as Semiring>::zero()) },
334        quote! { <#name as Semiring>::zero() },
335    );
336
337    let output = quote! {
338        #[cfg(test)]
339        mod #mod_name {
340            use super::*;
341            use karpal_algebra::Semiring;
342            use proptest::prelude::*;
343
344            proptest! {
345                // Additive commutative monoid
346                #[test]
347                fn add_associativity(a in #strategy, b in #strategy, c in #strategy) {
348                    #assert_add_assoc
349                }
350
351                #[test]
352                fn add_commutativity(a in #strategy, b in #strategy) {
353                    #assert_add_comm
354                }
355
356                #[test]
357                fn add_identity(a in #strategy) {
358                    #assert_add_left_id
359                    #assert_add_right_id
360                }
361
362                // Multiplicative monoid
363                #[test]
364                fn mul_associativity(a in #strategy, b in #strategy, c in #strategy) {
365                    #assert_mul_assoc
366                }
367
368                #[test]
369                fn mul_identity(a in #strategy) {
370                    #assert_mul_left_id
371                    #assert_mul_right_id
372                }
373
374                // Distributivity
375                #[test]
376                fn left_distributivity(a in #strategy, b in #strategy, c in #strategy) {
377                    #assert_left_dist
378                }
379
380                #[test]
381                fn right_distributivity(a in #strategy, b in #strategy, c in #strategy) {
382                    #assert_right_dist
383                }
384
385                // Zero annihilation
386                #[test]
387                fn zero_annihilation(a in #strategy) {
388                    #assert_zero_left
389                    #assert_zero_right
390                }
391            }
392        }
393    };
394    output.into()
395}
396
397/// Derive `VerifyRing`: generates proptest for additive inverse (negate).
398///
399/// Requires: `T: Ring + Clone + Debug + PartialEq`
400#[proc_macro_derive(VerifyRing, attributes(verify))]
401pub fn derive_verify_ring(input: TokenStream) -> TokenStream {
402    let input = parse_macro_input!(input as DeriveInput);
403    let name = &input.ident;
404    let strategy = extract_strategy(&input);
405    let epsilon = extract_epsilon(&input);
406    let mod_name = format_ident!("verify_ring_{}", name.to_string().to_lowercase());
407
408    let assert_left = make_assert(
409        &epsilon,
410        quote! { a.clone().negate().add(a.clone()) },
411        quote! { <#name as Semiring>::zero() },
412    );
413
414    let assert_right = make_assert(
415        &epsilon,
416        quote! { a.clone().add(a.clone().negate()) },
417        quote! { <#name as Semiring>::zero() },
418    );
419
420    let output = quote! {
421        #[cfg(test)]
422        mod #mod_name {
423            use super::*;
424            use karpal_algebra::{Semiring, Ring};
425            use proptest::prelude::*;
426
427            proptest! {
428                #[test]
429                fn left_additive_inverse(a in #strategy) {
430                    #assert_left
431                }
432
433                #[test]
434                fn right_additive_inverse(a in #strategy) {
435                    #assert_right
436                }
437            }
438        }
439    };
440    output.into()
441}
442
443/// Derive `VerifyLattice`: generates proptests for associativity, commutativity,
444/// idempotency, and absorption of `meet`/`join`.
445///
446/// Requires: `T: Lattice + Clone + Debug + PartialEq`
447#[proc_macro_derive(VerifyLattice, attributes(verify))]
448pub fn derive_verify_lattice(input: TokenStream) -> TokenStream {
449    let input = parse_macro_input!(input as DeriveInput);
450    let name = &input.ident;
451    let strategy = extract_strategy(&input);
452    let mod_name = format_ident!("verify_lattice_{}", name.to_string().to_lowercase());
453
454    // Lattice laws don't use epsilon — meet/join should be exact
455    let output = quote! {
456        #[cfg(test)]
457        mod #mod_name {
458            use super::*;
459            use karpal_algebra::Lattice;
460            use proptest::prelude::*;
461
462            proptest! {
463                // Associativity
464                #[test]
465                fn join_associativity(a in #strategy, b in #strategy, c in #strategy) {
466                    prop_assert_eq!(
467                        a.clone().join(b.clone()).join(c.clone()),
468                        a.join(b.join(c))
469                    );
470                }
471
472                #[test]
473                fn meet_associativity(a in #strategy, b in #strategy, c in #strategy) {
474                    prop_assert_eq!(
475                        a.clone().meet(b.clone()).meet(c.clone()),
476                        a.meet(b.meet(c))
477                    );
478                }
479
480                // Commutativity
481                #[test]
482                fn join_commutativity(a in #strategy, b in #strategy) {
483                    prop_assert_eq!(a.clone().join(b.clone()), b.join(a));
484                }
485
486                #[test]
487                fn meet_commutativity(a in #strategy, b in #strategy) {
488                    prop_assert_eq!(a.clone().meet(b.clone()), b.meet(a));
489                }
490
491                // Idempotency
492                #[test]
493                fn join_idempotency(a in #strategy) {
494                    prop_assert_eq!(a.clone().join(a.clone()), a);
495                }
496
497                #[test]
498                fn meet_idempotency(a in #strategy) {
499                    prop_assert_eq!(a.clone().meet(a.clone()), a);
500                }
501
502                // Absorption
503                #[test]
504                fn absorption_join_meet(a in #strategy, b in #strategy) {
505                    prop_assert_eq!(a.clone().join(a.clone().meet(b)), a);
506                }
507
508                #[test]
509                fn absorption_meet_join(a in #strategy, b in #strategy) {
510                    prop_assert_eq!(a.clone().meet(a.clone().join(b)), a);
511                }
512            }
513        }
514    };
515    output.into()
516}