use std::{any::Any, collections::HashMap, str, sync::Arc};
use axum::{
async_trait,
extract::FromRequestParts,
http::{
header::{self},
request::Parts,
StatusCode,
},
};
use sec::Secret;
use thiserror::Error;
use crate::{storage::ImageLocation, ImageDigest};
use super::{
www_authenticate::{self},
ContainerRegistry,
};
#[derive(Debug)]
pub enum Unverified {
UsernameAndPassword {
username: String,
password: Secret<String>,
},
NoCredentials,
}
impl Unverified {
#[inline(always)]
pub fn is_no_credentials(&self) -> bool {
matches!(self, Unverified::NoCredentials)
}
}
#[async_trait]
impl<S> FromRequestParts<S> for Unverified {
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(auth_header) = parts.headers.get(header::AUTHORIZATION) {
let (_unparsed, basic) = www_authenticate::basic_auth_response(auth_header.as_bytes())
.map_err(|_| StatusCode::BAD_REQUEST)?;
Ok(Unverified::UsernameAndPassword {
username: str::from_utf8(&basic.username)
.map_err(|_| StatusCode::BAD_REQUEST)?
.to_owned(),
password: Secret::new(
str::from_utf8(&basic.password)
.map_err(|_| StatusCode::BAD_REQUEST)?
.to_owned(),
),
})
} else {
Ok(Unverified::NoCredentials)
}
}
}
#[derive(Debug)]
pub struct ValidCredentials(pub Box<dyn Any + Send + Sync>);
impl ValidCredentials {
#[inline(always)]
pub fn new<T: Send + Sync + 'static>(inner: T) -> Self {
ValidCredentials(Box::new(inner))
}
pub fn extract_ref<T: 'static>(&self) -> &T {
self.0.downcast_ref::<T>().expect("could not downcast `ValidCredentials` into expected type - was auth provider called with the wrong set of credentials?")
}
}
#[async_trait]
impl FromRequestParts<Arc<ContainerRegistry>> for ValidCredentials {
type Rejection = StatusCode;
#[inline(always)]
async fn from_request_parts(
parts: &mut Parts,
state: &Arc<ContainerRegistry>,
) -> Result<Self, Self::Rejection> {
let unverified = Unverified::from_request_parts(parts, state).await?;
match state.auth_provider.check_credentials(&unverified).await {
Some(creds) => Ok(creds),
None => Err(StatusCode::UNAUTHORIZED),
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[repr(u8)]
pub enum Permissions {
NoAccess = 0,
WriteOnly = 2,
ReadOnly = 4,
ReadWrite = 6,
}
impl Permissions {
#[inline(always)]
#[must_use = "should not check read permissions and discard the result"]
pub fn has_read_permission(self) -> bool {
match self {
Permissions::NoAccess | Permissions::WriteOnly => false,
Permissions::ReadOnly | Permissions::ReadWrite => true,
}
}
#[inline(always)]
#[must_use = "should not check write permissions and discard the result"]
pub fn has_write_permission(self) -> bool {
match self {
Permissions::NoAccess | Permissions::ReadOnly => false,
Permissions::WriteOnly | Permissions::ReadWrite => true,
}
}
#[inline(always)]
pub fn require_read(self) -> Result<(), MissingPermission> {
if !self.has_read_permission() {
Err(MissingPermission)
} else {
Ok(())
}
}
#[inline(always)]
pub fn require_write(self) -> Result<(), MissingPermission> {
if !self.has_write_permission() {
Err(MissingPermission)
} else {
Ok(())
}
}
}
#[derive(Debug, Error)]
#[error("not permitted")]
pub struct MissingPermission;
#[async_trait]
pub trait AuthProvider: Send + Sync {
async fn check_credentials(&self, unverified: &Unverified) -> Option<ValidCredentials>;
async fn image_permissions(
&self,
creds: &ValidCredentials,
image: &ImageLocation,
) -> Permissions;
async fn blob_permissions(&self, creds: &ValidCredentials, blob: &ImageDigest) -> Permissions;
}
#[derive(Debug)]
pub struct Anonymous<A> {
anon_permissions: Permissions,
inner: A,
}
impl<A> Anonymous<A> {
pub fn new(anon_permissions: Permissions, inner: A) -> Self {
Self {
anon_permissions,
inner,
}
}
}
#[derive(Debug)]
enum AnonCreds {
Anonymous,
Valid(ValidCredentials),
}
#[async_trait]
impl<A> AuthProvider for Anonymous<A>
where
A: AuthProvider,
{
async fn check_credentials(&self, unverified: &Unverified) -> Option<ValidCredentials> {
match unverified {
Unverified::NoCredentials => Some(ValidCredentials::new(AnonCreds::Anonymous)),
_other => self.inner.check_credentials(unverified).await,
}
}
async fn image_permissions(
&self,
creds: &ValidCredentials,
image: &ImageLocation,
) -> Permissions {
match creds.extract_ref::<AnonCreds>() {
AnonCreds::Anonymous => self.anon_permissions,
_other => self.inner.image_permissions(creds, image).await,
}
}
async fn blob_permissions(&self, creds: &ValidCredentials, blob: &ImageDigest) -> Permissions {
match creds.extract_ref::<AnonCreds>() {
AnonCreds::Anonymous => self.anon_permissions,
_other => self.inner.blob_permissions(creds, blob).await,
}
}
}
#[async_trait]
impl AuthProvider for Permissions {
#[inline(always)]
async fn check_credentials(&self, unverified: &Unverified) -> Option<ValidCredentials> {
match unverified {
Unverified::NoCredentials => None,
_other => Some(ValidCredentials::new(())),
}
}
#[inline(always)]
async fn image_permissions(
&self,
_creds: &ValidCredentials,
_image: &ImageLocation,
) -> Permissions {
*self
}
#[inline(always)]
async fn blob_permissions(
&self,
_creds: &ValidCredentials,
_blob: &ImageDigest,
) -> Permissions {
*self
}
}
#[async_trait]
impl AuthProvider for HashMap<String, Secret<String>> {
async fn check_credentials(&self, unverified: &Unverified) -> Option<ValidCredentials> {
match unverified {
Unverified::UsernameAndPassword {
username: unverified_username,
password: unverified_password,
} => {
if let Some(correct_password) = self.get(unverified_username) {
if constant_time_eq::constant_time_eq(
correct_password.reveal().as_bytes(),
unverified_password.reveal().as_bytes(),
) {
return Some(ValidCredentials::new(unverified_username.clone()));
}
}
None
}
Unverified::NoCredentials => None,
}
}
#[inline(always)]
async fn image_permissions(
&self,
_creds: &ValidCredentials,
_image: &ImageLocation,
) -> Permissions {
Permissions::ReadWrite
}
#[inline(always)]
async fn blob_permissions(
&self,
_creds: &ValidCredentials,
_blob: &ImageDigest,
) -> Permissions {
Permissions::ReadWrite
}
}
#[async_trait]
impl<T> AuthProvider for Box<T>
where
T: AuthProvider,
{
#[inline(always)]
async fn check_credentials(&self, unverified: &Unverified) -> Option<ValidCredentials> {
<T as AuthProvider>::check_credentials(self, unverified).await
}
#[inline(always)]
async fn image_permissions(
&self,
_creds: &ValidCredentials,
_image: &ImageLocation,
) -> Permissions {
Permissions::ReadWrite
}
#[inline(always)]
async fn blob_permissions(
&self,
_creds: &ValidCredentials,
_blob: &ImageDigest,
) -> Permissions {
Permissions::ReadWrite
}
}
#[async_trait]
impl<T> AuthProvider for Arc<T>
where
T: AuthProvider,
{
#[inline(always)]
async fn check_credentials(&self, unverified: &Unverified) -> Option<ValidCredentials> {
<T as AuthProvider>::check_credentials(self, unverified).await
}
#[inline(always)]
async fn image_permissions(
&self,
_creds: &ValidCredentials,
_image: &ImageLocation,
) -> Permissions {
Permissions::ReadWrite
}
#[inline(always)]
async fn blob_permissions(
&self,
_creds: &ValidCredentials,
_blob: &ImageDigest,
) -> Permissions {
Permissions::ReadWrite
}
}
#[async_trait]
impl AuthProvider for Secret<String> {
#[inline(always)]
async fn check_credentials(&self, unverified: &Unverified) -> Option<ValidCredentials> {
match unverified {
Unverified::UsernameAndPassword {
username: _,
password,
} => {
if constant_time_eq::constant_time_eq(
password.reveal().as_bytes(),
self.reveal().as_bytes(),
) {
Some(ValidCredentials::new(()))
} else {
None
}
}
Unverified::NoCredentials => None,
}
}
#[inline(always)]
async fn image_permissions(
&self,
_creds: &ValidCredentials,
_image: &ImageLocation,
) -> Permissions {
Permissions::ReadWrite
}
#[inline(always)]
async fn blob_permissions(
&self,
_creds: &ValidCredentials,
_blob: &ImageDigest,
) -> Permissions {
Permissions::ReadWrite
}
}