#[cfg(feature = "aws-auth")]
pub(crate) mod aws;
pub(crate) mod oidc;
mod plain;
mod sasl;
mod scram;
#[cfg(test)]
mod test;
mod x509;
use std::{borrow::Cow, fmt::Debug, str::FromStr};
use derivative::Derivative;
use hmac::{digest::KeyInit, Mac};
use rand::Rng;
use serde::Deserialize;
use typed_builder::TypedBuilder;
use self::scram::ScramVersion;
use crate::{
bson::Document,
client::options::ServerApi,
cmap::{Command, Connection, StreamDescription},
error::{Error, ErrorKind, Result},
};
const SCRAM_SHA_1_STR: &str = "SCRAM-SHA-1";
const SCRAM_SHA_256_STR: &str = "SCRAM-SHA-256";
const MONGODB_CR_STR: &str = "MONGODB-CR";
const GSSAPI_STR: &str = "GSSAPI";
const MONGODB_AWS_STR: &str = "MONGODB-AWS";
const MONGODB_X509_STR: &str = "MONGODB-X509";
const PLAIN_STR: &str = "PLAIN";
const MONGODB_OIDC_STR: &str = "MONGODB-OIDC";
#[derive(Clone, Deserialize, PartialEq, Debug)]
#[non_exhaustive]
pub enum AuthMechanism {
MongoDbCr,
ScramSha1,
ScramSha256,
MongoDbX509,
Gssapi,
Plain,
#[cfg(feature = "aws-auth")]
MongoDbAws,
MongoDbOidc,
}
impl AuthMechanism {
fn from_scram_version(scram: &ScramVersion) -> Self {
match scram {
ScramVersion::Sha1 => Self::ScramSha1,
ScramVersion::Sha256 => Self::ScramSha256,
}
}
pub(crate) fn from_stream_description(description: &StreamDescription) -> AuthMechanism {
let scram_sha_256_found = description
.sasl_supported_mechs
.as_ref()
.map(|ms| ms.iter().any(|m| m == AuthMechanism::ScramSha256.as_str()))
.unwrap_or(false);
if scram_sha_256_found {
AuthMechanism::ScramSha256
} else {
AuthMechanism::ScramSha1
}
}
pub fn validate_credential(&self, credential: &Credential) -> Result<()> {
match self {
AuthMechanism::ScramSha1 | AuthMechanism::ScramSha256 => {
if credential.username.is_none() {
return Err(ErrorKind::InvalidArgument {
message: "No username provided for SCRAM authentication".to_string(),
}
.into());
};
Ok(())
}
AuthMechanism::MongoDbX509 => {
if credential.password.is_some() {
return Err(ErrorKind::InvalidArgument {
message: "A password cannot be specified with MONGODB-X509".to_string(),
}
.into());
}
if credential.source.as_deref().unwrap_or("$external") != "$external" {
return Err(ErrorKind::InvalidArgument {
message: "only $external may be specified as an auth source for \
MONGODB-X509"
.to_string(),
}
.into());
}
Ok(())
}
AuthMechanism::Plain => {
if credential.username.is_none() {
return Err(ErrorKind::InvalidArgument {
message: "No username provided for PLAIN authentication".to_string(),
}
.into());
}
if credential.username.as_deref() == Some("") {
return Err(ErrorKind::InvalidArgument {
message: "Username for PLAIN authentication must be non-empty".to_string(),
}
.into());
}
if credential.password.is_none() {
return Err(ErrorKind::InvalidArgument {
message: "No password provided for PLAIN authentication".to_string(),
}
.into());
}
Ok(())
}
#[cfg(feature = "aws-auth")]
AuthMechanism::MongoDbAws => {
if credential.username.is_some() && credential.password.is_none() {
return Err(ErrorKind::InvalidArgument {
message: "Username cannot be provided without password for MONGODB-AWS \
authentication"
.to_string(),
}
.into());
}
Ok(())
}
AuthMechanism::MongoDbOidc => {
let is_automatic = credential
.mechanism_properties
.as_ref()
.map_or(false, |p| p.contains_key("PROVIDER_NAME"));
if credential.username.is_some() && is_automatic {
return Err(Error::invalid_argument(
"username and PROVIDER_NAME cannot both be specified for MONGODB-OIDC \
authentication",
));
}
if credential
.source
.as_ref()
.map_or(false, |s| s != "$external")
{
return Err(Error::invalid_argument(
"source must be $external for MONGODB-OIDC authentication",
));
}
if credential.password.is_some() {
return Err(Error::invalid_argument(
"password must not be set for MONGODB-OIDC authentication",
));
}
Ok(())
}
_ => Ok(()),
}
}
pub fn as_str(&self) -> &'static str {
match self {
AuthMechanism::ScramSha1 => SCRAM_SHA_1_STR,
AuthMechanism::ScramSha256 => SCRAM_SHA_256_STR,
AuthMechanism::MongoDbCr => MONGODB_CR_STR,
AuthMechanism::MongoDbX509 => MONGODB_X509_STR,
AuthMechanism::Gssapi => GSSAPI_STR,
AuthMechanism::Plain => PLAIN_STR,
#[cfg(feature = "aws-auth")]
AuthMechanism::MongoDbAws => MONGODB_AWS_STR,
AuthMechanism::MongoDbOidc => MONGODB_OIDC_STR,
}
}
pub(crate) fn default_source<'a>(&'a self, uri_db: Option<&'a str>) -> &'a str {
match self {
AuthMechanism::ScramSha1 | AuthMechanism::ScramSha256 | AuthMechanism::MongoDbCr => {
uri_db.unwrap_or("admin")
}
AuthMechanism::MongoDbX509 => "$external",
AuthMechanism::Plain => uri_db.unwrap_or("$external"),
AuthMechanism::MongoDbOidc => "$external",
#[cfg(feature = "aws-auth")]
AuthMechanism::MongoDbAws => "$external",
AuthMechanism::Gssapi => "",
}
}
pub(crate) fn build_speculative_client_first(
&self,
credential: &Credential,
) -> Result<Option<ClientFirst>> {
match self {
Self::ScramSha1 => {
let client_first = ScramVersion::Sha1.build_speculative_client_first(credential)?;
Ok(Some(ClientFirst::Scram(ScramVersion::Sha1, client_first)))
}
Self::ScramSha256 => {
let client_first =
ScramVersion::Sha256.build_speculative_client_first(credential)?;
Ok(Some(ClientFirst::Scram(ScramVersion::Sha256, client_first)))
}
Self::MongoDbX509 => Ok(Some(ClientFirst::X509(Box::new(
x509::build_speculative_client_first(credential),
)))),
Self::Plain => Ok(None),
Self::MongoDbOidc => Ok(None),
#[cfg(feature = "aws-auth")]
AuthMechanism::MongoDbAws => Ok(None),
AuthMechanism::MongoDbCr => Err(ErrorKind::Authentication {
message: "MONGODB-CR is deprecated and not supported by this driver. Use SCRAM \
for password-based authentication instead"
.into(),
}
.into()),
_ => Err(ErrorKind::Authentication {
message: format!("Authentication mechanism {:?} not yet implemented.", self),
}
.into()),
}
}
pub(crate) async fn authenticate_stream(
&self,
stream: &mut Connection,
credential: &Credential,
server_api: Option<&ServerApi>,
#[cfg(feature = "aws-auth")] http_client: &crate::runtime::HttpClient,
) -> Result<()> {
self.validate_credential(credential)?;
match self {
AuthMechanism::ScramSha1 => {
ScramVersion::Sha1
.authenticate_stream(stream, credential, server_api, None)
.await
}
AuthMechanism::ScramSha256 => {
ScramVersion::Sha256
.authenticate_stream(stream, credential, server_api, None)
.await
}
AuthMechanism::MongoDbX509 => {
x509::authenticate_stream(stream, credential, server_api, None).await
}
AuthMechanism::Plain => {
plain::authenticate_stream(stream, credential, server_api).await
}
#[cfg(feature = "aws-auth")]
AuthMechanism::MongoDbAws => {
aws::authenticate_stream(stream, credential, server_api, http_client).await
}
AuthMechanism::MongoDbCr => Err(ErrorKind::Authentication {
message: "MONGODB-CR is deprecated and not supported by this driver. Use SCRAM \
for password-based authentication instead"
.into(),
}
.into()),
AuthMechanism::MongoDbOidc => {
oidc::authenticate_stream(stream, credential, server_api).await
}
_ => Err(ErrorKind::Authentication {
message: format!("Authentication mechanism {:?} not yet implemented.", self),
}
.into()),
}
}
}
impl FromStr for AuthMechanism {
type Err = Error;
fn from_str(str: &str) -> Result<Self> {
match str {
SCRAM_SHA_1_STR => Ok(AuthMechanism::ScramSha1),
SCRAM_SHA_256_STR => Ok(AuthMechanism::ScramSha256),
MONGODB_CR_STR => Ok(AuthMechanism::MongoDbCr),
MONGODB_X509_STR => Ok(AuthMechanism::MongoDbX509),
GSSAPI_STR => Ok(AuthMechanism::Gssapi),
PLAIN_STR => Ok(AuthMechanism::Plain),
MONGODB_OIDC_STR => Ok(AuthMechanism::MongoDbOidc),
#[cfg(feature = "aws-auth")]
MONGODB_AWS_STR => Ok(AuthMechanism::MongoDbAws),
#[cfg(not(feature = "aws-auth"))]
MONGODB_AWS_STR => Err(ErrorKind::InvalidArgument {
message: "MONGODB-AWS auth is only supported with the aws-auth feature flag and \
the tokio runtime"
.into(),
}
.into()),
_ => Err(ErrorKind::InvalidArgument {
message: format!("invalid mechanism string: {}", str),
}
.into()),
}
}
}
#[derive(Clone, Default, Deserialize, TypedBuilder, Derivative)]
#[derivative(PartialEq)]
#[builder(field_defaults(default, setter(into)))]
#[non_exhaustive]
pub struct Credential {
pub username: Option<String>,
pub source: Option<String>,
pub password: Option<String>,
pub mechanism: Option<AuthMechanism>,
pub mechanism_properties: Option<Document>,
#[serde(skip)]
#[derivative(Debug = "ignore", PartialEq = "ignore")]
pub(crate) oidc_callbacks: Option<oidc::Callbacks>,
}
impl Credential {
pub(crate) fn resolved_source(&self) -> &str {
self.mechanism
.as_ref()
.map(|m| m.default_source(None))
.unwrap_or("admin")
}
pub(crate) fn append_needed_mechanism_negotiation(&self, command: &mut Document) {
if let (Some(username), None) = (self.username.as_ref(), self.mechanism.as_ref()) {
command.insert(
"saslSupportedMechs",
format!("{}.{}", self.resolved_source(), username),
);
}
}
pub(crate) async fn authenticate_stream(
&self,
conn: &mut Connection,
server_api: Option<&ServerApi>,
first_round: Option<FirstRound>,
#[cfg(feature = "aws-auth")] http_client: &crate::runtime::HttpClient,
) -> Result<()> {
let stream_description = conn.stream_description()?;
if !stream_description.initial_server_type.can_auth() {
return Ok(());
};
if let Some(first_round) = first_round {
return match first_round {
FirstRound::Scram(version, first_round) => {
version
.authenticate_stream(conn, self, server_api, first_round)
.await
}
FirstRound::X509(server_first) => {
x509::authenticate_stream(conn, self, server_api, server_first).await
}
};
}
let mechanism = match self.mechanism {
None => Cow::Owned(AuthMechanism::from_stream_description(stream_description)),
Some(ref m) => Cow::Borrowed(m),
};
mechanism
.authenticate_stream(
conn,
self,
server_api,
#[cfg(feature = "aws-auth")]
http_client,
)
.await
}
#[cfg(test)]
pub(crate) fn serialize_for_client_options<S>(
credential: &Option<Credential>,
serializer: S,
) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::Serialize;
#[derive(serde::Serialize)]
struct CredentialHelper<'a> {
authsource: Option<&'a String>,
authmechanism: Option<&'a str>,
authmechanismproperties: Option<&'a Document>,
}
let state = credential.as_ref().map(|c| CredentialHelper {
authsource: c.source.as_ref(),
authmechanism: c.mechanism.as_ref().map(|s| s.as_str()),
authmechanismproperties: c.mechanism_properties.as_ref(),
});
state.serialize(serializer)
}
}
impl Debug for Credential {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Credential")
.field(&"REDACTED".to_string())
.finish()
}
}
pub(crate) enum ClientFirst {
Scram(ScramVersion, scram::ClientFirst),
X509(Box<Command>),
}
impl ClientFirst {
pub(crate) fn to_document(&self) -> Document {
match self {
Self::Scram(version, client_first) => client_first.to_command(version).body,
Self::X509(command) => command.body.clone(),
}
}
pub(crate) fn into_first_round(self, server_first: Document) -> FirstRound {
match self {
Self::Scram(version, client_first) => FirstRound::Scram(
version,
scram::FirstRound {
client_first,
server_first,
},
),
Self::X509(..) => FirstRound::X509(server_first),
}
}
}
#[derive(Debug)]
pub(crate) enum FirstRound {
Scram(ScramVersion, scram::FirstRound),
X509(Document),
}
pub(crate) fn generate_nonce_bytes() -> [u8; 32] {
rand::thread_rng().gen()
}
pub(crate) fn generate_nonce() -> String {
let result = generate_nonce_bytes();
base64::encode(result)
}
fn mac<M: Mac + KeyInit>(
key: &[u8],
input: &[u8],
auth_mechanism: &str,
) -> Result<impl AsRef<[u8]>> {
let mut mac = <M as Mac>::new_from_slice(key)
.map_err(|_| Error::unknown_authentication_error(auth_mechanism))?;
mac.update(input);
Ok(mac.finalize().into_bytes())
}