use std::borrow::Cow;
use base64;
use rand::distributions::{Distribution, Uniform};
use rand::{rngs::OsRng, Rng};
use ring::digest::SHA256_OUTPUT_LEN;
use ring::hmac;
use error::{Error, Field, Kind};
use utils::find_proofs;
use NONCE_LENGTH;
pub struct ScramServer<P: AuthenticationProvider> {
provider: P,
}
pub struct PasswordInfo {
hashed_password: Vec<u8>,
salt: Vec<u8>,
iterations: u16,
}
#[derive(Clone, Copy, PartialEq, Debug)]
pub enum AuthenticationStatus {
Authenticated,
NotAuthenticated,
NotAuthorized,
}
impl PasswordInfo {
pub fn new(hashed_password: Vec<u8>, iterations: u16, salt: Vec<u8>) -> Self {
PasswordInfo {
hashed_password,
iterations,
salt,
}
}
}
pub trait AuthenticationProvider {
fn get_password_for(&self, username: &str) -> Option<PasswordInfo>;
fn authorize(&self, authcid: &str, authzid: &str) -> bool {
authcid == authzid
}
}
fn parse_client_first(data: &str) -> Result<(&str, Option<&str>, &str), Error> {
let mut parts = data.split(',');
if let Some(part) = parts.next() {
if let Some(cb) = part.chars().next() {
if cb == 'p' {
return Err(Error::UnsupportedExtension);
}
if cb != 'n' && cb != 'y' || part.len() > 1 {
return Err(Error::Protocol(Kind::InvalidField(Field::ChannelBinding)));
}
} else {
return Err(Error::Protocol(Kind::ExpectedField(Field::ChannelBinding)));
}
} else {
return Err(Error::Protocol(Kind::ExpectedField(Field::ChannelBinding)));
}
let authzid = if let Some(part) = parts.next() {
if part.is_empty() {
None
} else if part.len() < 2 || &part.as_bytes()[..2] != b"a=" {
return Err(Error::Protocol(Kind::ExpectedField(Field::Authzid)));
} else {
Some(&part[2..])
}
} else {
return Err(Error::Protocol(Kind::ExpectedField(Field::Authzid)));
};
let authcid = parse_part!(parts, Authcid, b"n=");
let nonce = match parts.next() {
Some(part) if &part.as_bytes()[..2] == b"r=" => &part[2..],
_ => {
return Err(Error::Protocol(Kind::ExpectedField(Field::Nonce)));
}
};
Ok((authcid, authzid, nonce))
}
fn parse_client_final(data: &str) -> Result<(&str, &str, &str), Error> {
let mut parts = data.split(',');
let gs2header = parse_part!(parts, GS2Header, b"c=");
let nonce = parse_part!(parts, Nonce, b"r=");
let proof = parse_part!(parts, Proof, b"p=");
Ok((gs2header, nonce, proof))
}
impl<P: AuthenticationProvider> ScramServer<P> {
pub fn new(provider: P) -> Self {
ScramServer { provider }
}
pub fn handle_client_first<'a>(
&'a self,
client_first: &'a str,
) -> Result<ServerFirst<'a, P>, Error> {
let (authcid, authzid, client_nonce) = parse_client_first(client_first)?;
let password_info = self
.provider
.get_password_for(authcid)
.ok_or_else(|| Error::InvalidUser(authcid.to_string()))?;
Ok(ServerFirst {
client_nonce,
authcid,
authzid,
provider: &self.provider,
password_info,
})
}
}
pub struct ServerFirst<'a, P: 'a + AuthenticationProvider> {
client_nonce: &'a str,
authcid: &'a str,
authzid: Option<&'a str>,
provider: &'a P,
password_info: PasswordInfo,
}
impl<'a, P: AuthenticationProvider> ServerFirst<'a, P> {
pub fn server_first(self) -> (ClientFinal<'a, P>, String) {
self.server_first_with_rng(&mut OsRng)
}
pub fn server_first_with_rng<R: Rng>(self, rng: &mut R) -> (ClientFinal<'a, P>, String) {
let mut nonce = String::with_capacity(self.client_nonce.len() + NONCE_LENGTH);
nonce.push_str(self.client_nonce);
nonce.extend(
Uniform::from(33..125)
.sample_iter(rng)
.map(|x: u8| if x > 43 { (x + 1) as char } else { x as char })
.take(NONCE_LENGTH),
);
let gs2header: Cow<'static, str> = match self.authzid {
Some(authzid) => format!("n,a={},", authzid).into(),
None => "n,,".into(),
};
let client_first_bare: Cow<'static, str> =
format!("n={},r={}", self.authcid, self.client_nonce).into();
let server_first: Cow<'static, str> = format!(
"r={},s={},i={}",
nonce,
base64::encode(self.password_info.salt.as_slice()),
self.password_info.iterations
)
.into();
(
ClientFinal {
hashed_password: self.password_info.hashed_password,
nonce,
gs2header,
client_first_bare,
server_first: server_first.clone(),
authcid: self.authcid,
authzid: self.authzid,
provider: self.provider,
},
server_first.into_owned(),
)
}
}
pub struct ClientFinal<'a, P: 'a + AuthenticationProvider> {
hashed_password: Vec<u8>,
nonce: String,
gs2header: Cow<'static, str>,
client_first_bare: Cow<'static, str>,
server_first: Cow<'static, str>,
authcid: &'a str,
authzid: Option<&'a str>,
provider: &'a P,
}
impl<'a, P: AuthenticationProvider> ClientFinal<'a, P> {
pub fn handle_client_final(self, client_final: &str) -> Result<ServerFinal, Error> {
let (gs2header_enc, nonce, proof) = parse_client_final(client_final)?;
if !self.verify_header(gs2header_enc) {
return Err(Error::Protocol(Kind::InvalidField(Field::GS2Header)));
}
if !self.verify_nonce(nonce) {
return Err(Error::Protocol(Kind::InvalidField(Field::Nonce)));
}
if let Some(signature) = self.verify_proof(proof)? {
if let Some(authzid) = self.authzid {
if self.provider.authorize(self.authcid, authzid) {
Ok(ServerFinal {
status: AuthenticationStatus::Authenticated,
signature,
})
} else {
Ok(ServerFinal {
status: AuthenticationStatus::NotAuthorized,
signature: format!(
"e=User '{}' not authorized to act as '{}'",
self.authcid, authzid
),
})
}
} else {
Ok(ServerFinal {
status: AuthenticationStatus::Authenticated,
signature,
})
}
} else {
Ok(ServerFinal {
status: AuthenticationStatus::NotAuthenticated,
signature: "e=Invalid Password".to_string(),
})
}
}
fn verify_header(&self, gs2header: &str) -> bool {
let server_gs2header = base64::encode(self.gs2header.as_bytes());
server_gs2header == gs2header
}
fn verify_nonce(&self, nonce: &str) -> bool {
nonce == self.nonce
}
fn verify_proof(&self, proof: &str) -> Result<Option<String>, Error> {
let (client_proof, server_signature): ([u8; SHA256_OUTPUT_LEN], hmac::Tag) = find_proofs(
&self.gs2header,
&self.client_first_bare,
&self.server_first,
self.hashed_password.as_slice(),
&self.nonce,
);
let proof = if let Ok(proof) = base64::decode(proof.as_bytes()) {
proof
} else {
return Err(Error::Protocol(Kind::InvalidField(Field::Proof)));
};
if proof != client_proof {
return Ok(None);
}
let server_signature_string = format!("v={}", base64::encode(server_signature.as_ref()));
Ok(Some(server_signature_string))
}
}
pub struct ServerFinal {
status: AuthenticationStatus,
signature: String,
}
impl ServerFinal {
pub fn server_final(self) -> (AuthenticationStatus, String) {
(self.status, self.signature)
}
}
#[cfg(test)]
mod tests {
use super::super::{Error, Field, Kind};
use super::{parse_client_final, parse_client_first};
#[test]
fn test_parse_client_first_success() {
let (authcid, authzid, nonce) = parse_client_first("n,,n=user,r=abcdefghijk").unwrap();
assert_eq!(authcid, "user");
assert!(authzid.is_none());
assert_eq!(nonce, "abcdefghijk");
let (authcid, authzid, nonce) =
parse_client_first("y,a=other user,n=user,r=abcdef=hijk").unwrap();
assert_eq!(authcid, "user");
assert_eq!(authzid, Some("other user"));
assert_eq!(nonce, "abcdef=hijk");
let (authcid, authzid, nonce) = parse_client_first("n,,n=,r=").unwrap();
assert_eq!(authcid, "");
assert!(authzid.is_none());
assert_eq!(nonce, "");
}
#[test]
fn test_parse_client_first_missing_fields() {
assert_eq!(
parse_client_first("n,,n=user").unwrap_err(),
Error::Protocol(Kind::ExpectedField(Field::Nonce))
);
assert_eq!(
parse_client_first("n,,r=user").unwrap_err(),
Error::Protocol(Kind::ExpectedField(Field::Authcid))
);
assert_eq!(
parse_client_first("n,n=user,r=abc").unwrap_err(),
Error::Protocol(Kind::ExpectedField(Field::Authzid))
);
assert_eq!(
parse_client_first(",,n=user,r=abc").unwrap_err(),
Error::Protocol(Kind::ExpectedField(Field::ChannelBinding))
);
assert_eq!(
parse_client_first("").unwrap_err(),
Error::Protocol(Kind::ExpectedField(Field::ChannelBinding))
);
assert_eq!(
parse_client_first(",,,").unwrap_err(),
Error::Protocol(Kind::ExpectedField(Field::ChannelBinding))
);
}
#[test]
fn test_parse_client_first_invalid_data() {
assert_eq!(
parse_client_first("a,,n=user,r=abc").unwrap_err(),
Error::Protocol(Kind::InvalidField(Field::ChannelBinding))
);
assert_eq!(
parse_client_first("p,,n=user,r=abc").unwrap_err(),
Error::UnsupportedExtension
);
assert_eq!(
parse_client_first("nn,,n=user,r=abc").unwrap_err(),
Error::Protocol(Kind::InvalidField(Field::ChannelBinding))
);
assert_eq!(
parse_client_first("n,,n,r=abc").unwrap_err(),
Error::Protocol(Kind::ExpectedField(Field::Authcid))
);
}
#[test]
fn test_parse_client_final_success() {
let (gs2head, nonce, proof) = parse_client_final("c=abc,r=abcefg,p=783232").unwrap();
assert_eq!(gs2head, "abc");
assert_eq!(nonce, "abcefg");
assert_eq!(proof, "783232");
let (gs2head, nonce, proof) = parse_client_final("c=,r=,p=").unwrap();
assert_eq!(gs2head, "");
assert_eq!(nonce, "");
assert_eq!(proof, "");
}
#[test]
fn test_parse_client_final_missing_fields() {
assert_eq!(
parse_client_final("c=whatever,r=something").unwrap_err(),
Error::Protocol(Kind::ExpectedField(Field::Proof))
);
assert_eq!(
parse_client_final("c=whatever,p=words").unwrap_err(),
Error::Protocol(Kind::ExpectedField(Field::Nonce))
);
assert_eq!(
parse_client_final("c=whatever").unwrap_err(),
Error::Protocol(Kind::ExpectedField(Field::Nonce))
);
assert_eq!(
parse_client_final("c=").unwrap_err(),
Error::Protocol(Kind::ExpectedField(Field::Nonce))
);
assert_eq!(
parse_client_final("").unwrap_err(),
Error::Protocol(Kind::ExpectedField(Field::GS2Header))
);
assert_eq!(
parse_client_final("r=anonce").unwrap_err(),
Error::Protocol(Kind::ExpectedField(Field::GS2Header))
);
}
}