1use std::ops::{Add, Mul};
3use k256::{AffinePoint, ProjectivePoint, Scalar};
4use k256::elliptic_curve::ops::Invert;
5use merlin::Transcript;
6use serde::{Deserialize, Serialize};
7use crate::transcript;
8use crate::util::*;
9
10#[derive(Clone, Debug)]
12pub struct WeightNormLinearArgument {
13 pub g: ProjectivePoint,
14 pub g_vec: Vec<ProjectivePoint>,
15 pub h_vec: Vec<ProjectivePoint>,
16 pub c: Vec<Scalar>,
17 pub rho: Scalar,
18 pub mu: Scalar,
19}
20
21#[derive(Clone, Debug)]
25pub struct Proof {
26 pub r: Vec<ProjectivePoint>,
27 pub x: Vec<ProjectivePoint>,
28 pub l: Vec<Scalar>,
29 pub n: Vec<Scalar>,
30}
31
32#[derive(Serialize, Deserialize, Clone, Debug)]
34pub struct SerializableProof {
35 pub r: Vec<AffinePoint>,
36 pub x: Vec<AffinePoint>,
37 pub l: Vec<Scalar>,
38 pub n: Vec<Scalar>,
39}
40
41impl From<&SerializableProof> for Proof {
42 fn from(value: &SerializableProof) -> Self {
43 return Proof {
44 r: value.r.iter().map(ProjectivePoint::from).collect::<Vec<ProjectivePoint>>(),
45 x: value.x.iter().map(ProjectivePoint::from).collect::<Vec<ProjectivePoint>>(),
46 l: value.l.clone(),
47 n: value.n.clone(),
48 };
49 }
50}
51
52impl From<&Proof> for SerializableProof {
53 fn from(value: &Proof) -> Self {
54 return SerializableProof {
55 r: value.r.iter().map(|r_val| r_val.to_affine()).collect::<Vec<AffinePoint>>(),
56 x: value.x.iter().map(|x_val| x_val.to_affine()).collect::<Vec<AffinePoint>>(),
57 l: value.l.clone(),
58 n: value.n.clone(),
59 };
60 }
61}
62
63impl WeightNormLinearArgument {
64 pub fn commit(&self, l: &[Scalar], n: &[Scalar]) -> ProjectivePoint {
67 let v = vector_mul(&self.c, l).add(weight_vector_mul(n, n, &self.mu));
68 self.
69 g.mul(v).
70 add(vector_mul(&self.h_vec, l)).
71 add(vector_mul(&self.g_vec, n))
72 }
73
74 pub fn verify(&self, commitment: &ProjectivePoint, t: &mut Transcript, proof: Proof) -> bool {
76 if proof.x.len() != proof.r.len() {
77 return false;
78 }
79
80 if proof.x.is_empty() {
81 return commitment.eq(&self.commit(&proof.l, &proof.n));
82 }
83
84 let (c0, c1) = reduce(&self.c);
85 let (g0, g1) = reduce(&self.g_vec);
86 let (h0, h1) = reduce(&self.h_vec);
87
88 transcript::app_point(b"wnla_com", commitment, t);
89 transcript::app_point(b"wnla_x", proof.x.last().unwrap(), t);
90 transcript::app_point(b"wnla_r", proof.r.last().unwrap(), t);
91 t.append_u64(b"l.sz", self.h_vec.len() as u64);
92 t.append_u64(b"n.sz", self.g_vec.len() as u64);
93
94 let y = transcript::get_challenge(b"wnla_challenge", t);
95
96 let h_ = vector_add(&h0, &vector_mul_on_scalar(&h1, &y));
97 let g_ = vector_add(&vector_mul_on_scalar(&g0, &self.rho), &vector_mul_on_scalar(&g1, &y));
98 let c_ = vector_add(&c0, &vector_mul_on_scalar(&c1, &y));
99
100 let com_ = commitment.
101 add(&proof.x.last().unwrap().mul(&y)).
102 add(&proof.r.last().unwrap().mul(&y.mul(&y).sub(&Scalar::ONE)));
103
104 let wnla = WeightNormLinearArgument {
105 g: self.g,
106 g_vec: g_,
107 h_vec: h_,
108 c: c_,
109 rho: self.mu,
110 mu: self.mu.mul(&self.mu),
111 };
112
113 let proof_ = Proof {
114 r: proof.r[..proof.r.len() - 1].to_vec(),
115 x: proof.x[..proof.x.len() - 1].to_vec(),
116 l: proof.l,
117 n: proof.n,
118 };
119
120 wnla.verify(&com_, t, proof_)
121 }
122
123 pub fn prove(&self, commitment: &ProjectivePoint, t: &mut Transcript, l: Vec<Scalar>, n: Vec<Scalar>) -> Proof {
126 if l.len() + n.len() < 6 {
127 return Proof {
128 r: vec![],
129 x: vec![],
130 l,
131 n,
132 };
133 }
134
135 let rho_inv = self.rho.invert_vartime().unwrap();
136
137 let (c0, c1) = reduce(&self.c);
138 let (l0, l1) = reduce(&l);
139 let (n0, n1) = reduce(&n);
140 let (g0, g1) = reduce(&self.g_vec);
141 let (h0, h1) = reduce(&self.h_vec);
142
143 let mu2 = self.mu.mul(&self.mu);
144
145 let vx = weight_vector_mul(&n0, &n1, &mu2).
146 mul(&rho_inv.mul(&Scalar::from(2u32))).
147 add(&vector_mul(&c0, &l1)).
148 add(&vector_mul(&c1, &l0));
149
150 let vr = weight_vector_mul(&n1, &n1, &mu2).add(&vector_mul(&c1, &l1));
151
152 let x = self.g.mul(vx).
153 add(&vector_mul(&h0, &l1)).
154 add(&vector_mul(&h1, &l0)).
155 add(&vector_mul(&g0, &vector_mul_on_scalar(&n1, &self.rho))).
156 add(&vector_mul(&g1, &vector_mul_on_scalar(&n0, &rho_inv)));
157
158 let r = self.g.mul(vr).
159 add(vector_mul(&h1, &l1)).
160 add(vector_mul(&g1, &n1));
161
162 transcript::app_point(b"wnla_com", commitment, t);
163 transcript::app_point(b"wnla_x", &x, t);
164 transcript::app_point(b"wnla_r", &r, t);
165 t.append_u64(b"l.sz", l.len() as u64);
166 t.append_u64(b"n.sz", n.len() as u64);
167
168 let y = transcript::get_challenge(b"wnla_challenge", t);
169
170 let h_ = vector_add(&h0, &vector_mul_on_scalar(&h1, &y));
171 let g_ = vector_add(&vector_mul_on_scalar(&g0, &self.rho), &vector_mul_on_scalar(&g1, &y));
172 let c_ = vector_add(&c0, &vector_mul_on_scalar(&c1, &y));
173
174 let l_ = vector_add(&l0, &vector_mul_on_scalar(&l1, &y));
175 let n_ = vector_add(&vector_mul_on_scalar(&n0, &rho_inv), &vector_mul_on_scalar(&n1, &y));
176
177 let wnla = WeightNormLinearArgument {
178 g: self.g,
179 g_vec: g_,
180 h_vec: h_,
181 c: c_,
182 rho: self.mu,
183 mu: mu2,
184 };
185
186 let mut proof = wnla.prove(&wnla.commit(&l_, &n_), t, l_, n_);
187 proof.r.push(r);
188 proof.x.push(x);
189 proof
190 }
191}