1use crate::{
4 B32, B64, EncodedVerifyingKey, MlDsaParams, MuBuilder, Signature,
5 algebra::{Elem, NttMatrix, NttVector, Vector},
6 crypto::H,
7 ntt::{Ntt, NttInverse},
8 param::ParameterSet,
9 param::VerifyingKeySize,
10 sampling::{expand_a, sample_in_ball},
11};
12use common::{Key, KeyExport, KeyInit, KeySizeUser};
13use module_lattice::MaybeBox;
14use shake::Shake256;
15use signature::{DigestVerifier, Error, MultipartVerifier};
16
17#[derive(Clone, Debug, PartialEq)]
19pub struct VerifyingKey<P: ParameterSet> {
20 rho: B32,
22
23 t1: MaybeBox<Vector<P::K>>,
25
26 precomputed_values: MaybeBox<PrecomputedValues<P>>,
28}
29
30#[derive(Clone, Debug, PartialEq)]
33struct PrecomputedValues<P: ParameterSet> {
34 A_hat: NttMatrix<P::K, P::L>,
36
37 t1_2d_hat: NttVector<P::K>,
39
40 tr: B64,
42}
43
44impl<P: MlDsaParams> PrecomputedValues<P> {
45 fn new(t1: &Vector<P::K>, enc: &EncodedVerifyingKey<P>, A_hat: NttMatrix<P::K, P::L>) -> Self {
46 let t1_2d_hat = (Elem::new(1 << 13) * t1).ntt();
47 let tr = H::default().absorb(enc).squeeze_new();
48
49 Self {
50 A_hat,
51 t1_2d_hat,
52 tr,
53 }
54 }
55}
56
57impl<P: MlDsaParams> VerifyingKey<P> {
58 pub(crate) fn new(
59 rho: B32,
60 t1: Vector<P::K>,
61 A_hat: NttMatrix<P::K, P::L>,
62 enc: Option<EncodedVerifyingKey<P>>,
63 ) -> Self {
64 let enc = enc.unwrap_or_else(|| Self::encode_internal(&rho, &t1));
65 let precomputed_values = PrecomputedValues::new(&t1, &enc, A_hat);
66
67 Self {
68 rho,
69 t1: MaybeBox::new(t1),
70 precomputed_values: MaybeBox::new(precomputed_values),
71 }
72 }
73
74 #[inline]
75 fn new_expand_a(rho: B32, t1: Vector<P::K>, enc: Option<EncodedVerifyingKey<P>>) -> Self {
76 let A_hat = expand_a(&rho);
77 Self::new(rho, t1, A_hat, enc)
78 }
79
80 pub fn compute_mu<F: FnOnce(&mut Shake256) -> Result<(), Error>>(
85 &self,
86 Mp: F,
87 ctx: &[u8],
88 ) -> Result<B64, Error> {
89 let mut mu = MuBuilder::new(&self.precomputed_values.tr, ctx);
90 Mp(mu.as_mut())?;
91 Ok(mu.finish())
92 }
93
94 pub fn verify_internal(&self, M: &[u8], sigma: &Signature<P>) -> bool
99 where
100 P: MlDsaParams,
101 {
102 let mu = MuBuilder::internal(&self.precomputed_values.tr, &[M]);
103 self.raw_verify_mu(&mu, sigma)
104 }
105
106 pub(crate) fn raw_verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool
107 where
108 P: MlDsaParams,
109 {
110 let c = sample_in_ball(&sigma.c_tilde, P::TAU);
112
113 let z_hat = sigma.z.ntt();
114 let c_hat = c.ntt();
115 let Az_hat = &self.precomputed_values.A_hat * &z_hat;
116 let ct1_2d_hat = &c_hat * &self.precomputed_values.t1_2d_hat;
117
118 let wp_approx = (&Az_hat - &ct1_2d_hat).ntt_inverse();
119 let w1p = sigma.h.use_hint(&wp_approx);
120
121 let w1p_tilde = P::encode_w1(&w1p);
122 let cp_tilde = H::default()
123 .absorb(mu)
124 .absorb(&w1p_tilde)
125 .squeeze_new::<P::Lambda>();
126
127 sigma.c_tilde == cp_tilde
128 }
129
130 pub fn verify_with_context(&self, M: &[u8], ctx: &[u8], sigma: &Signature<P>) -> bool {
132 self.raw_verify_with_context(&[M], ctx, sigma)
133 }
134
135 pub fn verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool {
137 self.raw_verify_mu(mu, sigma)
138 }
139
140 fn raw_verify_with_context(&self, M: &[&[u8]], ctx: &[u8], sigma: &Signature<P>) -> bool {
141 if ctx.len() > 255 {
142 return false;
143 }
144
145 let mu = MuBuilder::new(&self.precomputed_values.tr, ctx).message(M);
146 self.verify_mu(&mu, sigma)
147 }
148
149 pub(crate) fn encode_internal(rho: &B32, t1: &Vector<P::K>) -> EncodedVerifyingKey<P> {
150 let t1_enc = P::encode_t1(t1);
151 P::concat_vk(rho.clone(), t1_enc)
152 }
153
154 #[must_use]
158 pub fn encode(&self) -> EncodedVerifyingKey<P> {
159 Self::encode_internal(&self.rho, &self.t1)
160 }
161
162 pub fn decode(enc: &EncodedVerifyingKey<P>) -> Self {
166 let (rho, t1_enc) = P::split_vk(enc);
167 let t1 = P::decode_t1(t1_enc);
168 Self::new_expand_a(rho.clone(), t1, Some(enc.clone()))
169 }
170}
171
172impl<P: MlDsaParams> KeySizeUser for VerifyingKey<P> {
173 type KeySize = VerifyingKeySize<P>;
174}
175
176impl<P: MlDsaParams> KeyInit for VerifyingKey<P> {
177 fn new(key: &Key<Self>) -> Self {
178 Self::decode(key)
179 }
180}
181
182impl<P: MlDsaParams> KeyExport for VerifyingKey<P> {
183 fn to_bytes(&self) -> Key<Self> {
184 self.encode()
185 }
186}
187
188impl<P: MlDsaParams> core::hash::Hash for VerifyingKey<P> {
189 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
190 self.encode().hash(state);
191 }
192}
193
194impl<P: MlDsaParams> signature::Verifier<Signature<P>> for VerifyingKey<P> {
195 fn verify(&self, msg: &[u8], signature: &Signature<P>) -> Result<(), Error> {
196 self.multipart_verify(&[msg], signature)
197 }
198}
199
200impl<P: MlDsaParams> MultipartVerifier<Signature<P>> for VerifyingKey<P> {
201 fn multipart_verify(&self, msg: &[&[u8]], signature: &Signature<P>) -> Result<(), Error> {
202 self.raw_verify_with_context(msg, &[], signature)
203 .then_some(())
204 .ok_or(Error::new())
205 }
206}
207
208impl<P: MlDsaParams> DigestVerifier<Shake256, Signature<P>> for VerifyingKey<P> {
209 fn verify_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
210 &self,
211 f: F,
212 signature: &Signature<P>,
213 ) -> Result<(), Error> {
214 let mut mu = MuBuilder::new(&self.precomputed_values.tr, &[]);
215 f(mu.as_mut())?;
216 let mu = mu.finish();
217
218 self.raw_verify_mu(&mu, signature)
219 .then_some(())
220 .ok_or(Error::new())
221 }
222}