use anyhow::Result;
use async_trait::async_trait;
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct CredentialContext {
pub properties: HashMap<String, String>,
}
impl CredentialContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_property(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.properties.insert(key.into(), value.into());
self
}
pub fn get(&self, key: &str) -> Option<&str> {
self.properties.get(key).map(|s| s.as_str())
}
}
#[async_trait]
pub trait IdentityProvider: Send + Sync {
async fn get_credentials(&self, context: &CredentialContext) -> Result<Credentials>;
fn clone_box(&self) -> Box<dyn IdentityProvider>;
}
impl Clone for Box<dyn IdentityProvider> {
fn clone(&self) -> Self {
self.clone_box()
}
}
#[derive(Clone, PartialEq, Eq)]
pub enum Credentials {
UsernamePassword { username: String, password: String },
Token { username: String, token: String },
Certificate {
cert_pem: String,
key_pem: String,
username: Option<String>,
},
}
impl std::fmt::Debug for Credentials {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Credentials::UsernamePassword { username, .. } => f
.debug_struct("UsernamePassword")
.field("username", username)
.field("password", &"[REDACTED]")
.finish(),
Credentials::Token { username, .. } => f
.debug_struct("Token")
.field("username", username)
.field("token", &"[REDACTED]")
.finish(),
Credentials::Certificate { username, .. } => f
.debug_struct("Certificate")
.field("cert_pem", &"[REDACTED]")
.field("key_pem", &"[REDACTED]")
.field("username", username)
.finish(),
}
}
}
impl Credentials {
pub fn try_into_auth_pair(self) -> std::result::Result<(String, String), Self> {
match self {
Credentials::UsernamePassword { username, password } => Ok((username, password)),
Credentials::Token { username, token } => Ok((username, token)),
other => Err(other),
}
}
pub fn try_into_certificate(
self,
) -> std::result::Result<(String, String, Option<String>), Self> {
match self {
Credentials::Certificate {
cert_pem,
key_pem,
username,
} => Ok((cert_pem, key_pem, username)),
other => Err(other),
}
}
#[deprecated(note = "Use try_into_auth_pair() which returns Result instead of panicking")]
pub(crate) fn into_auth_pair(self) -> (String, String) {
self.try_into_auth_pair()
.unwrap_or_else(|_| panic!("Certificate credentials cannot be converted to an auth pair. Use try_into_auth_pair() or try_into_certificate() instead."))
}
#[deprecated(note = "Use try_into_certificate() which returns Result instead of panicking")]
pub(crate) fn into_certificate(self) -> (String, String, Option<String>) {
self.try_into_certificate()
.unwrap_or_else(|_| panic!("Not certificate credentials. Use try_into_certificate() or try_into_auth_pair() instead."))
}
pub fn is_certificate(&self) -> bool {
matches!(self, Credentials::Certificate { .. })
}
}
mod password;
pub use password::PasswordIdentityProvider;
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_password_provider() {
let provider = PasswordIdentityProvider::new("testuser", "testpass");
let credentials = provider
.get_credentials(&CredentialContext::default())
.await
.unwrap();
match credentials {
Credentials::UsernamePassword { username, password } => {
assert_eq!(username, "testuser");
assert_eq!(password, "testpass");
}
_ => panic!("Expected UsernamePassword credentials"),
}
}
#[tokio::test]
async fn test_provider_clone() {
let provider: Box<dyn IdentityProvider> =
Box::new(PasswordIdentityProvider::new("user", "pass"));
let cloned = provider.clone();
let credentials = cloned
.get_credentials(&CredentialContext::default())
.await
.unwrap();
assert!(matches!(credentials, Credentials::UsernamePassword { .. }));
}
#[test]
fn test_try_into_auth_pair_username_password() {
let creds = Credentials::UsernamePassword {
username: "user".into(),
password: "pass".into(),
};
let (u, p) = creds.try_into_auth_pair().unwrap();
assert_eq!(u, "user");
assert_eq!(p, "pass");
}
#[test]
fn test_try_into_auth_pair_token() {
let creds = Credentials::Token {
username: "user".into(),
token: "tok".into(),
};
let (u, t) = creds.try_into_auth_pair().unwrap();
assert_eq!(u, "user");
assert_eq!(t, "tok");
}
#[test]
fn test_try_into_auth_pair_rejects_certificate() {
let creds = Credentials::Certificate {
cert_pem: "cert".into(),
key_pem: "key".into(),
username: None,
};
let result = creds.try_into_auth_pair();
assert!(result.is_err());
let returned = result.unwrap_err();
assert!(returned.is_certificate());
}
#[test]
fn test_try_into_certificate_success() {
let creds = Credentials::Certificate {
cert_pem: "cert".into(),
key_pem: "key".into(),
username: Some("user".into()),
};
let (c, k, u) = creds.try_into_certificate().unwrap();
assert_eq!(c, "cert");
assert_eq!(k, "key");
assert_eq!(u, Some("user".into()));
}
#[test]
fn test_try_into_certificate_rejects_password() {
let creds = Credentials::UsernamePassword {
username: "user".into(),
password: "pass".into(),
};
let result = creds.try_into_certificate();
assert!(result.is_err());
}
}