1use std::error::Error;
2use std::fmt;
3use std::str::{self, FromStr};
4
5use bcrypt_only::bcrypt;
6pub use bcrypt_only::{KEY_SIZE_MAX, Salt, WorkFactor};
7pub use bcrypt_only::BcryptError as CompareError;
8
9mod base64;
10
11#[cfg(test)]
12mod tests;
13
14pub const FORMATTED_HASH_SIZE: usize = 60;
16
17#[derive(Clone, Copy, Debug, Eq, PartialEq)]
19pub enum HashError {
20 Length,
22 ZeroByte,
24 RandomError(getrandom::Error),
26}
27
28impl fmt::Display for HashError {
29 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
30 match self {
31 HashError::Length => write!(f, "password too long"),
32 HashError::ZeroByte => write!(f, "password contains a NUL character"),
33 HashError::RandomError(err) => write!(f, "salt generation failed: {}", err),
34 }
35 }
36}
37
38impl Error for HashError {
39 fn source(&self) -> Option<&(dyn Error + 'static)> {
40 match self {
41 HashError::Length | HashError::ZeroByte => None,
42 HashError::RandomError(err) => Some(err),
43 }
44 }
45}
46
47#[derive(Clone, Debug)]
49pub struct Hash {
50 pub work_factor: WorkFactor,
52 pub salt: Salt,
54 pub hash: [u8; 23],
56}
57
58impl Hash {
59 pub fn to_formatted(&self) -> [u8; FORMATTED_HASH_SIZE] {
76 let mut formatted = [0_u8; 60];
77 formatted[..4].copy_from_slice(b"$2b$");
78 formatted[4] = b'0' + (self.work_factor.log_rounds() / 10) as u8;
79 formatted[5] = b'0' + (self.work_factor.log_rounds() % 10) as u8;
80 formatted[6] = b'$';
81 base64::encode(&self.salt.to_bytes(), &mut formatted[7..29]);
82 base64::encode(&self.hash, &mut formatted[29..60]);
83 formatted
84 }
85}
86
87#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
89pub enum ParseError {
90 Length,
91 Prefix,
92 WorkFactor,
93 Salt,
94 Hash,
95}
96
97impl fmt::Display for ParseError {
98 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
99 write!(f, "{}", match self {
100 ParseError::Length => "invalid length",
101 ParseError::Prefix => "invalid prefix",
102 ParseError::WorkFactor => "invalid work factor",
103 ParseError::Salt => "invalid salt",
104 ParseError::Hash => "invalid hash",
105 })
106 }
107}
108
109impl Error for ParseError {}
110
111impl FromStr for Hash {
113 type Err = ParseError;
114
115 fn from_str(s: &str) -> Result<Self, Self::Err> {
116 if s.len() != 60 {
117 return Err(ParseError::Length);
118 }
119
120 if !s.starts_with("$2a$") && !s.starts_with("$2b$") && !s.starts_with("$2y$") {
121 return Err(ParseError::Prefix);
122 }
123
124 let work_factor =
125 s.get(4..6)
126 .and_then(|rs| rs.parse().ok())
127 .and_then(WorkFactor::exp)
128 .ok_or(ParseError::WorkFactor)?;
129
130 let salt = {
131 let mut salt = [0_u8; 16];
132 base64::decode(&s.as_bytes()[7..29], &mut salt).map_err(|_| ParseError::Salt)?;
133 Salt::from_bytes(&salt)
134 };
135
136 let mut hash = [0_u8; 23];
137 base64::decode(&s.as_bytes()[29..60], &mut hash).map_err(|_| ParseError::Hash)?;
138
139 Ok(Self { work_factor, salt, hash })
140 }
141}
142
143pub fn hash(password: &str, work_factor: WorkFactor) -> Result<Hash, HashError> {
145 if password.len() > KEY_SIZE_MAX {
146 return Err(HashError::Length);
147 }
148
149 if password.contains('\0') {
150 return Err(HashError::ZeroByte);
151 }
152
153 let mut salt = [0_u8; 16];
154 getrandom::getrandom(&mut salt).map_err(HashError::RandomError)?;
155 let salt = Salt::from_bytes(&salt);
156
157 let hash = bcrypt(password.as_bytes(), &salt, work_factor).unwrap();
158 Ok(Hash { work_factor, salt, hash })
159}
160
161pub fn compare(password: &str, expected: &Hash) -> Result<bool, CompareError> {
163 let hash = bcrypt(password.as_bytes(), &expected.salt, expected.work_factor)?;
164 Ok(hash == expected.hash)
165}