use std::borrow::Cow;
use std::io;
use base64;
use rand::distributions::{Distribution, Uniform};
use rand::{OsRng, Rng};
use ring::digest::SHA256_OUTPUT_LEN;
use ring::hmac;
use error::{Error, Field, Kind};
use utils::{find_proofs, hash_password};
use NONCE_LENGTH;
#[deprecated(since = "0.2.0", note = "Please use `ScramClient` instead. (exported at crate root)")]
pub type ClientFirst<'a> = ScramClient<'a>;
fn parse_server_first(data: &str) -> Result<(&str, Vec<u8>, u16), Error> {
if data.len() < 2 {
return Err(Error::Protocol(Kind::ExpectedField(Field::Nonce)));
}
let mut parts = data.split(',').peekable();
match parts.peek() {
Some(part) if &part.as_bytes()[..2] == b"m=" => {
return Err(Error::UnsupportedExtension);
}
Some(_) => {}
None => {
return Err(Error::Protocol(Kind::ExpectedField(Field::Nonce)));
}
}
let nonce = match parts.next() {
Some(part) if &part.as_bytes()[..2] == b"r=" => &part[2..],
_ => {
return Err(Error::Protocol(Kind::ExpectedField(Field::Nonce)));
}
};
let salt = match parts.next() {
Some(part) if &part.as_bytes()[..2] == b"s=" => base64::decode(part[2..].as_bytes())
.map_err(|_| Error::Protocol(Kind::InvalidField(Field::Salt)))?,
_ => {
return Err(Error::Protocol(Kind::ExpectedField(Field::Salt)));
}
};
let iterations = match parts.next() {
Some(part) if &part.as_bytes()[..2] == b"i=" => part[2..]
.parse()
.map_err(|_| Error::Protocol(Kind::InvalidField(Field::Iterations)))?,
_ => {
return Err(Error::Protocol(Kind::ExpectedField(Field::Iterations)));
}
};
Ok((nonce, salt, iterations))
}
fn parse_server_final(data: &str) -> Result<Vec<u8>, Error> {
if data.len() < 2 {
return Err(Error::Protocol(Kind::ExpectedField(Field::VerifyOrError)));
}
match &data[..2] {
"v=" => base64::decode(&data.as_bytes()[2..])
.map_err(|_| Error::Protocol(Kind::InvalidField(Field::VerifyOrError))),
"e=" => Err(Error::Authentication(data[2..].to_string())),
_ => Err(Error::Protocol(Kind::ExpectedField(Field::VerifyOrError))),
}
}
#[derive(Debug)]
pub struct ScramClient<'a> {
gs2header: Cow<'static, str>,
password: &'a str,
nonce: String,
authcid: &'a str,
}
impl<'a> ScramClient<'a> {
pub fn new(authcid: &'a str, password: &'a str, authzid: Option<&'a str>) -> io::Result<Self> {
let rng = OsRng::new()?;
Ok(Self::with_rng(authcid, password, authzid, rng))
}
pub fn with_rng<R: Rng>(
authcid: &'a str,
password: &'a str,
authzid: Option<&'a str>,
mut rng: R,
) -> Self {
let gs2header: Cow<'static, str> = match authzid {
Some(authzid) => format!("n,a={},", authzid).into(),
None => "n,,".into(),
};
let nonce: String = Uniform::from(33..125)
.sample_iter(&mut rng)
.map(|x: u8| if x > 43 { (x + 1) as char } else { x as char })
.take(NONCE_LENGTH)
.collect();
ScramClient {
gs2header,
password,
authcid,
nonce,
}
}
pub fn client_first(self) -> (ServerFirst<'a>, String) {
let escaped_authcid: Cow<'a, str> =
if self.authcid.chars().any(|chr| chr == ',' || chr == '=') {
self.authcid.into()
} else {
self.authcid.replace(',', "=2C").replace('=', "=3D").into()
};
let client_first_bare = format!("n={},r={}", escaped_authcid, self.nonce);
let client_first = format!("{}{}", self.gs2header, client_first_bare);
let server_first = ServerFirst {
gs2header: self.gs2header,
password: self.password,
client_nonce: self.nonce,
client_first_bare,
};
(server_first, client_first)
}
}
#[derive(Debug)]
pub struct ServerFirst<'a> {
gs2header: Cow<'static, str>,
password: &'a str,
client_nonce: String,
client_first_bare: String,
}
impl<'a> ServerFirst<'a> {
pub fn handle_server_first(self, server_first: &str) -> Result<ClientFinal, Error> {
let (nonce, salt, iterations) = parse_server_first(server_first)?;
if !nonce.starts_with(&self.client_nonce) {
return Err(Error::Protocol(Kind::InvalidNonce));
}
let salted_password = hash_password(self.password, iterations, &salt);
let (client_proof, server_signature): ([u8; SHA256_OUTPUT_LEN], hmac::Signature) =
find_proofs(
&self.gs2header,
&self.client_first_bare,
&server_first,
&salted_password,
nonce,
);
let client_final = format!(
"c={},r={},p={}",
base64::encode(self.gs2header.as_bytes()),
nonce,
base64::encode(&client_proof)
);
Ok(ClientFinal {
server_signature,
client_final,
})
}
}
#[derive(Debug)]
pub struct ClientFinal {
server_signature: hmac::Signature,
client_final: String,
}
impl ClientFinal {
#[inline]
pub fn client_final(self) -> (ServerFinal, String) {
let server_final = ServerFinal {
server_signature: self.server_signature,
};
(server_final, self.client_final)
}
}
#[derive(Debug)]
pub struct ServerFinal {
server_signature: hmac::Signature,
}
impl ServerFinal {
pub fn handle_server_final(self, server_final: &str) -> Result<(), Error> {
if self.server_signature.as_ref() == &*parse_server_final(server_final)? {
Ok(())
} else {
Err(Error::InvalidServer)
}
}
}