1use crate::{AlgorithmIdentifierRef, Error, Result};
4use der::{
5 asn1::{AnyRef, ObjectIdentifier, OctetStringRef},
6 Decode, DecodeValue, Encode, EncodeValue, ErrorKind, Length, Reader, Sequence, Tag, Tagged,
7 Writer,
8};
9
10pub const PBKDF2_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113549.1.5.12");
12
13pub const HMAC_WITH_SHA1_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113549.2.7");
15
16pub const HMAC_WITH_SHA224_OID: ObjectIdentifier =
18 ObjectIdentifier::new_unwrap("1.2.840.113549.2.8");
19
20pub const HMAC_WITH_SHA256_OID: ObjectIdentifier =
22 ObjectIdentifier::new_unwrap("1.2.840.113549.2.9");
23
24pub const HMAC_WITH_SHA384_OID: ObjectIdentifier =
26 ObjectIdentifier::new_unwrap("1.2.840.113549.2.10");
27
28pub const HMAC_WITH_SHA512_OID: ObjectIdentifier =
30 ObjectIdentifier::new_unwrap("1.2.840.113549.2.11");
31
32pub const SCRYPT_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.3.6.1.4.1.11591.4.11");
36
37type ScryptCost = u64;
39
40#[derive(Clone, Debug, Eq, PartialEq)]
42#[non_exhaustive]
43pub enum Kdf<'a> {
44 Pbkdf2(Pbkdf2Params<'a>),
46
47 Scrypt(ScryptParams<'a>),
49}
50
51impl<'a> Kdf<'a> {
52 pub fn key_length(&self) -> Option<u16> {
55 match self {
56 Self::Pbkdf2(params) => params.key_length,
57 Self::Scrypt(params) => params.key_length,
58 }
59 }
60
61 pub fn oid(&self) -> ObjectIdentifier {
63 match self {
64 Self::Pbkdf2(_) => PBKDF2_OID,
65 Self::Scrypt(_) => SCRYPT_OID,
66 }
67 }
68
69 pub fn pbkdf2(&self) -> Option<&Pbkdf2Params<'a>> {
71 match self {
72 Self::Pbkdf2(params) => Some(params),
73 _ => None,
74 }
75 }
76
77 pub fn scrypt(&self) -> Option<&ScryptParams<'a>> {
79 match self {
80 Self::Scrypt(params) => Some(params),
81 _ => None,
82 }
83 }
84
85 pub fn is_pbkdf2(&self) -> bool {
87 self.pbkdf2().is_some()
88 }
89
90 pub fn is_scrypt(&self) -> bool {
92 self.scrypt().is_some()
93 }
94
95 pub fn to_alg_params_invalid(&self) -> Error {
98 Error::AlgorithmParametersInvalid { oid: self.oid() }
99 }
100}
101
102impl<'a> DecodeValue<'a> for Kdf<'a> {
103 fn decode_value<R: Reader<'a>>(reader: &mut R, header: der::Header) -> der::Result<Self> {
104 AlgorithmIdentifierRef::decode_value(reader, header)?.try_into()
105 }
106}
107
108impl EncodeValue for Kdf<'_> {
109 fn value_len(&self) -> der::Result<Length> {
110 self.oid().encoded_len()?
111 + match self {
112 Self::Pbkdf2(params) => params.encoded_len()?,
113 Self::Scrypt(params) => params.encoded_len()?,
114 }
115 }
116
117 fn encode_value(&self, writer: &mut impl Writer) -> der::Result<()> {
118 self.oid().encode(writer)?;
119
120 match self {
121 Self::Pbkdf2(params) => params.encode(writer)?,
122 Self::Scrypt(params) => params.encode(writer)?,
123 }
124
125 Ok(())
126 }
127}
128
129impl<'a> Sequence<'a> for Kdf<'a> {}
130
131impl<'a> From<Pbkdf2Params<'a>> for Kdf<'a> {
132 fn from(params: Pbkdf2Params<'a>) -> Self {
133 Kdf::Pbkdf2(params)
134 }
135}
136
137impl<'a> From<ScryptParams<'a>> for Kdf<'a> {
138 fn from(params: ScryptParams<'a>) -> Self {
139 Kdf::Scrypt(params)
140 }
141}
142
143impl<'a> TryFrom<AlgorithmIdentifierRef<'a>> for Kdf<'a> {
144 type Error = der::Error;
145
146 fn try_from(alg: AlgorithmIdentifierRef<'a>) -> der::Result<Self> {
147 if let Some(params) = alg.parameters {
148 match alg.oid {
149 PBKDF2_OID => params.try_into().map(Self::Pbkdf2),
150 SCRYPT_OID => params.try_into().map(Self::Scrypt),
151 oid => Err(ErrorKind::OidUnknown { oid }.into()),
152 }
153 } else {
154 Err(Tag::OctetString.value_error())
155 }
156 }
157}
158
159#[derive(Copy, Clone, Debug, Eq, PartialEq)]
176pub struct Pbkdf2Params<'a> {
177 pub salt: &'a [u8],
180
181 pub iteration_count: u32,
183
184 pub key_length: Option<u16>,
186
187 pub prf: Pbkdf2Prf,
189}
190
191impl<'a> Pbkdf2Params<'a> {
192 pub const MAX_ITERATION_COUNT: u32 = 100_000_000;
202
203 const INVALID_ERR: Error = Error::AlgorithmParametersInvalid { oid: PBKDF2_OID };
204
205 pub fn hmac_with_sha256(iteration_count: u32, salt: &'a [u8]) -> Result<Self> {
207 if iteration_count > Self::MAX_ITERATION_COUNT {
208 return Err(Self::INVALID_ERR);
209 }
210 Ok(Self {
211 salt,
212 iteration_count,
213 key_length: None,
214 prf: Pbkdf2Prf::HmacWithSha256,
215 })
216 }
217}
218
219impl<'a> DecodeValue<'a> for Pbkdf2Params<'a> {
220 fn decode_value<R: Reader<'a>>(reader: &mut R, header: der::Header) -> der::Result<Self> {
221 AnyRef::decode_value(reader, header)?.try_into()
222 }
223}
224
225impl EncodeValue for Pbkdf2Params<'_> {
226 fn value_len(&self) -> der::Result<Length> {
227 let len = OctetStringRef::new(self.salt)?.encoded_len()?
228 + self.iteration_count.encoded_len()?
229 + self.key_length.encoded_len()?;
230
231 if self.prf == Pbkdf2Prf::default() {
232 len
233 } else {
234 len + self.prf.encoded_len()?
235 }
236 }
237
238 fn encode_value(&self, writer: &mut impl Writer) -> der::Result<()> {
239 OctetStringRef::new(self.salt)?.encode(writer)?;
240 self.iteration_count.encode(writer)?;
241 self.key_length.encode(writer)?;
242
243 if self.prf == Pbkdf2Prf::default() {
244 Ok(())
245 } else {
246 self.prf.encode(writer)
247 }
248 }
249}
250
251impl<'a> Sequence<'a> for Pbkdf2Params<'a> {}
252
253impl<'a> TryFrom<AnyRef<'a>> for Pbkdf2Params<'a> {
254 type Error = der::Error;
255
256 fn try_from(any: AnyRef<'a>) -> der::Result<Self> {
257 any.sequence(|reader| {
258 Ok(Self {
260 salt: OctetStringRef::decode(reader)?.as_bytes(),
261 iteration_count: reader.decode()?,
262 key_length: reader.decode()?,
263 prf: Option::<AlgorithmIdentifierRef<'_>>::decode(reader)?
264 .map(TryInto::try_into)
265 .transpose()?
266 .unwrap_or_default(),
267 })
268 })
269 }
270}
271
272#[derive(Copy, Clone, Debug, Eq, PartialEq)]
274#[non_exhaustive]
275pub enum Pbkdf2Prf {
276 HmacWithSha1,
278
279 HmacWithSha224,
281
282 HmacWithSha256,
284
285 HmacWithSha384,
287
288 HmacWithSha512,
290}
291
292impl Pbkdf2Prf {
293 pub fn oid(self) -> ObjectIdentifier {
295 match self {
296 Self::HmacWithSha1 => HMAC_WITH_SHA1_OID,
297 Self::HmacWithSha224 => HMAC_WITH_SHA224_OID,
298 Self::HmacWithSha256 => HMAC_WITH_SHA256_OID,
299 Self::HmacWithSha384 => HMAC_WITH_SHA384_OID,
300 Self::HmacWithSha512 => HMAC_WITH_SHA512_OID,
301 }
302 }
303}
304
305impl Default for Pbkdf2Prf {
314 fn default() -> Self {
315 Self::HmacWithSha1
316 }
317}
318
319impl<'a> TryFrom<AlgorithmIdentifierRef<'a>> for Pbkdf2Prf {
320 type Error = der::Error;
321
322 fn try_from(alg: AlgorithmIdentifierRef<'a>) -> der::Result<Self> {
323 if let Some(params) = alg.parameters {
324 if !params.is_null() {
326 return Err(params.tag().value_error());
327 }
328 } else {
329 return Err(Tag::Null.value_error());
331 }
332
333 match alg.oid {
334 HMAC_WITH_SHA1_OID => Ok(Self::HmacWithSha1),
335 HMAC_WITH_SHA224_OID => Ok(Self::HmacWithSha224),
336 HMAC_WITH_SHA256_OID => Ok(Self::HmacWithSha256),
337 HMAC_WITH_SHA384_OID => Ok(Self::HmacWithSha384),
338 HMAC_WITH_SHA512_OID => Ok(Self::HmacWithSha512),
339 oid => Err(ErrorKind::OidUnknown { oid }.into()),
340 }
341 }
342}
343
344impl<'a> From<Pbkdf2Prf> for AlgorithmIdentifierRef<'a> {
345 fn from(prf: Pbkdf2Prf) -> Self {
346 let parameters = der::asn1::Null;
348
349 AlgorithmIdentifierRef {
350 oid: prf.oid(),
351 parameters: Some(parameters.into()),
352 }
353 }
354}
355
356impl Encode for Pbkdf2Prf {
357 fn encoded_len(&self) -> der::Result<Length> {
358 AlgorithmIdentifierRef::try_from(*self)?.encoded_len()
359 }
360
361 fn encode(&self, writer: &mut impl Writer) -> der::Result<()> {
362 AlgorithmIdentifierRef::try_from(*self)?.encode(writer)
363 }
364}
365
366#[derive(Copy, Clone, Debug, Eq, PartialEq)]
380pub struct ScryptParams<'a> {
381 pub salt: &'a [u8],
383
384 pub cost_parameter: ScryptCost,
386
387 pub block_size: u16,
389
390 pub parallelization: u16,
392
393 pub key_length: Option<u16>,
395}
396
397impl<'a> ScryptParams<'a> {
398 #[cfg(feature = "pbes2")]
399 const INVALID_ERR: Error = Error::AlgorithmParametersInvalid { oid: SCRYPT_OID };
400
401 #[cfg(feature = "pbes2")]
405 pub fn from_params_and_salt(params: scrypt::Params, salt: &'a [u8]) -> Result<Self> {
406 Ok(Self {
407 salt,
408 cost_parameter: 1 << params.log_n(),
409 block_size: params.r().try_into().map_err(|_| Self::INVALID_ERR)?,
410 parallelization: params.p().try_into().map_err(|_| Self::INVALID_ERR)?,
411 key_length: None,
412 })
413 }
414}
415
416impl<'a> DecodeValue<'a> for ScryptParams<'a> {
417 fn decode_value<R: Reader<'a>>(reader: &mut R, header: der::Header) -> der::Result<Self> {
418 AnyRef::decode_value(reader, header)?.try_into()
419 }
420}
421
422impl EncodeValue for ScryptParams<'_> {
423 fn value_len(&self) -> der::Result<Length> {
424 OctetStringRef::new(self.salt)?.encoded_len()?
425 + self.cost_parameter.encoded_len()?
426 + self.block_size.encoded_len()?
427 + self.parallelization.encoded_len()?
428 + self.key_length.encoded_len()?
429 }
430
431 fn encode_value(&self, writer: &mut impl Writer) -> der::Result<()> {
432 OctetStringRef::new(self.salt)?.encode(writer)?;
433 self.cost_parameter.encode(writer)?;
434 self.block_size.encode(writer)?;
435 self.parallelization.encode(writer)?;
436 self.key_length.encode(writer)?;
437 Ok(())
438 }
439}
440
441impl<'a> Sequence<'a> for ScryptParams<'a> {}
442
443impl<'a> TryFrom<AnyRef<'a>> for ScryptParams<'a> {
444 type Error = der::Error;
445
446 fn try_from(any: AnyRef<'a>) -> der::Result<Self> {
447 any.sequence(|reader| {
448 Ok(Self {
449 salt: OctetStringRef::decode(reader)?.as_bytes(),
450 cost_parameter: reader.decode()?,
451 block_size: reader.decode()?,
452 parallelization: reader.decode()?,
453 key_length: reader.decode()?,
454 })
455 })
456 }
457}
458
459#[cfg(feature = "pbes2")]
460impl<'a> TryFrom<ScryptParams<'a>> for scrypt::Params {
461 type Error = Error;
462
463 fn try_from(params: ScryptParams<'a>) -> Result<scrypt::Params> {
464 scrypt::Params::try_from(¶ms)
465 }
466}
467
468#[cfg(feature = "pbes2")]
469impl<'a> TryFrom<&ScryptParams<'a>> for scrypt::Params {
470 type Error = Error;
471
472 fn try_from(params: &ScryptParams<'a>) -> Result<scrypt::Params> {
473 let n = params.cost_parameter;
474
475 let log_n = ((8 * core::mem::size_of::<ScryptCost>() as u32) - n.leading_zeros() - 1) as u8;
477
478 if 1 << log_n != n {
479 return Err(ScryptParams::INVALID_ERR);
480 }
481
482 scrypt::Params::new(
483 log_n,
484 params.block_size.into(),
485 params.parallelization.into(),
486 scrypt::Params::RECOMMENDED_LEN,
487 )
488 .map_err(|_| ScryptParams::INVALID_ERR)
489 }
490}