use crate::resource::{RequestContext, TenantContext};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::Arc;
use tokio::sync::RwLock;
pub trait AuthState: Send + Sync + 'static {}
#[derive(Debug, Clone, Copy)]
pub struct Unauthenticated;
impl AuthState for Unauthenticated {}
#[derive(Debug, Clone, Copy)]
pub struct Authenticated;
impl AuthState for Authenticated {}
#[derive(Debug, Clone)]
pub struct Credential<S: AuthState> {
pub(crate) value: String,
pub(crate) _phantom: PhantomData<S>,
}
impl Credential<Unauthenticated> {
pub fn new(value: impl Into<String>) -> Self {
Self {
value: value.into(),
_phantom: PhantomData,
}
}
pub fn raw_value(&self) -> &str {
&self.value
}
}
impl Credential<Authenticated> {
#[allow(dead_code)]
pub(crate) fn authenticated(value: String) -> Self {
Self {
value,
_phantom: PhantomData,
}
}
pub fn authenticated_value(&self) -> &str {
&self.value
}
}
#[derive(Debug, Clone)]
pub struct AuthenticationWitness {
pub(crate) tenant_context: TenantContext,
pub(crate) credential_hash: String,
pub(crate) validated_at: chrono::DateTime<chrono::Utc>,
}
impl AuthenticationWitness {
pub(crate) fn new(tenant_context: TenantContext, credential_hash: String) -> Self {
Self {
tenant_context,
credential_hash,
validated_at: chrono::Utc::now(),
}
}
pub fn tenant_context(&self) -> &TenantContext {
&self.tenant_context
}
pub fn validated_at(&self) -> chrono::DateTime<chrono::Utc> {
self.validated_at
}
pub fn credential_hash(&self) -> &str {
&self.credential_hash
}
}
#[derive(Debug, Clone)]
pub struct TenantAuthority {
witness: AuthenticationWitness,
}
impl TenantAuthority {
pub fn from_witness(witness: AuthenticationWitness) -> Self {
Self { witness }
}
pub fn witness(&self) -> &AuthenticationWitness {
&self.witness
}
pub fn tenant_id(&self) -> &str {
&self.witness.tenant_context.tenant_id
}
pub fn client_id(&self) -> &str {
&self.witness.tenant_context.client_id
}
}
#[derive(Debug)]
pub struct LinearCredential {
inner: Option<Credential<Unauthenticated>>,
}
impl LinearCredential {
pub fn new(value: impl Into<String>) -> Self {
Self {
inner: Some(Credential::new(value)),
}
}
pub fn consume(mut self) -> Credential<Unauthenticated> {
self.inner.take().expect("Credential already consumed")
}
pub fn is_consumed(&self) -> bool {
self.inner.is_none()
}
}
#[derive(Debug)]
pub struct ConsumedCredential {
_private: (),
}
impl ConsumedCredential {
pub(crate) fn new() -> Self {
Self { _private: () }
}
}
#[derive(Debug)]
pub enum AuthenticationResult {
Success {
witness: AuthenticationWitness,
consumed: ConsumedCredential,
},
Failed {
consumed: ConsumedCredential,
},
}
#[derive(Debug, Clone)]
pub struct AuthenticatedRequestContext {
inner: RequestContext,
authority: TenantAuthority,
}
impl AuthenticatedRequestContext {
pub fn from_witness(witness: AuthenticationWitness) -> Self {
let tenant_context = witness.tenant_context().clone();
let authority = TenantAuthority::from_witness(witness);
let inner = RequestContext::with_tenant_generated_id(tenant_context);
Self { inner, authority }
}
pub fn with_request_id(witness: AuthenticationWitness, request_id: String) -> Self {
let tenant_context = witness.tenant_context().clone();
let authority = TenantAuthority::from_witness(witness);
let inner = RequestContext::with_tenant(request_id, tenant_context);
Self { inner, authority }
}
pub fn request_context(&self) -> &RequestContext {
&self.inner
}
pub fn authority(&self) -> &TenantAuthority {
&self.authority
}
pub fn tenant_id(&self) -> &str {
self.authority.tenant_id()
}
pub fn client_id(&self) -> &str {
self.authority.client_id()
}
pub fn request_id(&self) -> &str {
&self.inner.request_id
}
}
#[derive(Debug, Clone)]
pub struct AuthenticatedContext {
authority: TenantAuthority,
}
impl AuthenticatedContext {
pub fn from_witness(witness: AuthenticationWitness) -> Self {
Self {
authority: TenantAuthority::from_witness(witness),
}
}
pub fn to_request_context(&self) -> AuthenticatedRequestContext {
let witness = self.authority.witness().clone();
AuthenticatedRequestContext::from_witness(witness)
}
pub fn authority(&self) -> &TenantAuthority {
&self.authority
}
}
#[derive(Debug, Clone)]
pub struct AuthenticationValidator {
credentials: Arc<RwLock<HashMap<String, TenantContext>>>,
}
impl AuthenticationValidator {
pub fn new() -> Self {
Self {
credentials: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register_credential(&self, credential: &str, tenant_context: TenantContext) {
let mut creds = self.credentials.write().await;
creds.insert(credential.to_string(), tenant_context);
}
pub async fn authenticate(
&self,
credential: LinearCredential,
) -> Result<AuthenticationWitness, AuthenticationError> {
let raw_cred = credential.consume();
let _consumed_proof = ConsumedCredential::new();
let creds = self.credentials.read().await;
if let Some(tenant_context) = creds.get(raw_cred.raw_value()) {
let mut hasher = Sha256::new();
hasher.update(raw_cred.raw_value().as_bytes());
let credential_hash = format!("{:x}", hasher.finalize());
Ok(AuthenticationWitness::new(
tenant_context.clone(),
credential_hash,
))
} else {
Err(AuthenticationError::InvalidCredential)
}
}
}
impl Default for AuthenticationValidator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, thiserror::Error)]
pub enum AuthenticationError {
#[error("Invalid credential provided")]
InvalidCredential,
#[error("Credential has been revoked")]
CredentialRevoked,
#[error("Authentication system unavailable")]
SystemUnavailable,
}
pub trait AuthenticatedProvider {
type Error: std::error::Error + Send + Sync + 'static;
fn list_resources_authenticated(
&self,
resource_type: &str,
context: &AuthenticatedRequestContext,
) -> impl std::future::Future<Output = Result<Vec<crate::resource::Resource>, Self::Error>> + Send;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::resource::TenantContext;
#[tokio::test]
async fn test_linear_credential_consumption() {
let cred = LinearCredential::new("test-key");
assert!(!cred.is_consumed());
let _raw = cred.consume();
}
#[tokio::test]
async fn test_authentication_flow() {
let validator = AuthenticationValidator::new();
let tenant_ctx = TenantContext::new("test-tenant".to_string(), "test-client".to_string());
validator.register_credential("valid-key", tenant_ctx).await;
let cred = LinearCredential::new("valid-key");
let witness = validator.authenticate(cred).await.unwrap();
let auth_context = AuthenticatedRequestContext::from_witness(witness);
assert_eq!(auth_context.tenant_id(), "test-tenant");
assert_eq!(auth_context.client_id(), "test-client");
}
#[tokio::test]
async fn test_invalid_authentication() {
let validator = AuthenticationValidator::new();
let cred = LinearCredential::new("invalid-key");
let result = validator.authenticate(cred).await;
assert!(result.is_err());
}
#[test]
fn test_type_level_authentication_states() {
let unauth = Credential::<Unauthenticated>::new("test");
assert_eq!(unauth.raw_value(), "test");
}
#[test]
fn test_witness_types() {
let tenant_ctx = TenantContext::new("test".to_string(), "client".to_string());
let witness = AuthenticationWitness::new(tenant_ctx, "hash".to_string());
let authority = TenantAuthority::from_witness(witness);
assert_eq!(authority.tenant_id(), "test");
assert_eq!(authority.client_id(), "client");
}
}