1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
use std::error::Error;
use std::fmt;
use std::str::{self, FromStr};

use bcrypt_only::bcrypt;
pub use bcrypt_only::{KEY_SIZE_MAX, Salt, WorkFactor};
pub use bcrypt_only::BcryptError as CompareError;

mod base64;

#[cfg(test)]
mod tests;

pub const FORMATTED_HASH_SIZE: usize = 60;

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum HashError {
	Length,
	ZeroByte,
	RandomError(getrandom::Error),
}

impl fmt::Display for HashError {
	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
		match self {
			HashError::Length => write!(f, "password too long"),
			HashError::ZeroByte => write!(f, "password contains a NUL character"),
			HashError::RandomError(err) => write!(f, "salt generation failed: {}", err),
		}
	}
}

impl Error for HashError {
	fn source(&self) -> Option<&(dyn Error + 'static)> {
		match self {
			HashError::Length | HashError::ZeroByte => None,
			HashError::RandomError(err) => Some(err),
		}
	}
}

#[derive(Clone, Debug)]
pub struct Hash {
	pub work_factor: WorkFactor,
	pub salt: Salt,
	pub hash: [u8; 23],
}

impl Hash {
	pub fn to_formatted(&self) -> [u8; FORMATTED_HASH_SIZE] {
		let mut formatted = [0_u8; 60];
		formatted[..4].copy_from_slice(b"$2b$");
		formatted[4] = b'0' + (self.work_factor.log_rounds() / 10) as u8;
		formatted[5] = b'0' + (self.work_factor.log_rounds() % 10) as u8;
		formatted[6] = b'$';
		base64::encode(&self.salt.to_bytes(), &mut formatted[7..29]);
		base64::encode(&self.hash, &mut formatted[29..60]);
		formatted
	}
}

#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum ParseError {
	Length,
	Prefix,
	WorkFactor,
	Salt,
	Hash,
}

impl fmt::Display for ParseError {
	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
		write!(f, "{}", match self {
			ParseError::Length => "invalid length",
			ParseError::Prefix => "invalid prefix",
			ParseError::WorkFactor => "invalid work factor",
			ParseError::Salt => "invalid salt",
			ParseError::Hash => "invalid hash",
		})
	}
}

impl Error for ParseError {}

impl FromStr for Hash {
	type Err = ParseError;

	fn from_str(s: &str) -> Result<Self, Self::Err> {
		if s.len() != 60 {
			return Err(ParseError::Length);
		}

		if !s.starts_with("$2a$") && !s.starts_with("$2b$") && !s.starts_with("$2y$") {
			return Err(ParseError::Prefix);
		}

		let work_factor =
			s.get(4..6)
				.and_then(|rs| rs.parse().ok())
				.and_then(WorkFactor::exp)
				.ok_or(ParseError::WorkFactor)?;

		let salt = {
			let mut salt = [0_u8; 16];
			base64::decode(&s.as_bytes()[7..29], &mut salt).map_err(|_| ParseError::Salt)?;
			Salt::from_bytes(&salt)
		};

		let mut hash = [0_u8; 23];
		base64::decode(&s.as_bytes()[29..60], &mut hash).map_err(|_| ParseError::Hash)?;

		Ok(Self { work_factor, salt, hash })
	}
}

pub fn hash(password: &str, work_factor: WorkFactor) -> Result<Hash, HashError> {
	if password.len() > KEY_SIZE_MAX {
		return Err(HashError::Length);
	}

	if password.contains('\0') {
		return Err(HashError::ZeroByte);
	}

	let mut salt = [0_u8; 16];
	getrandom::getrandom(&mut salt).map_err(HashError::RandomError)?;
	let salt = Salt::from_bytes(&salt);

	let hash = bcrypt(password.as_bytes(), &salt, work_factor).unwrap();
	Ok(Hash { work_factor, salt, hash })
}

pub fn compare(password: &str, expected: &Hash) -> Result<bool, CompareError> {
	let hash = bcrypt(password.as_bytes(), &expected.salt, expected.work_factor)?;
	Ok(hash == expected.hash)
}