use crate::config::vars::{CS_CLIENT_ID, CS_CLIENT_KEY};
use std::future::Future;
use thiserror::Error;
use uuid::Uuid;
use super::ClientKey;
#[derive(Debug, Error)]
pub enum KeyProviderError {
#[error("Client key not configured: {0}")]
NotConfigured(String),
#[error("Invalid client key: {0}")]
InvalidKey(String),
#[error("Failed to load client key: {0}")]
LoadError(String),
}
pub trait KeyProvider: Send + Sync + 'static {
fn client_key(&self) -> impl Future<Output = Result<ClientKey, KeyProviderError>> + Send;
}
pub struct EnvKeyProvider;
impl EnvKeyProvider {
fn parse(client_id: &str, client_key: &str) -> Result<ClientKey, KeyProviderError> {
let uuid = Uuid::parse_str(client_id)
.map_err(|e| KeyProviderError::InvalidKey(format!("invalid {CS_CLIENT_ID}: {e}")))?;
ClientKey::from_hex_v1(uuid, client_key)
.map_err(|e| KeyProviderError::InvalidKey(format!("invalid {CS_CLIENT_KEY}: {e}")))
}
}
impl KeyProvider for EnvKeyProvider {
async fn client_key(&self) -> Result<ClientKey, KeyProviderError> {
let client_id = std::env::var(CS_CLIENT_ID).map_err(|_| {
KeyProviderError::NotConfigured(format!("{CS_CLIENT_ID} environment variable not set"))
})?;
let client_key = std::env::var(CS_CLIENT_KEY).map_err(|_| {
KeyProviderError::NotConfigured(format!("{CS_CLIENT_KEY} environment variable not set"))
})?;
tracing::debug!("loading client key from environment variables");
Self::parse(&client_id, &client_key)
}
}
pub struct StaticKeyProvider(ClientKey);
impl StaticKeyProvider {
pub fn new(key: ClientKey) -> Self {
Self(key)
}
}
impl KeyProvider for StaticKeyProvider {
async fn client_key(&self) -> Result<ClientKey, KeyProviderError> {
Ok(self.0.clone())
}
}
impl<T: KeyProvider> KeyProvider for Option<T> {
async fn client_key(&self) -> Result<ClientKey, KeyProviderError> {
match self {
Some(provider) => provider.client_key().await,
None => Err(KeyProviderError::NotConfigured(
"no explicit key provided".into(),
)),
}
}
}
pub struct FallbackKeyProvider<P, F> {
primary: P,
fallback: F,
}
impl<P: KeyProvider, F: KeyProvider> FallbackKeyProvider<P, F> {
pub fn new(primary: P, fallback: F) -> Self {
Self { primary, fallback }
}
}
impl<P: KeyProvider, F: KeyProvider> KeyProvider for FallbackKeyProvider<P, F> {
async fn client_key(&self) -> Result<ClientKey, KeyProviderError> {
match self.primary.client_key().await {
Ok(key) => {
tracing::debug!("using primary key provider");
Ok(key)
}
Err(KeyProviderError::NotConfigured(_)) => {
tracing::debug!("primary key provider not configured, trying fallback");
self.fallback.client_key().await
}
Err(e) => Err(e),
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use recipher::keyset::{EncryptionKeySet, ProxyKeySet};
fn random_client_key() -> ClientKey {
let ek_a = EncryptionKeySet::generate().unwrap();
let ek_b = EncryptionKeySet::generate().unwrap();
let keyset = ProxyKeySet::generate(&ek_a, &ek_b);
ClientKey::new_v1(Uuid::new_v4(), keyset)
}
struct NotConfiguredProvider;
impl KeyProvider for NotConfiguredProvider {
async fn client_key(&self) -> Result<ClientKey, KeyProviderError> {
Err(KeyProviderError::NotConfigured("not configured".into()))
}
}
struct InvalidKeyProvider;
impl KeyProvider for InvalidKeyProvider {
async fn client_key(&self) -> Result<ClientKey, KeyProviderError> {
Err(KeyProviderError::InvalidKey("bad key".into()))
}
}
mod static_provider {
use super::*;
#[tokio::test]
async fn returns_the_wrapped_key() {
let expected_id = Uuid::new_v4();
let ek_a = EncryptionKeySet::generate().unwrap();
let ek_b = EncryptionKeySet::generate().unwrap();
let keyset = ProxyKeySet::generate(&ek_a, &ek_b);
let client_key = ClientKey::new_v1(expected_id, keyset);
let provider = StaticKeyProvider::new(client_key);
let result = provider.client_key().await.unwrap();
assert_eq!(
result.key_id, expected_id,
"should return the same key_id that was provided"
);
}
}
mod env_provider_parse {
use super::*;
#[test]
fn returns_client_key_for_valid_inputs() {
let key = random_client_key();
let uuid = key.key_id;
let hex = key.to_hex_v1().unwrap();
let result = EnvKeyProvider::parse(&uuid.to_string(), &hex).unwrap();
assert_eq!(
result.key_id, uuid,
"parsed key should have the same client_id"
);
}
mod given_invalid_uuid {
use super::*;
#[test]
fn returns_invalid_key_error() {
let err = EnvKeyProvider::parse("not-a-uuid", "deadbeef").unwrap_err();
assert!(
matches!(err, KeyProviderError::InvalidKey(_)),
"expected InvalidKey for bad UUID, got: {err:?}"
);
}
}
mod given_valid_uuid_but_wrong_key_length {
use super::*;
#[test]
fn returns_invalid_key_error() {
let uuid = Uuid::new_v4();
let err = EnvKeyProvider::parse(&uuid.to_string(), "deadbeef").unwrap_err();
assert!(
matches!(err, KeyProviderError::InvalidKey(_)),
"expected InvalidKey for truncated key material, got: {err:?}"
);
}
}
mod given_invalid_hex_characters {
use super::*;
#[test]
fn returns_invalid_key_error() {
let uuid = Uuid::new_v4();
let err = EnvKeyProvider::parse(&uuid.to_string(), "not-valid-hex!!").unwrap_err();
assert!(
matches!(err, KeyProviderError::InvalidKey(_)),
"expected InvalidKey for non-hex characters, got: {err:?}"
);
}
}
}
mod fallback_provider {
use super::*;
mod given_primary_succeeds {
use super::*;
#[tokio::test]
async fn returns_primary_key() {
let primary_key = random_client_key();
let primary_id = primary_key.key_id;
let fallback_key = random_client_key();
let provider = FallbackKeyProvider::new(
StaticKeyProvider::new(primary_key),
StaticKeyProvider::new(fallback_key),
);
let result = provider.client_key().await.unwrap();
assert_eq!(
result.key_id, primary_id,
"should return the primary provider's key"
);
}
}
mod given_primary_not_configured {
use super::*;
#[tokio::test]
async fn returns_fallback_key() {
let fallback_key = random_client_key();
let fallback_id = fallback_key.key_id;
let provider = FallbackKeyProvider::new(
NotConfiguredProvider,
StaticKeyProvider::new(fallback_key),
);
let result = provider.client_key().await.unwrap();
assert_eq!(
result.key_id, fallback_id,
"should fall through to the secondary provider"
);
}
}
mod given_primary_returns_invalid_key {
use super::*;
#[tokio::test]
async fn does_not_fall_through() {
let fallback_key = random_client_key();
let provider = FallbackKeyProvider::new(
InvalidKeyProvider,
StaticKeyProvider::new(fallback_key),
);
let err = provider.client_key().await.unwrap_err();
assert!(
matches!(err, KeyProviderError::InvalidKey(_)),
"should propagate InvalidKey without trying fallback, got: {err:?}"
);
}
}
}
mod option_provider {
use super::*;
#[tokio::test]
async fn some_delegates_to_inner() {
let key = random_client_key();
let expected_id = key.key_id;
let provider: Option<StaticKeyProvider> = Some(StaticKeyProvider::new(key));
let result = provider.client_key().await.unwrap();
assert_eq!(
result.key_id, expected_id,
"Some(provider) should delegate to the inner provider"
);
}
#[tokio::test]
async fn none_returns_not_configured() {
let provider: Option<StaticKeyProvider> = None;
let err = provider.client_key().await.unwrap_err();
assert!(
matches!(err, KeyProviderError::NotConfigured(_)),
"None should return NotConfigured, got: {err:?}"
);
}
#[tokio::test]
async fn none_triggers_fallback() {
let fallback_key = random_client_key();
let fallback_id = fallback_key.key_id;
let provider = FallbackKeyProvider::new(
Option::<StaticKeyProvider>::None,
StaticKeyProvider::new(fallback_key),
);
let result = provider.client_key().await.unwrap();
assert_eq!(
result.key_id, fallback_id,
"None primary should trigger fallback"
);
}
#[tokio::test]
async fn some_prevents_fallback() {
let primary_key = random_client_key();
let primary_id = primary_key.key_id;
let fallback_key = random_client_key();
let provider = FallbackKeyProvider::new(
Some(StaticKeyProvider::new(primary_key)),
StaticKeyProvider::new(fallback_key),
);
let result = provider.client_key().await.unwrap();
assert_eq!(
result.key_id, primary_id,
"Some primary should prevent fallback"
);
}
}
}