use std::io;
use nom::be_u8;
use num_traits::FromPrimitive;
use rand::{CryptoRng, Rng};
use crate::pgp::crypto::hash::HashAlgorithm;
use crate::pgp::errors::Result;
use crate::pgp::ser::Serialize;
const EXPBIAS: u32 = 6;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StringToKey {
typ: StringToKeyType,
hash: HashAlgorithm,
salt: Option<Vec<u8>>,
count: Option<u8>,
}
impl StringToKey {
pub fn new_default<R: CryptoRng + Rng>(rng: &mut R) -> Self {
StringToKey::new_iterated(rng, HashAlgorithm::default(), 224)
}
pub fn new_iterated<R: CryptoRng + Rng>(rng: &mut R, hash: HashAlgorithm, count: u8) -> Self {
let mut salt = vec![0u8; 8];
rng.fill(&mut salt[..]);
StringToKey {
typ: StringToKeyType::IteratedAndSalted,
hash,
salt: Some(salt),
count: Some(count),
}
}
}
impl StringToKey {
pub fn count(&self) -> Option<usize> {
match self.count {
Some(c) => {
let res = ((16u32 + u32::from(c & 15)) << (u32::from(c >> 4) + EXPBIAS)) as usize;
Some(res)
}
None => None,
}
}
pub fn salt(&self) -> Option<&[u8]> {
self.salt.as_ref().map(|salt| &salt[..])
}
pub fn hash(&self) -> HashAlgorithm {
self.hash
}
pub fn typ(&self) -> StringToKeyType {
self.typ
}
pub fn derive_key(&self, passphrase: &str, key_size: usize) -> Result<Vec<u8>> {
let digest_size = self.hash.digest_size();
let rounds = (key_size as f32 / digest_size as f32).ceil() as usize;
let mut key = Vec::with_capacity(key_size);
for round in 0..rounds {
let mut hasher = self.hash.new_hasher()?;
if round > 0 {
hasher.update(&vec![0u8; round][..]);
}
match self.typ {
StringToKeyType::Simple => {
hasher.update(passphrase.as_bytes());
}
StringToKeyType::Salted => {
hasher.update(self.salt.as_ref().expect("missing salt"));
hasher.update(passphrase.as_bytes());
}
StringToKeyType::IteratedAndSalted => {
let salt = self.salt.as_ref().expect("missing salt");
let pw = passphrase.as_bytes();
let data_size = salt.len() + pw.len();
let mut count = self.count().expect("missing count");
if count < data_size {
count = data_size;
}
while count > data_size {
hasher.update(salt);
hasher.update(pw);
count -= data_size;
}
if count < salt.len() {
hasher.update(&salt[..count]);
} else {
hasher.update(salt);
count -= salt.len();
hasher.update(&pw[..count]);
}
}
_ => unimplemented_err!("S2K {:?} is not available", self.typ),
}
if key_size - key.len() < digest_size {
let end = key_size - key.len();
key.extend_from_slice(&hasher.finish()[..end]);
} else {
key.extend_from_slice(&hasher.finish()[..]);
}
}
Ok(key)
}
}
#[repr(u8)]
#[derive(Debug, PartialEq, Eq, Copy, Clone, FromPrimitive)]
pub enum StringToKeyType {
Simple = 0,
Salted = 1,
Reserved = 2,
IteratedAndSalted = 3,
Private100 = 100,
Private101 = 101,
Private102 = 102,
Private103 = 103,
Private104 = 104,
Private105 = 105,
Private106 = 106,
Private107 = 107,
Private108 = 108,
Private109 = 109,
Private110 = 110,
}
impl Default for StringToKeyType {
fn default() -> Self {
StringToKeyType::IteratedAndSalted
}
}
impl StringToKeyType {
pub fn param_len(self) -> usize {
match self {
StringToKeyType::Simple => 1,
StringToKeyType::Salted => 9,
StringToKeyType::IteratedAndSalted => 10,
_ => 0,
}
}
}
fn has_salt(typ: StringToKeyType) -> bool {
matches!(
typ,
StringToKeyType::Salted | StringToKeyType::IteratedAndSalted
)
}
fn has_count(typ: StringToKeyType) -> bool {
matches!(typ, StringToKeyType::IteratedAndSalted)
}
#[rustfmt::skip]
named!(pub s2k_parser<StringToKey>, do_parse!(
typ: map_opt!(be_u8, StringToKeyType::from_u8)
>> hash: map_opt!(be_u8, HashAlgorithm::from_u8)
>> salt: cond!(has_salt(typ), map!(take!(8), |v| v.to_vec()))
>> count: cond!(has_count(typ), be_u8)
>> (StringToKey {
typ,
hash,
salt,
count,
})
));
impl Serialize for StringToKey {
fn to_writer<W: io::Write>(&self, writer: &mut W) -> Result<()> {
writer.write_all(&[self.typ as u8, self.hash as u8])?;
if let Some(ref salt) = self.salt {
writer.write_all(salt)?;
}
if let Some(count) = self.count {
writer.write_all(&[count])?;
}
Ok(())
}
}