use std::{fs::File, io::Read, time::Duration};
use chrono::{offset::Utc, DateTime};
use hmac::Hmac;
use lazy_static::lazy_static;
use rand::distributions::{Alphanumeric, DistString};
use serde::Deserialize;
use sha2::{Digest, Sha256};
use tokio::sync::Mutex;
use crate::{
bson::{doc, rawdoc, spec::BinarySubtype, Binary, Bson, Document},
client::{
auth::{
self,
sasl::{SaslContinue, SaslResponse, SaslStart},
AuthMechanism,
Credential,
},
options::ServerApi,
},
cmap::Connection,
error::{Error, Result},
runtime::HttpClient,
serde_util,
};
const AWS_ECS_IP: &str = "169.254.170.2";
const AWS_EC2_IP: &str = "169.254.169.254";
const AWS_LONG_DATE_FMT: &str = "%Y%m%dT%H%M%SZ";
const MECH_NAME: &str = "MONGODB-AWS";
lazy_static! {
static ref CACHED_CREDENTIAL: Mutex<Option<AwsCredential>> = Mutex::new(None);
}
pub(super) async fn authenticate_stream(
conn: &mut Connection,
credential: &Credential,
server_api: Option<&ServerApi>,
http_client: &HttpClient,
) -> Result<()> {
match authenticate_stream_inner(conn, credential, server_api, http_client).await {
Ok(()) => Ok(()),
Err(error) => {
*CACHED_CREDENTIAL.lock().await = None;
Err(error)
}
}
}
async fn authenticate_stream_inner(
conn: &mut Connection,
credential: &Credential,
server_api: Option<&ServerApi>,
http_client: &HttpClient,
) -> Result<()> {
let source = match credential.source.as_deref() {
Some("$external") | None => "$external",
Some(..) => {
return Err(Error::authentication_error(
MECH_NAME,
"auth source must be $external",
))
}
};
let nonce = auth::generate_nonce_bytes();
let client_first_payload = doc! {
"r": Binary { subtype: BinarySubtype::Generic, bytes: nonce.clone().to_vec() },
"p": 110i32,
};
let mut client_first_payload_bytes = Vec::new();
client_first_payload.to_writer(&mut client_first_payload_bytes)?;
let sasl_start = SaslStart::new(
source.into(),
AuthMechanism::MongoDbAws,
client_first_payload_bytes,
server_api.cloned(),
);
let client_first = sasl_start.into_command();
let server_first_response = conn.send_command(client_first, None).await?;
let server_first = ServerFirst::parse(server_first_response.auth_response_body(MECH_NAME)?)?;
server_first.validate(&nonce)?;
let aws_credential = {
let cached_credential = CACHED_CREDENTIAL.lock().await;
match *cached_credential {
Some(ref aws_credential) if !aws_credential.is_expired() => aws_credential.clone(),
_ => {
drop(cached_credential);
let aws_credential = AwsCredential::get(credential, http_client).await?;
if aws_credential.expiration.is_some() {
*CACHED_CREDENTIAL.lock().await = Some(aws_credential.clone());
}
aws_credential
}
}
};
let date = Utc::now();
let authorization_header = aws_credential.compute_authorization_header(
date,
&server_first.sts_host,
&server_first.server_nonce,
)?;
let mut client_second_payload = doc! {
"a": authorization_header,
"d": date.format(AWS_LONG_DATE_FMT).to_string(),
};
if let Some(security_token) = aws_credential.session_token {
client_second_payload.insert("t", security_token);
}
let mut client_second_payload_bytes = Vec::new();
client_second_payload.to_writer(&mut client_second_payload_bytes)?;
let sasl_continue = SaslContinue::new(
source.into(),
server_first.conversation_id.clone(),
client_second_payload_bytes,
server_api.cloned(),
);
let client_second = sasl_continue.into_command();
let server_second_response = conn.send_command(client_second, None).await?;
let server_second = SaslResponse::parse(
MECH_NAME,
server_second_response.auth_response_body(MECH_NAME)?,
)?;
if server_second.conversation_id != server_first.conversation_id {
return Err(Error::invalid_authentication_response(MECH_NAME));
}
if !server_second.done {
return Err(Error::invalid_authentication_response(MECH_NAME));
}
Ok(())
}
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub(crate) struct AwsCredential {
access_key_id: String,
secret_access_key: String,
#[serde(alias = "Token")]
session_token: Option<String>,
#[serde(
default,
deserialize_with = "serde_util::deserialize_datetime_option_from_double_or_string"
)]
expiration: Option<bson::DateTime>,
}
impl AwsCredential {
pub(crate) async fn get(credential: &Credential, http_client: &HttpClient) -> Result<Self> {
let access_key = credential
.username
.clone()
.or_else(|| std::env::var("AWS_ACCESS_KEY_ID").ok());
let secret_key = credential
.password
.clone()
.or_else(|| std::env::var("AWS_SECRET_ACCESS_KEY").ok());
let session_token = credential
.mechanism_properties
.as_ref()
.and_then(|d| d.get_str("AWS_SESSION_TOKEN").ok())
.map(|s| s.to_string())
.or_else(|| std::env::var("AWS_SESSION_TOKEN").ok());
let found_access_key = access_key.is_some();
let found_secret_key = secret_key.is_some();
if let (Some(access_key), Some(secret_key)) = (access_key, secret_key) {
return Ok(Self {
access_key_id: access_key,
secret_access_key: secret_key,
session_token,
expiration: None,
});
}
if found_access_key || found_secret_key {
return Err(Error::authentication_error(
MECH_NAME,
"cannot specify only one of access key and secret key; either both or neither \
must be provided",
));
}
if session_token.is_some() {
return Err(Error::authentication_error(
MECH_NAME,
"cannot specify session token without both access key and secret key",
));
}
if let (Ok(token_file), Ok(role_arn)) = (
std::env::var("AWS_WEB_IDENTITY_TOKEN_FILE"),
std::env::var("AWS_ROLE_ARN"),
) {
return Self::get_from_assume_role_with_web_identity_request(
token_file,
role_arn,
http_client,
)
.await;
}
if let Ok(relative_uri) = std::env::var("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") {
Self::get_from_ecs(relative_uri, http_client).await
} else {
Self::get_from_ec2(http_client).await
}
}
async fn get_from_assume_role_with_web_identity_request(
token_file: String,
role_arn: String,
http_client: &HttpClient,
) -> Result<Self> {
let mut file = File::open(&token_file).map_err(|_| {
Error::authentication_error(MECH_NAME, "could not open identity token file")
})?;
let mut buffer = Vec::<u8>::new();
file.read_to_end(&mut buffer).map_err(|_| {
Error::authentication_error(MECH_NAME, "could not read identity token file")
})?;
let token = std::str::from_utf8(&buffer).map_err(|_| {
Error::authentication_error(MECH_NAME, "could not read identity token file")
})?;
let session_name = std::env::var("AWS_ROLE_SESSION_NAME")
.unwrap_or_else(|_| Alphanumeric.sample_string(&mut rand::thread_rng(), 10));
let query = rawdoc! {
"Action": "AssumeRoleWithWebIdentity",
"RoleSessionName": session_name,
"RoleArn": role_arn,
"WebIdentityToken": token,
"Version": "2011-06-15",
};
let response = http_client
.get("https://sts.amazonaws.com/")
.headers(&[("Accept", "application/json")])
.query(query)
.send::<Document>()
.await
.map_err(|_| Error::unknown_authentication_error(MECH_NAME))?;
let credential = response
.get_document("AssumeRoleWithWebIdentityResponse")
.and_then(|d| d.get_document("AssumeRoleWithWebIdentityResult"))
.and_then(|d| d.get_document("Credentials"))
.map_err(|_| Error::unknown_authentication_error(MECH_NAME))?
.to_owned();
Ok(bson::from_document(credential)?)
}
async fn get_from_ecs(relative_uri: String, http_client: &HttpClient) -> Result<Self> {
let uri = format!("http://{}/{}", AWS_ECS_IP, relative_uri);
http_client
.get(&uri)
.send()
.await
.map_err(|_| Error::unknown_authentication_error(MECH_NAME))
}
async fn get_from_ec2(http_client: &HttpClient) -> Result<Self> {
let temporary_token = http_client
.put(&format!("http://{}/latest/api/token", AWS_EC2_IP))
.headers(&[("X-aws-ec2-metadata-token-ttl-seconds", "30")])
.send_and_get_string()
.await
.map_err(|_| Error::unknown_authentication_error(MECH_NAME))?;
let role_name_uri = format!(
"http://{}/latest/meta-data/iam/security-credentials/",
AWS_EC2_IP
);
let role_name = http_client
.get(&role_name_uri)
.headers(&[("X-aws-ec2-metadata-token", &temporary_token[..])])
.send_and_get_string()
.await
.map_err(|_| Error::unknown_authentication_error(MECH_NAME))?;
let credential_uri = format!("{}/{}", role_name_uri, role_name);
http_client
.get(&credential_uri)
.headers(&[("X-aws-ec2-metadata-token", &temporary_token[..])])
.send()
.await
.map_err(|_| Error::unknown_authentication_error(MECH_NAME))
}
fn compute_authorization_header(
&self,
date: DateTime<Utc>,
host: &str,
server_nonce: &[u8],
) -> Result<String> {
let date_str = date.format(AWS_LONG_DATE_FMT).to_string();
let token = self
.session_token
.as_ref()
.map(|s| format!("x-amz-security-token:{}\n", s))
.unwrap_or_default();
let token_signed_header = if self.session_token.is_some() {
"x-amz-security-token;"
} else {
""
};
#[rustfmt::skip]
let signed_headers = format!(
"\
content-length;\
content-type;\
host;\
x-amz-date;\
{token_signed_header}\
x-mongodb-gs2-cb-flag;\
x-mongodb-server-nonce\
",
token_signed_header = token_signed_header,
);
let body = "Action=GetCallerIdentity&Version=2011-06-15";
let hashed_body = hex::encode(Sha256::digest(body.as_bytes()));
let nonce = base64::encode(server_nonce);
#[rustfmt::skip]
let request = format!(
"\
POST\n\
/\n\n\
content-length:43\n\
content-type:application/x-www-form-urlencoded\n\
host:{host}\n\
x-amz-date:{date}\n\
{token}\
x-mongodb-gs2-cb-flag:n\n\
x-mongodb-server-nonce:{nonce}\n\n\
{signed_headers}\n\
{hashed_body}\
",
host = host,
date = date_str,
token = token,
nonce = nonce,
signed_headers = signed_headers,
hashed_body = hashed_body,
);
let hashed_request = hex::encode(Sha256::digest(request.as_bytes()));
let small_date = date.format("%Y%m%d").to_string();
let region = if host == "sts.amazonaws.com" {
"us-east-1"
} else {
let parts: Vec<_> = host.split('.').collect();
parts.get(1).copied().unwrap_or("us-east-1")
};
#[rustfmt::skip]
let string_to_sign = format!(
"\
AWS4-HMAC-SHA256\n\
{full_date}\n\
{small_date}/{region}/sts/aws4_request\n\
{hashed_request}\
",
full_date = date_str,
small_date = small_date,
region = region,
hashed_request = hashed_request,
);
let first_hmac_key = format!("AWS4{}", self.secret_access_key);
let k_date =
auth::mac::<Hmac<Sha256>>(first_hmac_key.as_ref(), small_date.as_ref(), MECH_NAME)?;
let k_region = auth::mac::<Hmac<Sha256>>(k_date.as_ref(), region.as_ref(), MECH_NAME)?;
let k_service = auth::mac::<Hmac<Sha256>>(k_region.as_ref(), b"sts", MECH_NAME)?;
let k_signing = auth::mac::<Hmac<Sha256>>(k_service.as_ref(), b"aws4_request", MECH_NAME)?;
let signature_bytes =
auth::mac::<Hmac<Sha256>>(k_signing.as_ref(), string_to_sign.as_ref(), MECH_NAME)?;
let signature = hex::encode(signature_bytes);
#[rustfmt::skip]
let auth_header = format!(
"\
AWS4-HMAC-SHA256 \
Credential={access_key}/{small_date}/{region}/sts/aws4_request, \
SignedHeaders={signed_headers}, \
Signature={signature}\
",
access_key = self.access_key_id,
small_date = small_date,
region = region,
signed_headers = signed_headers,
signature = signature
);
Ok(auth_header)
}
#[cfg(feature = "in-use-encryption-unstable")]
pub(crate) fn access_key(&self) -> &str {
&self.access_key_id
}
#[cfg(feature = "in-use-encryption-unstable")]
pub(crate) fn secret_key(&self) -> &str {
&self.secret_access_key
}
#[cfg(feature = "in-use-encryption-unstable")]
pub(crate) fn session_token(&self) -> Option<&str> {
self.session_token.as_deref()
}
fn is_expired(&self) -> bool {
match self.expiration {
Some(expiration) => {
expiration.saturating_duration_since(bson::DateTime::now())
< Duration::from_secs(5 * 60)
}
None => true,
}
}
}
struct ServerFirst {
conversation_id: Bson,
server_nonce: Vec<u8>,
sts_host: String,
done: bool,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct ServerFirstPayload {
#[serde(rename = "s", with = "serde_bytes")]
server_nonce: Vec<u8>,
#[serde(rename = "h")]
sts_host: String,
}
impl ServerFirst {
fn parse(response: Document) -> Result<Self> {
let SaslResponse {
conversation_id,
payload,
done,
} = SaslResponse::parse(MECH_NAME, response)?;
let ServerFirstPayload {
server_nonce,
sts_host,
} = bson::from_slice(payload.as_slice())
.map_err(|_| Error::invalid_authentication_response(MECH_NAME))?;
Ok(Self {
conversation_id,
server_nonce,
sts_host,
done,
})
}
fn validate(&self, nonce: &[u8]) -> Result<()> {
if self.done {
Err(Error::authentication_error(
MECH_NAME,
"handshake terminated early",
))
} else if !self.server_nonce.starts_with(nonce) {
Err(Error::authentication_error(MECH_NAME, "mismatched nonce"))
} else if self.server_nonce.len() != 64 {
Err(Error::authentication_error(
MECH_NAME,
"incorrect length server nonce",
))
} else if self.sts_host.is_empty() {
Err(Error::authentication_error(
MECH_NAME,
"sts host must be non-empty",
))
} else if self.sts_host.as_bytes().len() > 255 {
Err(Error::authentication_error(
MECH_NAME,
"sts host cannot be more than 255 bytes",
))
} else if self.sts_host.split('.').any(|s| s.is_empty()) {
Err(Error::authentication_error(
MECH_NAME,
"sts host cannot contain empty labels",
))
} else {
Ok(())
}
}
}
#[cfg(test)]
pub(crate) mod test_utils {
use super::{AwsCredential, CACHED_CREDENTIAL};
pub(crate) async fn cached_credential() -> Option<AwsCredential> {
CACHED_CREDENTIAL.lock().await.clone()
}
pub(crate) async fn clear_cached_credential() {
*CACHED_CREDENTIAL.lock().await = None;
}
pub(crate) async fn poison_cached_credential() {
CACHED_CREDENTIAL
.lock()
.await
.as_mut()
.unwrap()
.access_key_id = "bad".into();
}
pub(crate) async fn cached_access_key_id() -> String {
cached_credential().await.unwrap().access_key_id
}
pub(crate) async fn cached_secret_access_key() -> String {
cached_credential().await.unwrap().secret_access_key
}
pub(crate) async fn cached_session_token() -> Option<String> {
cached_credential().await.unwrap().session_token
}
pub(crate) async fn cached_expiration() -> bson::DateTime {
cached_credential().await.unwrap().expiration.unwrap()
}
pub(crate) async fn set_cached_expiration(expiration: bson::DateTime) {
CACHED_CREDENTIAL.lock().await.as_mut().unwrap().expiration = Some(expiration);
}
}