mod scram;
#[cfg(test)]
mod test;
use std::{borrow::Cow, str::FromStr};
use bson::Document;
use rand::Rng;
use typed_builder::TypedBuilder;
use self::scram::ScramVersion;
use crate::{
cmap::{Connection, StreamDescription},
error::{Error, ErrorKind, Result},
};
const SCRAM_SHA_1_STR: &str = "SCRAM-SHA-1";
const SCRAM_SHA_256_STR: &str = "SCRAM-SHA-256";
const MONGODB_CR_STR: &str = "MONGODB-CR";
const GSSAPI_STR: &str = "GSSAPI";
const MONGODB_X509_STR: &str = "MONGODB-X509";
const PLAIN_STR: &str = "PLAIN";
#[derive(Clone, PartialEq, Debug)]
pub enum AuthMechanism {
MongoDbCr,
ScramSha1,
ScramSha256,
MongoDbX509,
Gssapi,
Plain,
}
impl AuthMechanism {
pub(crate) fn from_stream_description(description: &StreamDescription) -> AuthMechanism {
let scram_sha_256_found = description
.sasl_supported_mechs
.as_ref()
.map(|ms| ms.iter().any(|m| m == AuthMechanism::ScramSha256.as_str()))
.unwrap_or(false);
if scram_sha_256_found {
AuthMechanism::ScramSha256
} else {
AuthMechanism::ScramSha1
}
}
pub fn validate_credential(&self, credential: &Credential) -> Result<()> {
match self {
AuthMechanism::ScramSha1 | AuthMechanism::ScramSha256 => {
if credential.username.is_none() {
return Err(ErrorKind::ArgumentError {
message: "No username provided for SCRAM authentication".to_string(),
}
.into());
};
Ok(())
}
_ => Ok(()),
}
}
pub fn as_str(&self) -> &'static str {
match self {
AuthMechanism::ScramSha1 => SCRAM_SHA_1_STR,
AuthMechanism::ScramSha256 => SCRAM_SHA_256_STR,
AuthMechanism::MongoDbCr => MONGODB_CR_STR,
AuthMechanism::MongoDbX509 => MONGODB_X509_STR,
AuthMechanism::Gssapi => GSSAPI_STR,
AuthMechanism::Plain => PLAIN_STR,
}
}
pub(crate) fn default_source<'a>(&'a self, uri_db: Option<&'a str>) -> &'a str {
match self {
AuthMechanism::ScramSha1 | AuthMechanism::ScramSha256 | AuthMechanism::MongoDbCr => {
uri_db.unwrap_or("admin")
}
_ => "",
}
}
pub(crate) fn authenticate_stream(
&self,
stream: &mut Connection,
credential: &Credential,
) -> Result<()> {
match self {
AuthMechanism::ScramSha1 => ScramVersion::Sha1.authenticate_stream(stream, credential),
AuthMechanism::ScramSha256 => {
ScramVersion::Sha256.authenticate_stream(stream, credential)
}
AuthMechanism::MongoDbCr => Err(ErrorKind::AuthenticationError {
message: "MONGODB-CR is deprecated and not supported by this driver. Use SCRAM \
for password-based authentication instead"
.into(),
}
.into()),
_ => Err(ErrorKind::AuthenticationError {
message: format!("Authentication mechanism {:?} not yet implemented.", self),
}
.into()),
}
}
}
impl FromStr for AuthMechanism {
type Err = Error;
fn from_str(str: &str) -> Result<Self> {
match str {
SCRAM_SHA_1_STR => Ok(AuthMechanism::ScramSha1),
SCRAM_SHA_256_STR => Ok(AuthMechanism::ScramSha256),
MONGODB_CR_STR => Ok(AuthMechanism::MongoDbCr),
MONGODB_X509_STR => Ok(AuthMechanism::MongoDbX509),
GSSAPI_STR => Ok(AuthMechanism::Gssapi),
PLAIN_STR => Ok(AuthMechanism::Plain),
_ => Err(ErrorKind::ArgumentError {
message: format!("invalid mechanism string: {}", str),
}
.into()),
}
}
}
#[derive(Clone, Debug, Default, TypedBuilder, PartialEq)]
pub struct Credential {
#[builder(default)]
pub username: Option<String>,
#[builder(default)]
pub source: Option<String>,
#[builder(default)]
pub password: Option<String>,
#[builder(default)]
pub mechanism: Option<AuthMechanism>,
#[builder(default)]
pub mechanism_properties: Option<Document>,
}
impl Credential {
#[cfg(test)]
pub(crate) fn into_document(mut self) -> Document {
use bson::Bson;
let mut doc = Document::new();
if let Some(s) = self.username.take() {
doc.insert("username", s);
}
if let Some(s) = self.password.take() {
doc.insert("password", s);
} else {
doc.insert("password", Bson::Null);
}
if let Some(s) = self.source.take() {
doc.insert("db", s);
}
doc
}
pub(crate) fn resolved_source(&self) -> &str {
self.mechanism
.as_ref()
.map(|m| m.default_source(None))
.unwrap_or("admin")
}
pub(crate) fn append_needed_mechanism_negotiation(&self, command: &mut Document) {
if let (Some(username), None) = (self.username.as_ref(), self.mechanism.as_ref()) {
command.insert(
"saslSupportedMechs",
format!("{}.{}", self.resolved_source(), username),
);
}
}
pub(crate) fn authenticate_stream(&self, conn: &mut Connection) -> Result<()> {
let stream_description = conn.stream_description()?;
if !stream_description.server_type.can_auth() {
return Ok(());
};
let mechanism = match self.mechanism {
None => Cow::Owned(AuthMechanism::from_stream_description(stream_description)),
Some(ref m) => Cow::Borrowed(m),
};
mechanism.authenticate_stream(conn, self)
}
}
pub(crate) fn generate_nonce() -> String {
let mut rng = rand::thread_rng();
let result: Vec<u8> = (0..32).map(|_| rng.gen()).collect();
base64::encode(result.as_slice())
}