use crate::{
crypto::{self, Keccak256},
error::Error,
Protected, SecretKey,
};
use rand::{thread_rng, RngCore};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Bytes(#[serde(with = "bytes")] pub Vec<u8>);
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct KeyFile {
pub id: String,
pub version: u64,
pub crypto: Crypto,
pub address: Option<Bytes>,
}
impl KeyFile {
pub fn to_secret_key(&self, password: &Protected) -> Result<SecretKey, Error> {
SecretKey::from_crypto(&self.crypto, password)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Crypto {
pub cipher: Cipher,
pub cipherparams: Aes128Ctr,
pub ciphertext: Bytes,
#[serde(flatten)]
pub kdf: Kdf,
pub mac: Bytes,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum Cipher {
#[serde(rename = "aes-128-ctr")]
Aes128Ctr,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Aes128Ctr {
pub iv: Bytes,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase", tag = "kdf", content = "kdfparams")]
pub enum Kdf {
Pbkdf2(Pbkdf2),
Scrypt(Scrypt),
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Pbkdf2 {
pub c: u32,
pub dklen: u32,
pub prf: Prf,
pub salt: Bytes,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Scrypt {
pub dklen: u32,
pub p: u32,
pub n: u32,
pub r: u32,
pub salt: Bytes,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum Prf {
#[serde(rename = "hmac-sha256")]
HmacSha256,
}
impl Crypto {
pub fn encrypt(plain: &[u8], password: &Protected, iterations: u32) -> Result<Self, Error> {
let mut rng = thread_rng();
let mut salt = [0u8; 32];
let mut iv = [0u8; 16];
rng.fill_bytes(&mut salt);
rng.fill_bytes(&mut iv);
let (derived_left_bits, derived_right_bits) =
crypto::derive_key_iterations(password.as_ref(), &salt, iterations);
let plain_len = plain.len();
let mut ciphertext = Bytes(vec![0u8; plain_len]);
crypto::aes::encrypt_128_ctr(&derived_left_bits, &iv, plain, &mut *ciphertext.0)
.map_err(crypto::Error::from)
.map_err(Error::Crypto)?;
let mac = crypto::derive_mac(&derived_right_bits, &*ciphertext.0).keccak256();
Ok(Crypto {
cipher: Cipher::Aes128Ctr,
cipherparams: Aes128Ctr { iv: Bytes(iv.to_vec()) },
ciphertext,
kdf: Kdf::Pbkdf2(Pbkdf2 {
c: iterations,
dklen: crypto::KEY_LENGTH as u32,
prf: Prf::HmacSha256,
salt: Bytes(salt.to_vec()),
}),
mac: Bytes(mac.to_vec()),
})
}
pub fn decrypt(&self, password: &Protected) -> Result<Vec<u8>, Error> {
let (left_bits, right_bits) = match self.kdf {
Kdf::Pbkdf2(ref params) => crypto::derive_key_iterations(password.as_ref(), ¶ms.salt.0, params.c),
Kdf::Scrypt(ref params) => {
crypto::scrypt::derive_key(password.as_ref(), ¶ms.salt.0, params.n, params.p, params.r)
.map_err(Error::ScryptError)?
}
};
let mac = crypto::derive_mac(&right_bits, &self.ciphertext.0).keccak256();
if !crypto::is_equal(&mac, &self.mac.0) {
return Err(Error::InvalidPassword);
}
let mut plain = Vec::new();
plain.resize(self.ciphertext.0.len(), 0);
crypto::aes::decrypt_128_ctr(&left_bits, &self.cipherparams.iv.0, &self.ciphertext.0, &mut plain)
.map_err(crypto::Error::from)
.map_err(Error::Crypto)?;
Ok(plain)
}
}
mod bytes {
use std::fmt;
use serde::{de, Deserializer, Serializer};
pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let it: String = rustc_hex::ToHex::to_hex(bytes);
serializer.serialize_str(&it)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor;
impl<'a> de::Visitor<'a> for Visitor {
type Value = Vec<u8>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "a hex string of even length")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
if v.len() % 2 != 0 {
return Err(E::invalid_length(v.len(), &self));
}
::rustc_hex::FromHex::from_hex(&v).map_err(|e| E::custom(e.to_string()))
}
fn visit_string<E: de::Error>(self, v: String) -> Result<Self::Value, E> {
self.visit_str(&v)
}
}
deserializer.deserialize_str(Visitor)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_deserialize() {
let _keyfile: KeyFile = serde_json::from_str(include_str!("../res/wallet.json")).unwrap();
}
#[test]
fn decrypt_encrypt() {
let data = &b"It was the year they finally immanentized the Eschaton."[..];
let password = Protected::new(b"discord".to_vec());
let crypto = Crypto::encrypt(data, &password, 10240).unwrap();
let decrypted = crypto.decrypt(&password).unwrap();
assert_eq!(data, decrypted.as_slice());
}
}