use hashing::{Algorithm, Output};
use primitives::Primitive;
use serde::de::Error;
use serde::de::{self, Visitor};
use serde::{Deserialize, Deserializer};
use serde_mcf;
use serde_mcf::{base64, base64bcrypt, Hashes};
use std::fmt;
#[derive(Debug, PartialEq)]
enum SupportedVariants {
Bcrypt(Hashes),
Mcf(Hashes),
Pasta(PastaVariants),
}
#[derive(Debug, PartialEq)]
enum PastaVariants {
Single,
Nested,
}
static VAR_STRUCT: [&str; 2] = ["variant", "remaining"];
#[derive(Deserialize)]
struct BcryptFields {
cost: u32,
#[serde(with = "base64bcrypt")]
salthash: (Vec<u8>, Vec<u8>),
}
#[derive(Deserialize)]
struct McfFields {
params: serde_mcf::Map<String, serde_mcf::Value>,
#[serde(with = "base64")]
pub salt: Vec<u8>,
#[serde(with = "base64")]
pub hash: Vec<u8>,
}
#[derive(Deserialize)]
struct PastaNest {
outer_id: Hashes,
outer_params: serde_mcf::Map<String, serde_mcf::Value>,
inner: Output,
}
impl<'de> Deserialize<'de> for Output {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_struct("var_container", &VAR_STRUCT, OutputVisitor)
}
}
struct OutputVisitor;
impl<'de> Visitor<'de> for OutputVisitor {
type Value = Output;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("an identifier")
}
fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
where
V: de::MapAccess<'de>,
{
let _: Option<String> = map.next_key()?;
let var: SupportedVariants = map.next_value()?;
match var {
SupportedVariants::Bcrypt(_) => {
let _: Option<String> = map.next_key()?;
let fields: BcryptFields = map.next_value()?;
let prim = ::primitives::Bcrypt::new(fields.cost);
if prim == ::primitives::Poisoned.into() {
return Err(V::Error::custom(format!(
"failed to deserialize as {:?}",
var
)));
}
Ok(Output {
alg: Algorithm::Single(prim),
salt: fields.salthash.0,
hash: fields.salthash.1,
})
}
SupportedVariants::Mcf(var) => {
let _: Option<String> = map.next_key()?;
let fields: McfFields = map.next_value()?;
let prim = ::primitives::Primitive::from((&var, &fields.params));
if prim == ::primitives::Poisoned.into() {
return Err(V::Error::custom(format!(
"failed to deserialize as {:?}",
var
)));
}
Ok(Output {
alg: Algorithm::Single(prim),
salt: fields.salt,
hash: fields.hash,
})
}
SupportedVariants::Pasta(var) => {
match var {
PastaVariants::Single => {
let _: Option<String> = map.next_key()?;
let output: serde_mcf::McfHash = map.next_value()?;
let prim =
::primitives::Primitive::from((&output.algorithm, &output.parameters));
if prim == ::primitives::Poisoned.into() {
return Err(V::Error::custom(format!(
"failed to deserialize as {:?}",
var
)));
}
Ok(Output {
alg: Algorithm::Single(prim),
salt: output.salt,
hash: output.hash,
})
}
PastaVariants::Nested => {
let _: Option<String> = map.next_key()?;
let fields: PastaNest = map.next_value()?;
let prim =
::primitives::Primitive::from((&fields.outer_id, &fields.outer_params));
if prim == ::primitives::Poisoned.into() {
return Err(V::Error::custom(format!(
"failed to deserialize as {:?}",
var
)));
}
Ok(Output {
alg: Algorithm::Nested {
outer: prim,
inner: Box::new(fields.inner.alg.clone()),
},
salt: fields.inner.salt,
hash: fields.inner.hash,
})
}
}
}
}
}
}
impl<'de> Deserialize<'de> for SupportedVariants {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_identifier(VariantVisitor)
}
}
struct VariantVisitor;
impl<'de> Visitor<'de> for VariantVisitor {
type Value = SupportedVariants;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("an identifier")
}
fn visit_borrowed_str<E>(self, val: &str) -> Result<Self::Value, E>
where
E: Error,
{
let variant = match val {
"" => SupportedVariants::Pasta(PastaVariants::Single),
"!" => SupportedVariants::Pasta(PastaVariants::Nested),
other => {
let variant = Hashes::from_id(other)
.ok_or_else(|| E::custom(format!("unknown MCF variant: {}", other)))?;
match variant {
Hashes::Bcrypt
| Hashes::Bcrypta
| Hashes::Bcryptx
| Hashes::Bcrypty
| Hashes::Bcryptb => SupportedVariants::Bcrypt(variant),
_ => SupportedVariants::Mcf(variant),
}
}
};
Ok(variant)
}
}
impl<'de> Deserialize<'de> for Primitive {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct PrimitiveStruct {
id: Hashes,
params: serde_mcf::Map<String, serde_mcf::Value>,
}
let prim = PrimitiveStruct::deserialize(deserializer)?;
let prim = (&prim.id, &prim.params).into();
if prim == ::primitives::Poisoned.into() {
return Err(D::Error::custom("failed to deserialize"));
}
Ok(prim)
}
}
#[cfg(test)]
mod test {
#![allow(clippy::clippy::shadow_unrelated)]
use super::*;
use serde_mcf;
use serde_yaml;
#[test]
fn variant_tests() {
let variant = "$argon2i";
assert_eq!(
serde_mcf::from_str::<SupportedVariants>(variant).unwrap(),
SupportedVariants::Mcf(Hashes::Argon2i)
);
let not_a_variant = "12";
assert!(serde_yaml::from_str::<SupportedVariants>(not_a_variant).is_err());
}
#[test]
fn hash_tests() {
let hash = "$$non-existant$$$";
assert!(serde_mcf::from_str::<Output>(hash).is_err());
let hash = "$argon2i$fake_map=12$salt$hash";
assert!(serde_mcf::from_str::<Output>(hash).is_err());
}
#[test]
fn de_bcrypt() {
let hash = "$2a$10$175ikf/E6E.73e83.fJRbODnYWBwmfS0ENdzUBZbedUNGO.99wJfa";
assert!(serde_mcf::from_str::<Output>(hash).is_ok());
let broken_hash = "$2a$purple$175ikf/E6E.73e83.fJRbODnYWBwmfS0ENdzUBZbedUNGO.99wJfa";
assert!(serde_mcf::from_str::<Output>(broken_hash).is_err());
}
}