1#![cfg(feature = "pkcs8")]
11
12pub use ::pkcs8::{DecodePrivateKey, DecodePublicKey, spki::AssociatedAlgorithmIdentifier};
13pub use const_oid::AssociatedOid;
14
15#[cfg(feature = "alloc")]
16pub use ::pkcs8::{EncodePrivateKey, EncodePublicKey};
17
18use crate::{
19 DecapsulationKey, EncapsulationKey, MlKem512, MlKem768, MlKem1024,
20 param::{EncapsulationKeySize, KemParams},
21 pke::EncryptionKey,
22};
23use ::pkcs8::{
24 der::{
25 AnyRef, Reader, SliceReader, TagNumber,
26 asn1::{ContextSpecific, OctetStringRef},
27 },
28 spki,
29};
30use array::Array;
31
32#[cfg(feature = "alloc")]
33use {
34 ::kem::KeyExport,
35 ::pkcs8::der::{Encode, TagMode, asn1::BitStringRef},
36};
37
38const SEED_TAG_NUMBER: TagNumber = TagNumber(0);
40
41type SeedString<'a> = ContextSpecific<&'a OctetStringRef>;
43
44impl AssociatedOid for MlKem512 {
45 const OID: ::pkcs8::ObjectIdentifier = const_oid::db::fips203::ID_ALG_ML_KEM_512;
46}
47
48impl AssociatedOid for MlKem768 {
49 const OID: ::pkcs8::ObjectIdentifier = const_oid::db::fips203::ID_ALG_ML_KEM_768;
50}
51
52impl AssociatedOid for MlKem1024 {
53 const OID: ::pkcs8::ObjectIdentifier = const_oid::db::fips203::ID_ALG_ML_KEM_1024;
54}
55
56impl AssociatedAlgorithmIdentifier for MlKem512 {
57 type Params = ::pkcs8::der::AnyRef<'static>;
58
59 const ALGORITHM_IDENTIFIER: spki::AlgorithmIdentifier<Self::Params> =
60 spki::AlgorithmIdentifier {
61 oid: Self::OID,
62 parameters: None,
63 };
64}
65
66impl AssociatedAlgorithmIdentifier for MlKem768 {
67 type Params = ::pkcs8::der::AnyRef<'static>;
68
69 const ALGORITHM_IDENTIFIER: spki::AlgorithmIdentifier<Self::Params> =
70 spki::AlgorithmIdentifier {
71 oid: Self::OID,
72 parameters: None,
73 };
74}
75
76impl AssociatedAlgorithmIdentifier for MlKem1024 {
77 type Params = ::pkcs8::der::AnyRef<'static>;
78
79 const ALGORITHM_IDENTIFIER: spki::AlgorithmIdentifier<Self::Params> =
80 spki::AlgorithmIdentifier {
81 oid: Self::OID,
82 parameters: None,
83 };
84}
85
86impl<P> AssociatedAlgorithmIdentifier for EncapsulationKey<P>
87where
88 P: KemParams + AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
89{
90 type Params = P::Params;
91
92 const ALGORITHM_IDENTIFIER: spki::AlgorithmIdentifier<Self::Params> = P::ALGORITHM_IDENTIFIER;
93}
94
95#[cfg(feature = "alloc")]
96impl<P> EncodePublicKey for EncapsulationKey<P>
97where
98 P: KemParams + AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
99{
100 fn to_public_key_der(&self) -> spki::Result<pkcs8::Document> {
103 let public_key = self.to_bytes();
104 let subject_public_key = BitStringRef::new(0, &public_key)?;
105
106 ::pkcs8::SubjectPublicKeyInfo {
107 algorithm: P::ALGORITHM_IDENTIFIER,
108 subject_public_key,
109 }
110 .try_into()
111 }
112}
113
114impl<P> TryFrom<::pkcs8::SubjectPublicKeyInfoRef<'_>> for EncapsulationKey<P>
115where
116 P: KemParams + AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
117{
118 type Error = spki::Error;
119
120 fn try_from(spki: ::pkcs8::SubjectPublicKeyInfoRef<'_>) -> Result<Self, spki::Error> {
123 if spki.algorithm.oid != P::ALGORITHM_IDENTIFIER.oid {
124 return Err(spki::Error::OidUnknown {
125 oid: P::ALGORITHM_IDENTIFIER.oid,
126 });
127 }
128
129 let bitstring_of_encapsulation_key = spki.subject_public_key;
130 let enc_key = match bitstring_of_encapsulation_key.as_bytes() {
131 Some(bytes) => {
132 let arr: Array<u8, EncapsulationKeySize<P>> = match bytes.try_into() {
133 Ok(array) => array,
134 Err(_) => return Err(spki::Error::KeyMalformed),
135 };
136 EncryptionKey::from_bytes(&arr).map_err(|_| spki::Error::KeyMalformed)?
137 }
138 None => return Err(spki::Error::KeyMalformed),
139 };
140
141 Ok(Self::from_encryption_key(enc_key))
142 }
143}
144
145impl<P> AssociatedAlgorithmIdentifier for DecapsulationKey<P>
146where
147 P: KemParams + AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
148{
149 type Params = P::Params;
150
151 const ALGORITHM_IDENTIFIER: spki::AlgorithmIdentifier<Self::Params> = P::ALGORITHM_IDENTIFIER;
152}
153
154#[cfg(feature = "alloc")]
155impl<P> EncodePrivateKey for DecapsulationKey<P>
156where
157 P: KemParams + AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
158{
159 fn to_pkcs8_der(&self) -> ::pkcs8::Result<pkcs8::SecretDocument> {
162 let seed = self.to_seed().ok_or(pkcs8::KeyError::Invalid)?;
163
164 let seed_der = SeedString {
165 tag_mode: TagMode::Implicit,
166 tag_number: SEED_TAG_NUMBER,
167 value: OctetStringRef::new(&seed)?,
168 }
169 .to_der()?;
170
171 let private_key = OctetStringRef::new(&seed_der)?;
172 let private_key_info = pkcs8::PrivateKeyInfoRef::new(P::ALGORITHM_IDENTIFIER, private_key);
173 pkcs8::SecretDocument::encode_msg(&private_key_info).map_err(pkcs8::Error::Asn1)
174 }
175}
176
177impl<P> TryFrom<::pkcs8::PrivateKeyInfoRef<'_>> for DecapsulationKey<P>
178where
179 P: KemParams + AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
180{
181 type Error = ::pkcs8::Error;
182
183 fn try_from(private_key_info_ref: ::pkcs8::PrivateKeyInfoRef<'_>) -> Result<Self, Self::Error> {
186 private_key_info_ref
187 .algorithm
188 .assert_algorithm_oid(P::ALGORITHM_IDENTIFIER.oid)?;
189
190 let mut reader = SliceReader::new(private_key_info_ref.private_key.as_bytes())?;
191 let seed_string = SeedString::decode_implicit(&mut reader, SEED_TAG_NUMBER)?
192 .ok_or(pkcs8::KeyError::Invalid)?;
193 let seed = seed_string
194 .value
195 .as_bytes()
196 .try_into()
197 .map_err(|_| pkcs8::KeyError::Invalid)?; reader.finish()?;
199
200 Ok(Self::from_seed(seed))
201 }
202}