Skip to main content

rsa/pss/
verifying_key.rs

1use super::{verify_digest_into, GenericSignature};
2use crate::key::GenericRsaPublicKey;
3use crate::traits::{modular::ModulusParams, PublicKeyParts, UnsignedModularInt};
4use core::marker::PhantomData;
5#[cfg(feature = "alloc")]
6use crypto_bigint::{modular::BoxedMontyParams, BoxedUint};
7use digest::{Digest, FixedOutputReset, Update};
8use signature::{hazmat::PrehashVerifier, DigestVerifier, Verifier};
9
10#[cfg(all(feature = "alloc", feature = "encoding"))]
11use crate::RsaPublicKey;
12#[cfg(feature = "encoding")]
13use {
14    crate::encoding::ID_RSASSA_PSS,
15    const_oid::AssociatedOid,
16    pkcs8::{Document, EncodePublicKey},
17    spki::{der::AnyRef, AlgorithmIdentifierRef, AssociatedAlgorithmIdentifier},
18};
19#[cfg(feature = "serde")]
20use {
21    serdect::serde::{de, ser, Deserialize, Serialize},
22    spki::DecodePublicKey,
23};
24
25/// Verifying key for checking the validity of RSASSA-PSS signatures as
26/// described in [RFC8017 § 8.1].
27///
28/// [RFC8017 § 8.1]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.1
29#[derive(Debug)]
30pub struct GenericVerifyingKey<D, T, M>
31where
32    D: Digest,
33    T: UnsignedModularInt,
34    M: ModulusParams<Modulus = T>,
35{
36    pub(super) inner: GenericRsaPublicKey<T, M>,
37    pub(super) salt_len: Option<usize>,
38    pub(super) phantom: PhantomData<D>,
39}
40
41/// Boxed RSASSA-PSS verifying key alias.
42#[cfg(feature = "alloc")]
43pub type VerifyingKey<D> = GenericVerifyingKey<D, BoxedUint, BoxedMontyParams>;
44
45impl<D, T, M> GenericVerifyingKey<D, T, M>
46where
47    D: Digest,
48    T: UnsignedModularInt,
49    M: ModulusParams<Modulus = T>,
50{
51    /// Create a new RSASSA-PSS verifying key.
52    /// Digest output size is used as a salt length.
53    pub fn new(key: GenericRsaPublicKey<T, M>) -> Self {
54        Self::new_with_salt_len(key, <D as Digest>::output_size())
55    }
56
57    /// Create a new RSASSA-PSS verifying key.
58    pub fn new_with_salt_len(key: GenericRsaPublicKey<T, M>, salt_len: usize) -> Self {
59        Self {
60            inner: key,
61            salt_len: Some(salt_len),
62            phantom: Default::default(),
63        }
64    }
65
66    /// Create a new RSASSA-PSS verifying key.
67    /// Attempts to automatically detect the salt length.
68    pub fn new_with_auto_salt_len(key: GenericRsaPublicKey<T, M>) -> Self {
69        Self {
70            inner: key,
71            salt_len: None,
72            phantom: Default::default(),
73        }
74    }
75
76    /// Return specified salt length for this key
77    pub fn salt_len(&self) -> Option<usize> {
78        self.salt_len
79    }
80}
81
82impl<D, T, M> GenericVerifyingKey<D, T, M>
83where
84    D: Digest + FixedOutputReset,
85    T: UnsignedModularInt,
86    M: ModulusParams<Modulus = T>,
87{
88    fn verify_prehash_signature(
89        &self,
90        prehash: &[u8],
91        signature: &GenericSignature<T>,
92    ) -> signature::Result<()> {
93        let mut storage = self.inner.n().as_ref().to_be_bytes();
94        verify_digest_into::<D, _, T>(
95            &self.inner,
96            prehash,
97            signature.inner(),
98            self.salt_len,
99            storage.as_mut(),
100        )
101        .map_err(Into::into)
102    }
103}
104
105//
106// `*Verifier` trait impls
107//
108
109impl<D, T, M> DigestVerifier<D, GenericSignature<T>> for GenericVerifyingKey<D, T, M>
110where
111    D: Digest + FixedOutputReset + Update,
112    T: UnsignedModularInt,
113    M: ModulusParams<Modulus = T>,
114{
115    fn verify_digest<F: Fn(&mut D) -> signature::Result<()>>(
116        &self,
117        f: F,
118        signature: &GenericSignature<T>,
119    ) -> signature::Result<()> {
120        let mut digest = D::new();
121        f(&mut digest)?;
122        self.verify_prehash_signature(&digest.finalize(), signature)
123    }
124}
125
126impl<D, T, M> PrehashVerifier<GenericSignature<T>> for GenericVerifyingKey<D, T, M>
127where
128    D: Digest + FixedOutputReset,
129    T: UnsignedModularInt,
130    M: ModulusParams<Modulus = T>,
131{
132    fn verify_prehash(
133        &self,
134        prehash: &[u8],
135        signature: &GenericSignature<T>,
136    ) -> signature::Result<()> {
137        self.verify_prehash_signature(prehash, signature)
138    }
139}
140
141impl<D, T, M> Verifier<GenericSignature<T>> for GenericVerifyingKey<D, T, M>
142where
143    D: Digest + FixedOutputReset,
144    T: UnsignedModularInt,
145    M: ModulusParams<Modulus = T>,
146{
147    fn verify(&self, msg: &[u8], signature: &GenericSignature<T>) -> signature::Result<()> {
148        self.verify_prehash_signature(&D::digest(msg), signature)
149    }
150}
151
152//
153// Other trait impls
154//
155
156impl<D, T, M> AsRef<GenericRsaPublicKey<T, M>> for GenericVerifyingKey<D, T, M>
157where
158    D: Digest,
159    T: UnsignedModularInt,
160    M: ModulusParams<Modulus = T>,
161{
162    fn as_ref(&self) -> &GenericRsaPublicKey<T, M> {
163        &self.inner
164    }
165}
166
167#[cfg(feature = "encoding")]
168#[cfg(feature = "alloc")]
169impl<D> AssociatedAlgorithmIdentifier for VerifyingKey<D>
170where
171    D: Digest,
172{
173    type Params = AnyRef<'static>;
174
175    const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = pkcs1::ALGORITHM_ID;
176}
177
178// Implemented manually so we don't have to bind D with Clone
179impl<D, T, M> Clone for GenericVerifyingKey<D, T, M>
180where
181    D: Digest,
182    T: UnsignedModularInt,
183    M: ModulusParams<Modulus = T> + Clone,
184{
185    fn clone(&self) -> Self {
186        Self {
187            inner: self.inner.clone(),
188            salt_len: self.salt_len,
189            phantom: Default::default(),
190        }
191    }
192}
193
194#[cfg(feature = "encoding")]
195#[cfg(feature = "alloc")]
196impl<D> EncodePublicKey for VerifyingKey<D>
197where
198    D: Digest,
199{
200    fn to_public_key_der(&self) -> spki::Result<Document> {
201        self.inner.to_public_key_der()
202    }
203}
204
205impl<D, T, M> From<GenericRsaPublicKey<T, M>> for GenericVerifyingKey<D, T, M>
206where
207    D: Digest,
208    T: UnsignedModularInt,
209    M: ModulusParams<Modulus = T>,
210{
211    fn from(key: GenericRsaPublicKey<T, M>) -> Self {
212        Self::new(key)
213    }
214}
215
216impl<D, T, M> From<GenericVerifyingKey<D, T, M>> for GenericRsaPublicKey<T, M>
217where
218    D: Digest,
219    T: UnsignedModularInt,
220    M: ModulusParams<Modulus = T>,
221{
222    fn from(key: GenericVerifyingKey<D, T, M>) -> Self {
223        key.inner
224    }
225}
226
227#[cfg(feature = "encoding")]
228#[cfg(feature = "alloc")]
229impl<D> TryFrom<pkcs8::SubjectPublicKeyInfoRef<'_>> for VerifyingKey<D>
230where
231    D: Digest + AssociatedOid,
232{
233    type Error = spki::Error;
234
235    fn try_from(spki: pkcs8::SubjectPublicKeyInfoRef<'_>) -> spki::Result<Self> {
236        match spki.algorithm.oid {
237            ID_RSASSA_PSS | pkcs1::ALGORITHM_OID => (),
238            _ => {
239                return Err(spki::Error::OidUnknown {
240                    oid: spki.algorithm.oid,
241                });
242            }
243        }
244
245        RsaPublicKey::try_from(spki).map(Self::new)
246    }
247}
248
249impl<D, T, M> PartialEq for GenericVerifyingKey<D, T, M>
250where
251    D: Digest,
252    T: UnsignedModularInt + PartialEq,
253    M: ModulusParams<Modulus = T>,
254{
255    fn eq(&self, other: &Self) -> bool {
256        self.inner == other.inner && self.salt_len == other.salt_len
257    }
258}
259
260#[cfg(feature = "serde")]
261#[cfg(feature = "alloc")]
262impl<D> Serialize for VerifyingKey<D>
263where
264    D: Digest,
265{
266    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
267    where
268        S: serde::Serializer,
269    {
270        let der = self.to_public_key_der().map_err(ser::Error::custom)?;
271        serdect::slice::serialize_hex_lower_or_bin(&der, serializer)
272    }
273}
274
275#[cfg(feature = "serde")]
276#[cfg(feature = "alloc")]
277impl<'de, D> Deserialize<'de> for VerifyingKey<D>
278where
279    D: Digest + AssociatedOid,
280{
281    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
282    where
283        De: serde::Deserializer<'de>,
284    {
285        let der_bytes = serdect::slice::deserialize_hex_or_bin_vec(deserializer)?;
286        Self::from_public_key_der(&der_bytes).map_err(de::Error::custom)
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    #[test]
293    #[cfg(all(feature = "hazmat", feature = "serde", feature = "private-key"))]
294    fn test_serde() {
295        use super::*;
296        use crate::RsaPrivateKey;
297        use rand::rngs::ChaCha8Rng;
298        use rand_core::SeedableRng;
299        use serde_test::{assert_tokens, Configure, Token};
300        use sha2::Sha256;
301
302        let mut rng = ChaCha8Rng::from_seed([42; 32]);
303        let priv_key = RsaPrivateKey::new_unchecked(&mut rng, 64).expect("failed to generate key");
304        let pub_key = priv_key.to_public_key();
305        let verifying_key = VerifyingKey::<Sha256>::new(pub_key);
306
307        let tokens = [Token::Str(
308            "3024300d06092a864886f70d01010105000313003010020900ab240c3361d02e370203010001",
309        )];
310
311        assert_tokens(&verifying_key.readable(), &tokens);
312    }
313}