liminal_ark_relations/
linear.rs

1use liminal_ark_relation_macro::snark_relation;
2
3/// Linear equation relation: a*x + b = y
4///
5/// Relation with:
6///  - 1 private witness (x)
7///  - 3 constants       (a, b, y)
8#[snark_relation]
9mod relation {
10    #[cfg(feature = "circuit")]
11    use {
12        ark_r1cs_std::{alloc::AllocVar, eq::EqGadget, uint32::UInt32},
13        ark_std::vec::Vec,
14    };
15
16    #[relation_object_definition]
17    #[derive(Clone, Debug)]
18    struct LinearEquationRelation {
19        /// slope
20        #[constant]
21        pub a: u32,
22        /// private witness
23        #[private_input]
24        pub x: u32,
25        /// an intercept
26        #[constant]
27        pub b: u32,
28        /// constant
29        #[constant]
30        pub y: u32,
31    }
32
33    #[cfg(feature = "circuit")]
34    #[circuit_definition]
35    fn generate_constraints() {
36        // TODO: migrate from real values to values in the finite field (see FpVar)
37        // Watch out for overflows!!!
38        let x = UInt32::new_witness(ark_relations::ns!(cs, "x"), || self.x())?;
39        let b = UInt32::new_constant(ark_relations::ns!(cs, "b"), self.b())?;
40        let y = UInt32::new_constant(ark_relations::ns!(cs, "y"), self.y())?;
41
42        let mut left = ark_std::iter::repeat(x)
43            .take(*self.a() as usize)
44            .collect::<Vec<UInt32<_>>>();
45
46        left.push(b);
47
48        UInt32::addmany(&left)?.enforce_equal(&y)
49    }
50}
51
52#[cfg(all(test, feature = "circuit"))]
53mod tests {
54    use ark_bls12_381::Bls12_381;
55    use ark_groth16::Groth16;
56    use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystem, ConstraintSystemRef};
57    use ark_snark::SNARK;
58
59    use super::*;
60    use crate::environment::CircuitField;
61
62    const A: u32 = 2;
63    const X: u32 = 1;
64    const B: u32 = 1;
65    const Y: u32 = 3;
66
67    #[test]
68    fn linear_constraints_correctness() {
69        let circuit = LinearEquationRelationWithFullInput::new(A, B, Y, X);
70
71        let cs: ConstraintSystemRef<CircuitField> = ConstraintSystem::new_ref();
72        circuit.generate_constraints(cs.clone()).unwrap();
73
74        let is_satisfied = cs.is_satisfied().unwrap();
75        if !is_satisfied {
76            println!("{:?}", cs.which_is_unsatisfied());
77        }
78
79        assert!(is_satisfied);
80    }
81
82    #[test]
83    fn linear_proving_procedure() {
84        let circuit_wo_input = LinearEquationRelationWithoutInput::new(A, B, Y);
85
86        let mut rng = ark_std::test_rng();
87        let (pk, vk) =
88            Groth16::<Bls12_381>::circuit_specific_setup(circuit_wo_input, &mut rng).unwrap();
89
90        let circuit_with_public_input = LinearEquationRelationWithPublicInput::new(A, B, Y);
91        let input = circuit_with_public_input.serialize_public_input();
92
93        let circuit_with_full_input = LinearEquationRelationWithFullInput::new(A, B, Y, X);
94
95        let proof = Groth16::prove(&pk, circuit_with_full_input, &mut rng).unwrap();
96        let valid_proof = Groth16::verify(&vk, &input, &proof).unwrap();
97        assert!(valid_proof);
98    }
99}