use redis::AsyncCommands;
use serde::{Deserialize, Serialize};
use super::CacheError;
use crate::auth0::cache::{crypto, Cache};
use crate::auth0::token::Token;
#[derive(Debug, thiserror::Error)]
pub enum RedisCacheError {
#[error(transparent)]
Serde(#[from] serde_json::Error),
#[error("redis error: {0}")]
Redis(#[from] redis::RedisError),
#[error("couldn't decrypt stored token: {0}")]
Crypto(#[from] crypto::CryptoError),
}
impl From<RedisCacheError> for super::CacheError {
fn from(val: RedisCacheError) -> Self {
CacheError(Box::new(val))
}
}
#[derive(Clone, Debug)]
pub struct RedisCache {
client: redis::Client,
encryption_key: String,
key_prefix: String,
}
impl RedisCache {
pub async fn new(
redis_connection_url: String,
redis_key_prefix: String,
token_encryption_key: String,
) -> Result<Self, RedisCacheError> {
let client: redis::Client = redis::Client::open(redis_connection_url)?;
let _ = client.get_multiplexed_async_connection().await?;
Ok(RedisCache {
client,
encryption_key: token_encryption_key,
key_prefix: redis_key_prefix,
})
}
async fn get<T>(&self, key: String) -> Result<Option<T>, RedisCacheError>
where
for<'de> T: Deserialize<'de>,
{
self.client
.get_multiplexed_async_connection()
.await?
.get::<_, Option<Vec<u8>>>(key)
.await?
.map(|value| crypto::decrypt(self.encryption_key.as_str(), value.as_slice()))
.transpose()
.map_err(Into::into)
}
async fn put<T: Serialize>(&self, key: String, lifetime_in_seconds: u64, v: T) -> Result<(), RedisCacheError> {
let mut connection = self.client.get_multiplexed_async_connection().await?;
let encrypted_value: Vec<u8> = crypto::encrypt(&v, self.encryption_key.as_str())?;
let _: () = connection.set_ex(key, encrypted_value, lifetime_in_seconds).await?;
Ok(())
}
}
#[async_trait::async_trait]
impl Cache for RedisCache {
async fn get_token(&self, client_id: &str, audience: &str) -> Result<Option<Token>, CacheError> {
let key = token_key(&self.key_prefix, client_id, audience);
self.get(key).await.map_err(Into::into)
}
async fn put_token(&self, client_id: &str, audience: &str, value_ref: &Token) -> Result<(), CacheError> {
let key = token_key(&self.key_prefix, client_id, audience);
self.put(key, value_ref.lifetime_in_seconds(), value_ref)
.await
.map_err(Into::into)
}
}
const TOKEN_VERSION: &str = "2";
fn token_key(key_prefix: &str, caller: &str, audience: &str) -> String {
format!(
"{}:{}:{}:{}:{}",
key_prefix,
super::TOKEN_PREFIX,
caller,
TOKEN_VERSION,
audience
)
}