use async_trait::async_trait;
use papaya::HashMapRef;
use redis::{Client, Commands};
use serde_json;
use std::{error::Error, hash::RandomState};
use super::certificates::{Certificate, SerializableCertificate};
use super::store_trait::Store;
const CHALLENGE_TTL_SECONDS: u64 = 300;
pub struct RedisStore {
pool: r2d2::Pool<redis::Client>,
cache: papaya::HashMap<String, Certificate>,
challenge_cache: papaya::HashMap<String, (String, String)>,
}
impl RedisStore {
pub fn new(redis_url: &str) -> Result<Self, Box<dyn Error>> {
let client = Client::open(redis_url)?;
let pool = r2d2::Pool::builder().build(client)?;
Ok(RedisStore {
pool,
cache: papaya::HashMap::new(),
challenge_cache: papaya::HashMap::new(),
})
}
fn certificate_key(domain: &str) -> String {
format!("proksi:cert:{domain}")
}
fn challenge_key(domain: &str) -> String {
format!("proksi:challenge:{domain}")
}
fn load_from_redis(&self, domain: &str) -> Option<Certificate> {
let mut conn = self.pool.get().unwrap();
let key = Self::certificate_key(domain);
let cert_data: Option<String> = conn.get(&key).ok()?;
if let Some(data) = cert_data {
if let Ok(serializable_cert) = serde_json::from_str::<SerializableCertificate>(&data) {
return Certificate::from_serializable(serializable_cert).ok();
}
}
None
}
fn load_challenge_from_redis(&self, domain: &str) -> Option<(String, String)> {
let mut conn = self.pool.get().unwrap();
let key = Self::challenge_key(domain);
let challenge_data: Option<String> = conn.get(&key).ok()?;
if let Some(data) = challenge_data {
if let Ok(challenge) = serde_json::from_str::<(String, String)>(&data) {
return Some(challenge);
}
}
None
}
}
#[async_trait]
impl Store for RedisStore {
async fn get_certificate(&self, domain: &str) -> Option<Certificate> {
if let Some(cert) = self.cache.pin().get(domain) {
return Some(cert.clone());
}
if let Some(cert) = self.load_from_redis(domain) {
self.cache.pin().insert(domain.to_string(), cert.clone());
return Some(cert);
}
None
}
async fn set_certificate(&self, domain: &str, cert: Certificate) -> Result<(), Box<dyn Error>> {
let mut conn = self.pool.get()?;
let key = Self::certificate_key(domain);
let serializable_cert = cert.to_serializable()?;
let cert_json = serde_json::to_string(&serializable_cert)?;
conn.set::<String, String, String>(key, cert_json)?;
self.cache.pin().insert(domain.to_string(), cert);
Ok(())
}
async fn get_certificates(
&self,
) -> HashMapRef<'_, String, Certificate, RandomState, seize::LocalGuard<'_>> {
static TEMP_MAP: once_cell::sync::Lazy<papaya::HashMap<String, Certificate>> =
once_cell::sync::Lazy::new(papaya::HashMap::new);
TEMP_MAP.pin().clear();
for (key, value) in &self.cache.pin() {
TEMP_MAP.pin().insert(key.clone(), value.clone());
}
let mut conn_guard = self.pool.get().unwrap();
let pattern = "proksi:cert:*".to_string();
if let Ok(keys) = conn_guard.keys::<_, Vec<String>>(&pattern) {
for key in keys {
if let Some(domain) = key.strip_prefix("proksi:cert:") {
let domain = domain.to_string();
if TEMP_MAP.pin().contains_key(&domain) {
continue;
}
if let Ok(cert_data) = conn_guard.get::<_, String>(&key) {
if let Ok(serializable_cert) =
serde_json::from_str::<SerializableCertificate>(&cert_data)
{
if let Ok(cert) = Certificate::from_serializable(serializable_cert) {
TEMP_MAP.pin().insert(domain.clone(), cert.clone());
self.cache.pin().insert(domain, cert);
}
}
}
}
}
}
TEMP_MAP.pin()
}
async fn get_challenge(&self, domain: &str) -> Option<(String, String)> {
if let Some(challenge) = self.challenge_cache.pin().get(domain) {
return Some(challenge.clone());
}
if let Some(challenge) = self.load_challenge_from_redis(domain) {
self.challenge_cache
.pin()
.insert(domain.to_string(), challenge.clone());
return Some(challenge);
}
None
}
async fn set_challenge(
&self,
domain: &str,
token: String,
proof: String,
) -> Result<(), Box<dyn Error>> {
let mut conn = self.pool.get()?;
let key = Self::challenge_key(domain);
let challenge_tuple = (token.clone(), proof.clone());
let challenge_json = serde_json::to_string(&challenge_tuple)?;
conn.set_ex::<String, String, String>(key, challenge_json, CHALLENGE_TTL_SECONDS)?;
self.challenge_cache
.pin()
.insert(domain.to_string(), challenge_tuple);
Ok(())
}
}