Skip to main content

ml_dsa/
verifying.rs

1//! ML-DSA signature verification.
2
3use crate::{
4    B32, B64, EncodedVerifyingKey, MlDsaParams, MuBuilder, Signature,
5    algebra::{Elem, NttMatrix, NttVector, Vector},
6    crypto::H,
7    ntt::{Ntt, NttInverse},
8    param::ParameterSet,
9    param::VerifyingKeySize,
10    sampling::{expand_a, sample_in_ball},
11};
12use common::{Key, KeyExport, KeyInit, KeySizeUser};
13use module_lattice::MaybeBox;
14use shake::Shake256;
15use signature::{DigestVerifier, Error, MultipartVerifier};
16
17/// An ML-DSA verification key.
18#[derive(Clone, Debug, PartialEq)]
19pub struct VerifyingKey<P: ParameterSet> {
20    /// Public seed used to deterministically re-expand `A_hat`.
21    rho: B32,
22
23    /// High bits of the public key polynomial `t`.
24    t1: MaybeBox<Vector<P::K>>,
25
26    /// Precomputed expanded values.
27    precomputed_values: MaybeBox<PrecomputedValues<P>>,
28}
29
30/// Cached values derived from `rho` and `t1` at key construction time to avoid re-expanding them
31/// when verifying signatures.
32#[derive(Clone, Debug, PartialEq)]
33struct PrecomputedValues<P: ParameterSet> {
34    /// Expanded public matrix in NTT domain.
35    A_hat: NttMatrix<P::K, P::L>,
36
37    /// `2ᵈ ⋅ t1` which can be reused in signature verification.
38    t1_2d_hat: NttVector<P::K>,
39
40    /// Hash of the encoded public key, used to bind messages to the key this was precomputed from.
41    tr: B64,
42}
43
44impl<P: MlDsaParams> PrecomputedValues<P> {
45    fn new(t1: &Vector<P::K>, enc: &EncodedVerifyingKey<P>, A_hat: NttMatrix<P::K, P::L>) -> Self {
46        let t1_2d_hat = (Elem::new(1 << 13) * t1).ntt();
47        let tr = H::default().absorb(enc).squeeze_new();
48
49        Self {
50            A_hat,
51            t1_2d_hat,
52            tr,
53        }
54    }
55}
56
57impl<P: MlDsaParams> VerifyingKey<P> {
58    pub(crate) fn new(
59        rho: B32,
60        t1: Vector<P::K>,
61        A_hat: NttMatrix<P::K, P::L>,
62        enc: Option<EncodedVerifyingKey<P>>,
63    ) -> Self {
64        let enc = enc.unwrap_or_else(|| Self::encode_internal(&rho, &t1));
65        let precomputed_values = PrecomputedValues::new(&t1, &enc, A_hat);
66
67        Self {
68            rho,
69            t1: MaybeBox::new(t1),
70            precomputed_values: MaybeBox::new(precomputed_values),
71        }
72    }
73
74    #[inline]
75    fn new_expand_a(rho: B32, t1: Vector<P::K>, enc: Option<EncodedVerifyingKey<P>>) -> Self {
76        let A_hat = expand_a(&rho);
77        Self::new(rho, t1, A_hat, enc)
78    }
79
80    /// Computes µ according to FIPS 204 for use in `ML-DSA.Sign` and `ML-DSA.Verify`.
81    ///
82    /// # Errors
83    /// Returns [`Error`] if the given `Mp` returns one.
84    pub fn compute_mu<F: FnOnce(&mut Shake256) -> Result<(), Error>>(
85        &self,
86        Mp: F,
87        ctx: &[u8],
88    ) -> Result<B64, Error> {
89        let mut mu = MuBuilder::new(&self.precomputed_values.tr, ctx);
90        Mp(mu.as_mut())?;
91        Ok(mu.finish())
92    }
93
94    /// Implementation of Algorithm 8: `ML-DSA.Verify_internal` algorithm from FIPS 204.
95    ///
96    /// It does not include the domain separator that distinguishes between the normal and
97    /// pre-hashed cases, and it does not separate the context string from the rest of the message.
98    pub fn verify_internal(&self, M: &[u8], sigma: &Signature<P>) -> bool
99    where
100        P: MlDsaParams,
101    {
102        let mu = MuBuilder::internal(&self.precomputed_values.tr, &[M]);
103        self.raw_verify_mu(&mu, sigma)
104    }
105
106    pub(crate) fn raw_verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool
107    where
108        P: MlDsaParams,
109    {
110        // Reconstruct w
111        let c = sample_in_ball(&sigma.c_tilde, P::TAU);
112
113        let z_hat = sigma.z.ntt();
114        let c_hat = c.ntt();
115        let Az_hat = &self.precomputed_values.A_hat * &z_hat;
116        let ct1_2d_hat = &c_hat * &self.precomputed_values.t1_2d_hat;
117
118        let wp_approx = (&Az_hat - &ct1_2d_hat).ntt_inverse();
119        let w1p = sigma.h.use_hint(&wp_approx);
120
121        let w1p_tilde = P::encode_w1(&w1p);
122        let cp_tilde = H::default()
123            .absorb(mu)
124            .absorb(&w1p_tilde)
125            .squeeze_new::<P::Lambda>();
126
127        sigma.c_tilde == cp_tilde
128    }
129
130    /// Implementation of Algorithm 3: `ML-DSA.Verify` from FIPS 204.
131    pub fn verify_with_context(&self, M: &[u8], ctx: &[u8], sigma: &Signature<P>) -> bool {
132        self.raw_verify_with_context(&[M], ctx, sigma)
133    }
134
135    /// Implementation of Algorithm 3: `ML-DSA.Verify` from FIPS 204 with a pre-computed μ.
136    pub fn verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool {
137        self.raw_verify_mu(mu, sigma)
138    }
139
140    fn raw_verify_with_context(&self, M: &[&[u8]], ctx: &[u8], sigma: &Signature<P>) -> bool {
141        if ctx.len() > 255 {
142            return false;
143        }
144
145        let mu = MuBuilder::new(&self.precomputed_values.tr, ctx).message(M);
146        self.verify_mu(&mu, sigma)
147    }
148
149    pub(crate) fn encode_internal(rho: &B32, t1: &Vector<P::K>) -> EncodedVerifyingKey<P> {
150        let t1_enc = P::encode_t1(t1);
151        P::concat_vk(rho.clone(), t1_enc)
152    }
153
154    /// Encode the key in a fixed-size byte array.
155    ///
156    /// Implementation of Algorithm 22: `pkEncode` from FIPS 204.
157    #[must_use]
158    pub fn encode(&self) -> EncodedVerifyingKey<P> {
159        Self::encode_internal(&self.rho, &self.t1)
160    }
161
162    /// Decode the key from an appropriately sized byte array.
163    ///
164    /// Implementation of Algorithm 23: `pkDecode` from FIPS 204.
165    pub fn decode(enc: &EncodedVerifyingKey<P>) -> Self {
166        let (rho, t1_enc) = P::split_vk(enc);
167        let t1 = P::decode_t1(t1_enc);
168        Self::new_expand_a(rho.clone(), t1, Some(enc.clone()))
169    }
170}
171
172impl<P: MlDsaParams> KeySizeUser for VerifyingKey<P> {
173    type KeySize = VerifyingKeySize<P>;
174}
175
176impl<P: MlDsaParams> KeyInit for VerifyingKey<P> {
177    fn new(key: &Key<Self>) -> Self {
178        Self::decode(key)
179    }
180}
181
182impl<P: MlDsaParams> KeyExport for VerifyingKey<P> {
183    fn to_bytes(&self) -> Key<Self> {
184        self.encode()
185    }
186}
187
188impl<P: MlDsaParams> core::hash::Hash for VerifyingKey<P> {
189    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
190        self.encode().hash(state);
191    }
192}
193
194impl<P: MlDsaParams> signature::Verifier<Signature<P>> for VerifyingKey<P> {
195    fn verify(&self, msg: &[u8], signature: &Signature<P>) -> Result<(), Error> {
196        self.multipart_verify(&[msg], signature)
197    }
198}
199
200impl<P: MlDsaParams> MultipartVerifier<Signature<P>> for VerifyingKey<P> {
201    fn multipart_verify(&self, msg: &[&[u8]], signature: &Signature<P>) -> Result<(), Error> {
202        self.raw_verify_with_context(msg, &[], signature)
203            .then_some(())
204            .ok_or(Error::new())
205    }
206}
207
208impl<P: MlDsaParams> DigestVerifier<Shake256, Signature<P>> for VerifyingKey<P> {
209    fn verify_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
210        &self,
211        f: F,
212        signature: &Signature<P>,
213    ) -> Result<(), Error> {
214        let mut mu = MuBuilder::new(&self.precomputed_values.tr, &[]);
215        f(mu.as_mut())?;
216        let mu = mu.finish();
217
218        self.raw_verify_mu(&mu, signature)
219            .then_some(())
220            .ok_or(Error::new())
221    }
222}