ml_dsa/
pkcs8.rs

1//! PKCS#8 private key encoding support.
2
3#![cfg(feature = "pkcs8")]
4
5use crate::{
6    EncodedVerifyingKey, KeyGen, KeyPair, MlDsa44, MlDsa65, MlDsa87, MlDsaParams, Signature,
7    SigningKey, VerifyingKey,
8};
9use ::pkcs8::{
10    AlgorithmIdentifierRef, PrivateKeyInfoRef,
11    der::{
12        self, AnyRef, Reader, TagNumber,
13        asn1::{ContextSpecific, OctetStringRef},
14    },
15    spki::{
16        self, AlgorithmIdentifier, AssociatedAlgorithmIdentifier, SignatureAlgorithmIdentifier,
17        SubjectPublicKeyInfoRef,
18    },
19};
20use const_oid::db::fips204;
21
22#[cfg(feature = "alloc")]
23use pkcs8::{
24    EncodePrivateKey, EncodePublicKey,
25    der::{
26        Encode, TagMode,
27        asn1::{BitString, BitStringRef},
28    },
29    spki::{SignatureBitStringEncoding, SubjectPublicKeyInfo},
30};
31
32/// Tag number for the seed value.
33const SEED_TAG_NUMBER: TagNumber = TagNumber(0);
34
35/// ML-KEM seed serialized as ASN.1.
36type SeedString<'a> = ContextSpecific<&'a OctetStringRef>;
37
38impl AssociatedAlgorithmIdentifier for MlDsa44 {
39    type Params = AnyRef<'static>;
40
41    const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = AlgorithmIdentifierRef {
42        oid: fips204::ID_ML_DSA_44,
43        parameters: None,
44    };
45}
46
47impl AssociatedAlgorithmIdentifier for MlDsa65 {
48    type Params = AnyRef<'static>;
49
50    const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = AlgorithmIdentifierRef {
51        oid: fips204::ID_ML_DSA_65,
52        parameters: None,
53    };
54}
55
56impl AssociatedAlgorithmIdentifier for MlDsa87 {
57    type Params = AnyRef<'static>;
58
59    const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = AlgorithmIdentifierRef {
60        oid: fips204::ID_ML_DSA_87,
61        parameters: None,
62    };
63}
64
65impl<P> AssociatedAlgorithmIdentifier for Signature<P>
66where
67    P: MlDsaParams,
68    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
69{
70    type Params = AnyRef<'static>;
71
72    const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = P::ALGORITHM_IDENTIFIER;
73}
74
75#[cfg(feature = "alloc")]
76impl<P: MlDsaParams> SignatureBitStringEncoding for Signature<P> {
77    fn to_bitstring(&self) -> der::Result<BitString> {
78        BitString::new(0, self.encode().to_vec())
79    }
80}
81
82impl<P> SignatureAlgorithmIdentifier for KeyPair<P>
83where
84    P: MlDsaParams,
85    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
86{
87    type Params = AnyRef<'static>;
88
89    const SIGNATURE_ALGORITHM_IDENTIFIER: AlgorithmIdentifier<Self::Params> =
90        Signature::<P>::ALGORITHM_IDENTIFIER;
91}
92
93impl<P> TryFrom<PrivateKeyInfoRef<'_>> for KeyPair<P>
94where
95    P: MlDsaParams,
96    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
97{
98    type Error = ::pkcs8::Error;
99
100    fn try_from(private_key_info: PrivateKeyInfoRef<'_>) -> ::pkcs8::Result<Self> {
101        private_key_info
102            .algorithm
103            .assert_algorithm_oid(P::ALGORITHM_IDENTIFIER.oid)?;
104
105        let mut reader = der::SliceReader::new(private_key_info.private_key.as_bytes())?;
106        let seed_string = SeedString::decode_implicit(&mut reader, SEED_TAG_NUMBER)?
107            .ok_or(pkcs8::Error::KeyMalformed)?;
108        let seed = seed_string
109            .value
110            .as_bytes()
111            .try_into()
112            .map_err(|_| pkcs8::Error::KeyMalformed)?;
113        reader.finish()?;
114
115        Ok(P::from_seed(&seed))
116    }
117}
118
119#[cfg(feature = "alloc")]
120impl<P> EncodePrivateKey for KeyPair<P>
121where
122    P: MlDsaParams,
123    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
124{
125    fn to_pkcs8_der(&self) -> ::pkcs8::Result<der::SecretDocument> {
126        let seed_der = SeedString {
127            tag_mode: TagMode::Implicit,
128            tag_number: SEED_TAG_NUMBER,
129            value: OctetStringRef::new(&self.seed)?,
130        }
131        .to_der()?;
132
133        let private_key = OctetStringRef::new(&seed_der)?;
134        let private_key_info = PrivateKeyInfoRef::new(P::ALGORITHM_IDENTIFIER, private_key);
135        ::pkcs8::SecretDocument::encode_msg(&private_key_info).map_err(::pkcs8::Error::Asn1)
136    }
137}
138
139impl<P> SignatureAlgorithmIdentifier for SigningKey<P>
140where
141    P: MlDsaParams,
142    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
143{
144    type Params = AnyRef<'static>;
145
146    const SIGNATURE_ALGORITHM_IDENTIFIER: AlgorithmIdentifier<Self::Params> =
147        Signature::<P>::ALGORITHM_IDENTIFIER;
148}
149
150impl<P> TryFrom<PrivateKeyInfoRef<'_>> for SigningKey<P>
151where
152    P: MlDsaParams,
153    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
154{
155    type Error = ::pkcs8::Error;
156
157    fn try_from(private_key_info: ::pkcs8::PrivateKeyInfoRef<'_>) -> ::pkcs8::Result<Self> {
158        let keypair = KeyPair::try_from(private_key_info)?;
159
160        Ok(keypair.signing_key)
161    }
162}
163
164impl<P> SignatureAlgorithmIdentifier for VerifyingKey<P>
165where
166    P: MlDsaParams,
167    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
168{
169    type Params = AnyRef<'static>;
170
171    const SIGNATURE_ALGORITHM_IDENTIFIER: AlgorithmIdentifier<Self::Params> =
172        Signature::<P>::ALGORITHM_IDENTIFIER;
173}
174
175#[cfg(feature = "alloc")]
176impl<P> EncodePublicKey for VerifyingKey<P>
177where
178    P: MlDsaParams,
179    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
180{
181    fn to_public_key_der(&self) -> spki::Result<der::Document> {
182        let public_key = self.encode();
183        let subject_public_key = BitStringRef::new(0, &public_key)?;
184
185        SubjectPublicKeyInfo {
186            algorithm: P::ALGORITHM_IDENTIFIER,
187            subject_public_key,
188        }
189        .try_into()
190    }
191}
192
193impl<P> TryFrom<SubjectPublicKeyInfoRef<'_>> for VerifyingKey<P>
194where
195    P: MlDsaParams,
196    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
197{
198    type Error = spki::Error;
199
200    fn try_from(spki: SubjectPublicKeyInfoRef<'_>) -> spki::Result<Self> {
201        spki.algorithm
202            .assert_algorithm_oid(P::ALGORITHM_IDENTIFIER.oid)?;
203
204        Ok(Self::decode(
205            &EncodedVerifyingKey::<P>::try_from(
206                spki.subject_public_key
207                    .as_bytes()
208                    .ok_or_else(|| der::Tag::BitString.value_error().to_error())?,
209            )
210            .map_err(|_| ::pkcs8::Error::KeyMalformed)?,
211        ))
212    }
213}