use std::borrow::Cow;
use std::fmt;
use std::io::Read;
use std::io::Write;
use std::str::FromStr;
use base64::prelude::BASE64_STANDARD;
use base64::read::DecoderReader;
use base64::write::EncoderWriter;
use http::Uri;
use reqsign::aws::DefaultSigner as AwsDefaultSigner;
use reqsign::azure::DefaultSigner as AzureDefaultSigner;
use reqsign::google::DefaultSigner as GcsDefaultSigner;
use reqwest::Request;
use reqwest::header::{HeaderName, HeaderValue};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use url::Url;
use uv_netrc::Netrc;
use uv_redacted::DisplaySafeUrl;
use uv_static::EnvVars;
const AZURE_STORAGE_VERSION: &str = "2023-11-03";
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Credentials {
Basic {
username: Username,
password: Option<Password>,
},
Bearer {
token: Token,
},
}
#[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash, Default, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Username(Option<String>);
impl Username {
pub(crate) fn new(value: Option<String>) -> Self {
Self(value.filter(|s| !s.is_empty()))
}
pub(crate) fn none() -> Self {
Self::new(None)
}
pub(crate) fn is_none(&self) -> bool {
self.0.is_none()
}
pub(crate) fn is_some(&self) -> bool {
self.0.is_some()
}
pub(crate) fn as_deref(&self) -> Option<&str> {
self.0.as_deref()
}
}
impl From<String> for Username {
fn from(value: String) -> Self {
Self::new(Some(value))
}
}
impl From<Option<String>> for Username {
fn from(value: Option<String>) -> Self {
Self::new(value)
}
}
#[derive(Clone, PartialEq, Eq, Ord, PartialOrd, Hash, Default, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Password(String);
impl Password {
pub fn new(password: String) -> Self {
Self(password)
}
pub fn as_str(&self) -> &str {
self.0.as_str()
}
pub fn into_string(self) -> String {
self.0
}
}
impl fmt::Debug for Password {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "****")
}
}
#[derive(Clone, PartialEq, Eq, Ord, PartialOrd, Hash, Default, Deserialize)]
#[serde(transparent)]
pub struct Token(Vec<u8>);
impl Token {
pub fn new(token: Vec<u8>) -> Self {
Self(token)
}
pub fn as_slice(&self) -> &[u8] {
self.0.as_slice()
}
pub fn into_bytes(self) -> Vec<u8> {
self.0
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
impl fmt::Debug for Token {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "****")
}
}
impl Credentials {
#[allow(dead_code)]
pub fn basic(username: Option<String>, password: Option<String>) -> Self {
Self::Basic {
username: Username::new(username),
password: password.map(Password),
}
}
#[allow(dead_code)]
pub fn bearer(token: Vec<u8>) -> Self {
Self::Bearer {
token: Token::new(token),
}
}
pub fn username(&self) -> Option<&str> {
match self {
Self::Basic { username, .. } => username.as_deref(),
Self::Bearer { .. } => None,
}
}
pub(crate) fn to_username(&self) -> Username {
match self {
Self::Basic { username, .. } => username.clone(),
Self::Bearer { .. } => Username::none(),
}
}
pub(crate) fn as_username(&self) -> Cow<'_, Username> {
match self {
Self::Basic { username, .. } => Cow::Borrowed(username),
Self::Bearer { .. } => Cow::Owned(Username::none()),
}
}
pub fn password(&self) -> Option<&str> {
match self {
Self::Basic { password, .. } => password.as_ref().map(Password::as_str),
Self::Bearer { .. } => None,
}
}
pub fn is_authenticated(&self) -> bool {
match self {
Self::Basic {
username: _,
password,
} => password.is_some(),
Self::Bearer { token } => !token.is_empty(),
}
}
pub(crate) fn is_empty(&self) -> bool {
match self {
Self::Basic { username, password } => username.is_none() && password.is_none(),
Self::Bearer { token } => token.is_empty(),
}
}
pub(crate) fn from_netrc(
netrc: &Netrc,
url: &DisplaySafeUrl,
username: Option<&str>,
) -> Option<Self> {
let host = url.host_str()?;
let entry = netrc
.hosts
.get(host)
.or_else(|| netrc.hosts.get("default"))?;
if username.is_some_and(|username| username != entry.login) {
return None;
}
Some(Self::Basic {
username: Username::new(Some(entry.login.clone())),
password: Some(Password(entry.password.clone())),
})
}
pub fn from_url(url: &Url) -> Option<Self> {
if url.username().is_empty() && url.password().is_none() {
return None;
}
Some(Self::Basic {
username: if url.username().is_empty() {
None
} else {
Some(
percent_encoding::percent_decode_str(url.username())
.decode_utf8()
.expect("An encoded username should always decode")
.into_owned(),
)
}
.into(),
password: url.password().map(|password| {
Password(
percent_encoding::percent_decode_str(password)
.decode_utf8()
.expect("An encoded password should always decode")
.into_owned(),
)
}),
})
}
pub fn from_env(name: impl AsRef<str>) -> Option<Self> {
let username = std::env::var(EnvVars::index_username(name.as_ref())).ok();
let password = std::env::var(EnvVars::index_password(name.as_ref())).ok();
if username.is_none() && password.is_none() {
None
} else {
Some(Self::basic(username, password))
}
}
pub(crate) fn from_request(request: &Request) -> Option<Self> {
Self::from_url(request.url()).or(
request
.headers()
.get(reqwest::header::AUTHORIZATION)
.map(Self::from_header_value)?,
)
}
pub(crate) fn from_header_value(header: &HeaderValue) -> Option<Self> {
if let Some(mut value) = header.as_bytes().strip_prefix(b"Basic ") {
let mut decoder = DecoderReader::new(&mut value, &BASE64_STANDARD);
let mut buf = String::new();
decoder
.read_to_string(&mut buf)
.expect("HTTP Basic Authentication should be base64 encoded");
let (username, password) = buf
.split_once(':')
.expect("HTTP Basic Authentication should include a `:` separator");
let username = if username.is_empty() {
None
} else {
Some(username.to_string())
};
let password = if password.is_empty() {
None
} else {
Some(password.to_string())
};
return Some(Self::Basic {
username: Username::new(username),
password: password.map(Password),
});
}
if let Some(token) = header.as_bytes().strip_prefix(b"Bearer ") {
return Some(Self::Bearer {
token: Token::new(token.to_vec()),
});
}
None
}
pub fn to_header_value(&self) -> HeaderValue {
match self {
Self::Basic { .. } => {
let mut buf = b"Basic ".to_vec();
{
let mut encoder = EncoderWriter::new(&mut buf, &BASE64_STANDARD);
write!(encoder, "{}:", self.username().unwrap_or_default())
.expect("Write to base64 encoder should succeed");
if let Some(password) = self.password() {
write!(encoder, "{password}")
.expect("Write to base64 encoder should succeed");
}
}
let mut header =
HeaderValue::from_bytes(&buf).expect("base64 is always valid HeaderValue");
header.set_sensitive(true);
header
}
Self::Bearer { token } => {
let mut header = HeaderValue::from_bytes(&[b"Bearer ", token.as_slice()].concat())
.expect("Bearer token is always valid HeaderValue");
header.set_sensitive(true);
header
}
}
}
#[must_use]
pub fn apply(&self, mut url: DisplaySafeUrl) -> DisplaySafeUrl {
if let Some(username) = self.username() {
let _ = url.set_username(username);
}
if let Some(password) = self.password() {
let _ = url.set_password(Some(password));
}
url
}
#[must_use]
pub fn authenticate(&self, mut request: Request) -> Request {
request
.headers_mut()
.insert(reqwest::header::AUTHORIZATION, Self::to_header_value(self));
request
}
}
#[derive(Clone, Debug)]
pub(crate) enum Authentication {
Credentials(Credentials),
AwsSigner(AwsDefaultSigner),
GcsSigner(GcsDefaultSigner),
AzureSigner(AzureDefaultSigner),
}
#[derive(Debug, Error)]
pub(crate) enum AuthenticationError {
#[error("Failed to convert request URL to URI")]
InvalidUri(#[from] http::uri::InvalidUri),
#[error("Failed to build request for {provider} signing")]
BuildRequest {
provider: &'static str,
#[source]
source: http::Error,
},
#[error("Failed to sign request with {provider} credentials")]
Sign {
provider: &'static str,
#[source]
source: reqsign::Error,
},
}
impl PartialEq for Authentication {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Credentials(a), Self::Credentials(b)) => a == b,
(Self::AwsSigner(..), Self::AwsSigner(..)) => true,
(Self::GcsSigner(..), Self::GcsSigner(..)) => true,
(Self::AzureSigner(..), Self::AzureSigner(..)) => true,
_ => false,
}
}
}
impl Eq for Authentication {}
impl From<Credentials> for Authentication {
fn from(credentials: Credentials) -> Self {
Self::Credentials(credentials)
}
}
impl From<AwsDefaultSigner> for Authentication {
fn from(signer: AwsDefaultSigner) -> Self {
Self::AwsSigner(signer)
}
}
impl From<GcsDefaultSigner> for Authentication {
fn from(signer: GcsDefaultSigner) -> Self {
Self::GcsSigner(signer)
}
}
impl From<AzureDefaultSigner> for Authentication {
fn from(signer: AzureDefaultSigner) -> Self {
Self::AzureSigner(signer)
}
}
impl Authentication {
pub(crate) fn password(&self) -> Option<&str> {
match self {
Self::Credentials(credentials) => credentials.password(),
Self::AwsSigner(..) | Self::GcsSigner(..) | Self::AzureSigner(..) => None,
}
}
pub(crate) fn username(&self) -> Option<&str> {
match self {
Self::Credentials(credentials) => credentials.username(),
Self::AwsSigner(..) | Self::GcsSigner(..) | Self::AzureSigner(..) => None,
}
}
pub(crate) fn as_username(&self) -> Cow<'_, Username> {
match self {
Self::Credentials(credentials) => credentials.as_username(),
Self::AwsSigner(..) | Self::GcsSigner(..) | Self::AzureSigner(..) => {
Cow::Owned(Username::none())
}
}
}
pub(crate) fn to_username(&self) -> Username {
match self {
Self::Credentials(credentials) => credentials.to_username(),
Self::AwsSigner(..) | Self::GcsSigner(..) | Self::AzureSigner(..) => Username::none(),
}
}
pub(crate) fn is_authenticated(&self) -> bool {
match self {
Self::Credentials(credentials) => credentials.is_authenticated(),
Self::AwsSigner(..) | Self::GcsSigner(..) | Self::AzureSigner(..) => true,
}
}
pub(crate) fn is_empty(&self) -> bool {
match self {
Self::Credentials(credentials) => credentials.is_empty(),
Self::AwsSigner(..) | Self::GcsSigner(..) | Self::AzureSigner(..) => false,
}
}
pub(crate) async fn authenticate(
&self,
mut request: Request,
) -> Result<Request, AuthenticationError> {
match self {
Self::Credentials(credentials) => Ok(credentials.authenticate(request)),
Self::AwsSigner(signer) => {
let uri = Uri::from_str(request.url().as_str())?;
let mut http_req = http::Request::builder()
.method(request.method().clone())
.uri(uri)
.body(())
.map_err(|source| AuthenticationError::BuildRequest {
provider: "AWS",
source,
})?;
*http_req.headers_mut() = request.headers().clone();
let (mut parts, ()) = http_req.into_parts();
signer.sign(&mut parts, None).await.map_err(|source| {
AuthenticationError::Sign {
provider: "AWS",
source,
}
})?;
request.headers_mut().extend(parts.headers);
if let Some(path_and_query) = parts.uri.path_and_query() {
request.url_mut().set_path(path_and_query.path());
request.url_mut().set_query(path_and_query.query());
}
Ok(request)
}
Self::GcsSigner(signer) => {
let uri = Uri::from_str(request.url().as_str())?;
let mut http_req = http::Request::builder()
.method(request.method().clone())
.uri(uri)
.body(())
.map_err(|source| AuthenticationError::BuildRequest {
provider: "GCS",
source,
})?;
*http_req.headers_mut() = request.headers().clone();
let (mut parts, ()) = http_req.into_parts();
signer.sign(&mut parts, None).await.map_err(|source| {
AuthenticationError::Sign {
provider: "GCS",
source,
}
})?;
request.headers_mut().extend(parts.headers);
if let Some(path_and_query) = parts.uri.path_and_query() {
request.url_mut().set_path(path_and_query.path());
request.url_mut().set_query(path_and_query.query());
}
Ok(request)
}
Self::AzureSigner(signer) => {
let uri = Uri::from_str(request.url().as_str())?;
let mut http_req = http::Request::builder()
.method(request.method().clone())
.uri(uri)
.body(())
.map_err(|source| AuthenticationError::BuildRequest {
provider: "Azure",
source,
})?;
*http_req.headers_mut() = request.headers().clone();
http_req
.headers_mut()
.entry(HeaderName::from_static("x-ms-version"))
.or_insert(HeaderValue::from_static(AZURE_STORAGE_VERSION));
let (mut parts, ()) = http_req.into_parts();
signer.sign(&mut parts, None).await.map_err(|source| {
AuthenticationError::Sign {
provider: "Azure",
source,
}
})?;
request.headers_mut().extend(parts.headers);
if let Some(path_and_query) = parts.uri.path_and_query() {
request.url_mut().set_path(path_and_query.path());
request.url_mut().set_query(path_and_query.query());
}
Ok(request)
}
}
}
}
#[cfg(test)]
mod tests {
use insta::assert_debug_snapshot;
use reqsign::aws::Credential as AwsCredential;
use reqsign::azure::Credential as AzureCredential;
use reqsign::{Context, ProvideCredential};
use super::*;
#[derive(Debug)]
struct EmptyAwsCredentialProvider;
impl ProvideCredential for EmptyAwsCredentialProvider {
type Credential = AwsCredential;
async fn provide_credential(
&self,
_ctx: &Context,
) -> reqsign::Result<Option<Self::Credential>> {
Ok(None)
}
}
#[derive(Debug)]
struct EmptyAzureCredentialProvider;
impl ProvideCredential for EmptyAzureCredentialProvider {
type Credential = AzureCredential;
async fn provide_credential(
&self,
_ctx: &Context,
) -> reqsign::Result<Option<Self::Credential>> {
Ok(None)
}
}
#[test]
fn from_url_no_credentials() {
let url = &Url::parse("https://example.com/simple/first/").unwrap();
assert_eq!(Credentials::from_url(url), None);
}
#[test]
fn from_url_username_and_password() {
let url = &Url::parse("https://example.com/simple/first/").unwrap();
let mut auth_url = url.clone();
auth_url.set_username("user").unwrap();
auth_url.set_password(Some("password")).unwrap();
let credentials = Credentials::from_url(&auth_url).unwrap();
assert_eq!(credentials.username(), Some("user"));
assert_eq!(credentials.password(), Some("password"));
}
#[test]
fn from_url_no_username() {
let url = &Url::parse("https://example.com/simple/first/").unwrap();
let mut auth_url = url.clone();
auth_url.set_password(Some("password")).unwrap();
let credentials = Credentials::from_url(&auth_url).unwrap();
assert_eq!(credentials.username(), None);
assert_eq!(credentials.password(), Some("password"));
}
#[test]
fn from_url_empty_username_with_password() {
let url = Url::parse("https://:token@example.com/simple/first/").unwrap();
let credentials = Credentials::from_url(&url).unwrap();
assert_eq!(credentials.username(), None);
assert_eq!(credentials.password(), Some("token"));
assert!(
credentials.is_authenticated(),
"URL with empty username but password should be considered authenticated"
);
}
#[test]
fn from_url_no_password() {
let url = &Url::parse("https://example.com/simple/first/").unwrap();
let mut auth_url = url.clone();
auth_url.set_username("user").unwrap();
let credentials = Credentials::from_url(&auth_url).unwrap();
assert_eq!(credentials.username(), Some("user"));
assert_eq!(credentials.password(), None);
}
#[test]
fn authenticated_request_from_url() {
let url = Url::parse("https://example.com/simple/first/").unwrap();
let mut auth_url = url.clone();
auth_url.set_username("user").unwrap();
auth_url.set_password(Some("password")).unwrap();
let credentials = Credentials::from_url(&auth_url).unwrap();
let mut request = Request::new(reqwest::Method::GET, url);
request = credentials.authenticate(request);
let mut header = request
.headers()
.get(reqwest::header::AUTHORIZATION)
.expect("Authorization header should be set")
.clone();
header.set_sensitive(false);
assert_debug_snapshot!(header, @r#""Basic dXNlcjpwYXNzd29yZA==""#);
assert_eq!(Credentials::from_header_value(&header), Some(credentials));
}
#[test]
fn authenticated_request_from_url_with_percent_encoded_user() {
let url = Url::parse("https://example.com/simple/first/").unwrap();
let mut auth_url = url.clone();
auth_url.set_username("user@domain").unwrap();
auth_url.set_password(Some("password")).unwrap();
let credentials = Credentials::from_url(&auth_url).unwrap();
let mut request = Request::new(reqwest::Method::GET, url);
request = credentials.authenticate(request);
let mut header = request
.headers()
.get(reqwest::header::AUTHORIZATION)
.expect("Authorization header should be set")
.clone();
header.set_sensitive(false);
assert_debug_snapshot!(header, @r#""Basic dXNlckBkb21haW46cGFzc3dvcmQ=""#);
assert_eq!(Credentials::from_header_value(&header), Some(credentials));
}
#[test]
fn authenticated_request_from_url_with_percent_encoded_password() {
let url = Url::parse("https://example.com/simple/first/").unwrap();
let mut auth_url = url.clone();
auth_url.set_username("user").unwrap();
auth_url.set_password(Some("password==")).unwrap();
let credentials = Credentials::from_url(&auth_url).unwrap();
let mut request = Request::new(reqwest::Method::GET, url);
request = credentials.authenticate(request);
let mut header = request
.headers()
.get(reqwest::header::AUTHORIZATION)
.expect("Authorization header should be set")
.clone();
header.set_sensitive(false);
assert_debug_snapshot!(header, @r#""Basic dXNlcjpwYXNzd29yZD09""#);
assert_eq!(Credentials::from_header_value(&header), Some(credentials));
}
#[tokio::test]
async fn authenticated_request_with_azure_signer() {
let signer = reqsign::azure::default_signer().with_credential_provider(
reqsign::azure::StaticCredentialProvider::new_bearer_token("token"),
);
let authentication = Authentication::from(signer);
let request = Request::new(
reqwest::Method::GET,
Url::parse("https://account.blob.core.windows.net/container/blob.whl").unwrap(),
);
let request = authentication.authenticate(request).await.unwrap();
let authorization = request
.headers()
.get(reqwest::header::AUTHORIZATION)
.expect("Authorization header should be set");
assert_eq!(authorization.to_str().unwrap(), "Bearer token");
assert!(request.headers().contains_key("x-ms-date"));
assert_eq!(
request
.headers()
.get("x-ms-version")
.expect("x-ms-version header should be set")
.to_str()
.unwrap(),
AZURE_STORAGE_VERSION
);
}
#[tokio::test]
async fn authenticated_request_with_aws_signer_missing_credentials() {
let signer = reqsign::aws::default_signer("s3", "us-east-1")
.with_credential_provider(EmptyAwsCredentialProvider);
let authentication = Authentication::from(signer);
let request = Request::new(
reqwest::Method::GET,
Url::parse("https://s3.amazonaws.com/bucket/blob.whl").unwrap(),
);
let err = authentication.authenticate(request).await.unwrap_err();
insta::assert_snapshot!(
err.to_string(),
@"Failed to sign request with AWS credentials"
);
}
#[tokio::test]
async fn authenticated_request_with_azure_signer_missing_credentials() {
let signer =
reqsign::azure::default_signer().with_credential_provider(EmptyAzureCredentialProvider);
let authentication = Authentication::from(signer);
let request = Request::new(
reqwest::Method::GET,
Url::parse("https://account.blob.core.windows.net/container/blob.whl").unwrap(),
);
let err = authentication.authenticate(request).await.unwrap_err();
insta::assert_snapshot!(
err.to_string(),
@"Failed to sign request with Azure credentials"
);
}
#[test]
fn test_password_redaction() {
let credentials =
Credentials::basic(Some(String::from("user")), Some(String::from("password")));
insta::assert_compact_debug_snapshot!(credentials, @r#"Basic { username: Username(Some("user")), password: Some(****) }"#);
}
#[test]
fn test_bearer_token_redaction() {
let token = "super_secret_token";
let credentials = Credentials::bearer(token.into());
insta::assert_compact_debug_snapshot!(credentials, @"Bearer { token: **** }");
}
}