1#![deny(
40 bad_style,
41 dead_code,
42 improper_ctypes,
43 rustdoc::broken_intra_doc_links,
44 non_shorthand_field_patterns,
45 no_mangle_generic_items,
46 overflowing_literals,
47 path_statements,
48 patterns_in_fns_without_body,
49 private_bounds,
50 private_interfaces,
51 unconditional_recursion,
52 unused,
53 unused_allocation,
54 unused_comparisons,
55 unused_parens,
56 while_true,
57 missing_debug_implementations,
58 missing_copy_implementations,
59 missing_docs,
60 trivial_casts,
61 trivial_numeric_casts,
62 unnameable_types,
63 unused_extern_crates,
64 unused_import_braces,
65 unused_qualifications,
66 unused_results
67)]
68#![no_std]
69
70use generic_array::{ArrayLength, GenericArray};
71use typenum::ToInt;
72
73use zeroize::Zeroize;
74
75pub trait PseudoRandomFunctionKey {
77 type KeyHandle;
79
80 fn key_handle(&self) -> &Self::KeyHandle;
82}
83
84pub trait PseudoRandomFunction<'a> {
89 type KeyHandle;
91 type PrfOutputSize: ArrayLength<u8> + ToInt<usize>;
93 type Error;
95
96 fn init(
110 &mut self,
111 key: &'a dyn PseudoRandomFunctionKey<KeyHandle = Self::KeyHandle>,
112 ) -> Result<(), Self::Error>;
113
114 fn update(&mut self, msg: &[u8]) -> Result<(), Self::Error>;
129
130 fn finish(&mut self, out: &mut [u8]) -> Result<usize, Self::Error>;
145}
146
147#[derive(Copy, Clone, Debug)]
149pub struct CounterMode {
150 pub counter_length: usize,
152}
153
154#[derive(Copy, Clone, Debug)]
156pub struct FeedbackMode<'a> {
157 pub iv: Option<&'a [u8]>,
159 pub counter_length: Option<usize>,
161}
162
163#[derive(Copy, Clone, Debug)]
165pub struct DoublePipelineIterationMode {
166 pub counter_length: Option<usize>,
168}
169
170#[derive(Copy, Clone, Debug)]
172pub enum KDFMode<'a> {
173 CounterMode(CounterMode),
175 FeedbackMode(FeedbackMode<'a>),
177 DoublePipelineIterationMode(DoublePipelineIterationMode),
179}
180
181#[derive(Copy, Clone, Debug)]
183pub enum CounterLocation {
184 NoCounter,
186 BeforeFixedInput,
188 BeforeIter,
190 MiddleOfFixedInput(usize),
192 AfterFixedInput,
194 AfterIter,
196}
197
198#[derive(Debug)]
200pub struct FixedInput<'a> {
201 pub fixed_input: &'a [u8],
203 pub counter_location: CounterLocation,
205}
206
207#[derive(Debug)]
209pub struct SpecifiedInput<'a> {
210 pub label: &'a [u8],
212 pub context: &'a [u8],
214}
215
216#[derive(Debug)]
218pub enum InputType<'a> {
219 FixedInput(FixedInput<'a>),
222 SpecifiedInput(SpecifiedInput<'a>),
224}
225
226pub fn kbkdf<'a, T: PseudoRandomFunction<'a>>(
241 kdf_mode: &KDFMode,
242 input_type: &InputType,
243 key: &'a dyn PseudoRandomFunctionKey<KeyHandle = T::KeyHandle>,
244 prf: &mut T,
245 derived_key: &mut [u8],
246) -> Result<(), T::Error> {
247 match kdf_mode {
248 KDFMode::CounterMode(counter_mode) => {
249 kbkdf_counter::<T>(counter_mode, input_type, key, prf, derived_key)
250 }
251 KDFMode::FeedbackMode(feedback_mode) => {
252 kbkdf_feedback::<T>(feedback_mode, input_type, key, prf, derived_key)
253 }
254 KDFMode::DoublePipelineIterationMode(double_pipeline) => {
255 kbkdf_double_pipeline::<T>(double_pipeline, input_type, key, prf, derived_key)
256 }
257 }
258}
259
260fn kbkdf_counter<'a, T: PseudoRandomFunction<'a>>(
261 counter_mode: &CounterMode,
262 input_type: &InputType,
263 key: &'a dyn PseudoRandomFunctionKey<KeyHandle = T::KeyHandle>,
264 prf: &mut T,
265 derived_key: &mut [u8],
266) -> Result<(), T::Error> {
267 let l = derived_key.len() * 8;
269 let h = T::PrfOutputSize::to_int() * 8;
270 let n = calculate_counter(l, h);
271 let mut intermediate_key = GenericArray::<u8, T::PrfOutputSize>::default();
272 assert!(
273 n < 2_usize.pow(counter_mode.counter_length as u32),
274 "Invalid derived key length"
275 );
276 for i in 1..=n {
277 prf.init(key)?;
278 let counter = i.to_be_bytes();
279 let counter = &counter[(counter.len() - counter_mode.counter_length / 8)..];
280 match input_type {
281 InputType::FixedInput(fixed_input) => match fixed_input.counter_location {
282 CounterLocation::NoCounter => prf.update(fixed_input.fixed_input)?,
283 CounterLocation::BeforeFixedInput => {
284 prf.update(counter)?;
285 prf.update(fixed_input.fixed_input)?;
286 }
287 CounterLocation::MiddleOfFixedInput(position) => {
288 prf.update(&fixed_input.fixed_input[..position])?;
289 prf.update(counter)?;
290 prf.update(&fixed_input.fixed_input[position..])?;
291 }
292 CounterLocation::AfterFixedInput => {
293 prf.update(fixed_input.fixed_input)?;
294 prf.update(counter)?;
295 }
296 _ => panic!(
297 "Invalid counter location for KBKDF In Counter Mode: {:?}",
298 fixed_input.counter_location
299 ),
300 },
301 InputType::SpecifiedInput(specified_input) => {
302 prf.update(counter)?;
303 prf.update(specified_input.label)?;
304 prf.update(b"\0")?;
305 prf.update(specified_input.context)?;
306 let length = (l as u32).to_be_bytes();
307 prf.update(&length)?;
308 }
309 }
310 let _ = prf.finish(intermediate_key.as_mut_slice())?;
311 insert_result(i, intermediate_key.as_slice(), derived_key);
312 intermediate_key.zeroize();
313 }
314
315 Ok(())
316}
317
318fn kbkdf_double_pipeline<'a, T: PseudoRandomFunction<'a>>(
319 double_feedback: &DoublePipelineIterationMode,
320 input_type: &InputType,
321 key: &'a dyn PseudoRandomFunctionKey<KeyHandle = T::KeyHandle>,
322 prf: &mut T,
323 derived_key: &mut [u8],
324) -> Result<(), T::Error> {
325 let l = derived_key.len() * 8;
326 let h = T::PrfOutputSize::to_int() * 8;
327 let n = calculate_counter(l, h);
328 let mut intermediate_key = GenericArray::<u8, T::PrfOutputSize>::default();
329 let mut feedback = GenericArray::<u8, T::PrfOutputSize>::default();
330 let length = (l as u32).to_be_bytes();
331 assert!(
332 n < 2_usize.pow(32),
333 "Invalid length provided for derived key"
334 );
335 for i in 1..=n {
336 let counter = i.to_be_bytes();
337 let counter = feedback_counter(double_feedback.counter_length, counter.as_slice());
338 prf.init(key)?;
340 if i == 1 {
341 match input_type {
342 InputType::FixedInput(fixed_input) => {
343 prf.update(fixed_input.fixed_input)?;
344 }
345 InputType::SpecifiedInput(specified_input) => {
346 prf.update(specified_input.label)?;
347 prf.update(b"\0")?;
348 prf.update(specified_input.context)?;
349 prf.update(length.as_slice())?;
350 }
351 }
352 } else {
353 prf.update(feedback.as_slice())?;
354 }
355 let _ = prf.finish(feedback.as_mut_slice())?;
356
357 prf.init(key)?;
358
359 match input_type {
360 InputType::FixedInput(fixed_input) => match fixed_input.counter_location {
361 CounterLocation::NoCounter => {
362 prf.update(feedback.as_slice())?;
363 prf.update(fixed_input.fixed_input)?;
364 }
365 CounterLocation::BeforeIter => {
366 prf.update(
367 counter
368 .expect("Counter length not provided for BeforeIter counter location"),
369 )?;
370 prf.update(feedback.as_slice())?;
371 prf.update(fixed_input.fixed_input)?;
372 }
373 CounterLocation::AfterFixedInput => {
374 prf.update(feedback.as_slice())?;
375 prf.update(fixed_input.fixed_input)?;
376 prf.update(counter.expect(
377 "Counter length not provided for AfterFixedInput counter location",
378 ))?;
379 }
380 CounterLocation::AfterIter => {
381 prf.update(feedback.as_slice())?;
382 prf.update(
383 counter
384 .expect("Counter length not provided for AfterIter counter location"),
385 )?;
386 prf.update(fixed_input.fixed_input)?;
387 }
388 _ => panic!(
389 "Invalid counter location for double feedback: {:?}",
390 fixed_input.counter_location
391 ),
392 },
393 InputType::SpecifiedInput(specified_input) => {
394 prf.update(feedback.as_slice())?;
395 if let Some(counter) = counter {
396 prf.update(counter)?;
397 }
398 prf.update(specified_input.label)?;
399 prf.update(b"\0")?;
400 prf.update(specified_input.context)?;
401 prf.update(&length)?;
402 }
403 }
404
405 let _ = prf.finish(intermediate_key.as_mut_slice())?;
406 insert_result(i, intermediate_key.as_slice(), derived_key);
407 intermediate_key.zeroize();
408 }
409
410 Ok(())
411}
412
413fn kbkdf_feedback<'a, T: PseudoRandomFunction<'a>>(
414 feedback_mode: &FeedbackMode,
415 input_type: &InputType,
416 key: &'a dyn PseudoRandomFunctionKey<KeyHandle = T::KeyHandle>,
417 prf: &mut T,
418 derived_key: &mut [u8],
419) -> Result<(), T::Error> {
420 let l = derived_key.len() * 8;
421 let h = T::PrfOutputSize::to_int() * 8;
422 let n = calculate_counter(l, h);
423
424 let mut intermediate_key = GenericArray::<u8, T::PrfOutputSize>::default();
425 let mut has_intermediate = feedback_mode.iv.is_some();
426 if let Some(iv) = feedback_mode.iv {
427 assert_eq!(iv.len(), T::PrfOutputSize::to_int());
428 intermediate_key.copy_from_slice(iv);
429 }
430 assert!(n < 2_usize.pow(32), "Invalid derived_key length provided");
431 for i in 1..=n {
432 prf.init(key)?;
433 let counter = i.to_be_bytes();
434 let counter = feedback_counter(feedback_mode.counter_length, counter.as_slice());
435 match input_type {
436 InputType::FixedInput(fixed_input) => match fixed_input.counter_location {
437 CounterLocation::NoCounter => {
438 if has_intermediate {
439 prf.update(intermediate_key.as_slice())?;
440 }
441 prf.update(fixed_input.fixed_input)?;
442 }
443 CounterLocation::BeforeIter => {
444 prf.update(
445 counter
446 .expect("Counter length not provided for BeforeIter counter location"),
447 )?;
448 if has_intermediate {
449 prf.update(intermediate_key.as_slice())?;
450 }
451 prf.update(fixed_input.fixed_input)?;
452 }
453 CounterLocation::AfterIter => {
454 if has_intermediate {
455 prf.update(intermediate_key.as_slice())?;
456 }
457 prf.update(
458 counter
459 .expect("Counter length not provided for AfterIter counter location"),
460 )?;
461 prf.update(fixed_input.fixed_input)?;
462 }
463 CounterLocation::AfterFixedInput => {
464 if has_intermediate {
465 prf.update(intermediate_key.as_slice())?;
466 }
467 prf.update(fixed_input.fixed_input)?;
468 prf.update(counter.expect(
469 "Counter length not provided for AfterFixedInput counter location",
470 ))?;
471 }
472 _ => panic!(
473 "Invalid counter location provided for KDF feedback mode: {:?}",
474 fixed_input.counter_location
475 ),
476 },
477 InputType::SpecifiedInput(specified_input) => {
478 if has_intermediate {
479 prf.update(intermediate_key.as_slice())?;
480 }
481 if let Some(counter) = counter {
482 prf.update(counter)?;
483 }
484 prf.update(specified_input.label)?;
485 prf.update(b"\0")?;
486 prf.update(specified_input.context)?;
487 let length = (l as u32).to_be_bytes();
488 prf.update(&length)?;
489 }
490 }
491 let _ = prf.finish(intermediate_key.as_mut_slice())?;
492 insert_result(i, intermediate_key.as_slice(), derived_key);
493 has_intermediate = true;
494 }
495
496 Ok(())
497}
498
499fn calculate_counter(derived_key_len_bits: usize, prf_output_size_in_bits: usize) -> usize {
500 derived_key_len_bits / prf_output_size_in_bits
501 + if derived_key_len_bits % prf_output_size_in_bits != 0 {
502 1
503 } else {
504 0
505 }
506}
507
508fn feedback_counter(counter_length: Option<usize>, counter: &[u8]) -> Option<&[u8]> {
509 match counter_length {
510 None => None,
511 Some(length) => Some(&counter[(counter.len() - length / 8)..]),
512 }
513}
514
515fn insert_result(counter: usize, intermediate: &[u8], result: &mut [u8]) {
516 let low_index = (counter - 1) * intermediate.len();
517 assert!(
518 low_index < result.len(),
519 "The starting insert index should not exceed bounds of result slice"
520 );
521 let high_index = core::cmp::min(low_index + intermediate.len(), result.len());
522 assert!(
523 high_index <= result.len(),
524 "Ending insert index should not exceed bounds of result slice"
525 );
526 result[low_index..high_index].clone_from_slice(&intermediate[..(high_index - low_index)]);
527}