use crate::config::validation::DEFAULT_RETRY_INTERVAL_MS;
use crate::error::{CacheError, Result};
use crate::utils::validate_cache_key as utils_validate_cache_key;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing::{debug, error, info, instrument, warn};
#[allow(dead_code)]
pub fn validate_sql_identifier(identifier: &str) -> bool {
if identifier.is_empty() {
return false;
}
let mut chars = identifier.chars();
let first = match chars.next() {
Some(c) => c,
None => return false,
};
if !first.is_ascii_alphabetic() && first != '_' {
return false;
}
for c in chars {
if !c.is_ascii_alphanumeric() && c != '_' {
return false;
}
}
true
}
pub fn validate_cache_key(key: &str) -> bool {
utils_validate_cache_key(key).is_ok()
}
fn escape_sql_string(value: &str) -> String {
let mut escaped = String::with_capacity(value.len() * 2);
for c in value.chars() {
match c {
'\'' => escaped.push_str("''"),
'\\' => escaped.push_str("\\\\"),
'\0' => escaped.push_str("\\0"),
'"' => escaped.push_str("\\\""),
'\n' => escaped.push_str("\\n"),
'\r' => escaped.push_str("\\r"),
'\t' => escaped.push_str("\\t"),
_ => escaped.push(c),
}
}
escaped
}
#[async_trait]
pub trait DbLoader: Send + Sync + std::fmt::Debug {
async fn load(&self, key: &str) -> Result<Option<Vec<u8>>>;
async fn load_batch(&self, keys: Vec<String>) -> Result<Vec<(String, Vec<u8>)>>;
fn is_healthy(&self) -> bool;
}
pub struct DbFallbackManager {
loader: Arc<dyn DbLoader>,
enabled: bool,
timeout_ms: u64,
max_retries: u32,
}
impl std::fmt::Debug for DbFallbackManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DbFallbackManager")
.field("enabled", &self.enabled)
.field("timeout_ms", &self.timeout_ms)
.field("max_retries", &self.max_retries)
.field("loader_healthy", &self.loader.is_healthy())
.finish()
}
}
impl DbFallbackManager {
pub fn new(
loader: Arc<dyn DbLoader>,
enabled: bool,
timeout_ms: u64,
max_retries: u32,
) -> Self {
Self {
loader,
enabled,
timeout_ms,
max_retries,
}
}
#[instrument(skip(self), level = "info")]
pub async fn fallback_load(&self, key: &str) -> Result<Option<Vec<u8>>> {
if !self.enabled {
debug!("Database fallback is disabled");
return Ok(None);
}
if !self.loader.is_healthy() {
error!("Database loader is not healthy, skipping fallback");
return Ok(None);
}
info!("Attempting database fallback for key: {}", key);
let mut last_error = None;
for attempt in 0..=self.max_retries {
if attempt > 0 {
debug!("Retry attempt {} for key: {}", attempt, key);
}
match self.try_load_with_timeout(key).await {
Ok(Some(data)) => {
info!("Successfully loaded data from database for key: {}", key);
return Ok(Some(data));
}
Ok(None) => {
debug!("No data found in database for key: {}", key);
return Ok(None);
}
Err(e) => {
error!("Failed to load data from database for key {}: {}", key, e);
last_error = Some(e);
if attempt < self.max_retries {
let backoff_ms = DEFAULT_RETRY_INTERVAL_MS * (2_u64.pow(attempt));
tokio::time::sleep(tokio::time::Duration::from_millis(backoff_ms)).await;
}
}
}
}
error!("All retry attempts failed for key: {}", key);
Err(last_error.unwrap_or_else(|| {
CacheError::DatabaseError("All fallback attempts failed".to_string())
}))
}
#[instrument(skip(self), level = "info")]
pub async fn fallback_load_batch(&self, keys: Vec<String>) -> Result<Vec<(String, Vec<u8>)>> {
if !self.enabled {
debug!("Database fallback is disabled");
return Ok(Vec::new());
}
if !self.loader.is_healthy() {
error!("Database loader is not healthy, skipping batch fallback");
return Ok(Vec::new());
}
info!("Attempting batch database fallback for {} keys", keys.len());
match tokio::time::timeout(
tokio::time::Duration::from_millis(self.timeout_ms),
self.loader.load_batch(keys.clone()),
)
.await
{
Ok(Ok(results)) => {
info!("Successfully loaded {} items from database", results.len());
Ok(results)
}
Ok(Err(e)) => {
error!("Failed to batch load from database: {}", e);
Err(e)
}
Err(_) => {
error!(
"Batch database fallback timed out after {}ms",
self.timeout_ms
);
Err(CacheError::Timeout(format!(
"Batch fallback timeout after {}ms",
self.timeout_ms
)))
}
}
}
async fn try_load_with_timeout(&self, key: &str) -> Result<Option<Vec<u8>>> {
match tokio::time::timeout(
tokio::time::Duration::from_millis(self.timeout_ms),
self.loader.load(key),
)
.await
{
Ok(result) => result,
Err(_) => {
debug!(
"Database load timed out after {}ms for key: {}",
self.timeout_ms, key
);
Ok(None)
}
}
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
}
#[derive(Debug)]
pub struct SqlDbLoader {
pool: Arc<dyn DbConnectionPool>,
table_name: String,
key_column: String,
value_column: String,
}
impl SqlDbLoader {
pub fn new(
pool: Arc<dyn DbConnectionPool>,
table_name: String,
key_column: String,
value_column: String,
) -> Result<Self> {
if !validate_sql_identifier(&table_name) {
return Err(CacheError::InvalidInput(format!(
"Invalid table name: {}. Table name must be a valid SQL identifier.",
table_name
)));
}
if !validate_sql_identifier(&key_column) {
return Err(CacheError::InvalidInput(format!(
"Invalid key column name: {}. Column name must be a valid SQL identifier.",
key_column
)));
}
if !validate_sql_identifier(&value_column) {
return Err(CacheError::InvalidInput(format!(
"Invalid value column name: {}. Column name must be a valid SQL identifier.",
value_column
)));
}
Ok(Self {
pool,
table_name,
key_column,
value_column,
})
}
}
#[async_trait]
impl DbLoader for SqlDbLoader {
#[instrument(skip(self), level = "debug")]
async fn load(&self, key: &str) -> Result<Option<Vec<u8>>> {
if !validate_cache_key(key) {
warn!("Invalid cache key format: {}", key);
return Err(CacheError::InvalidInput(format!(
"Invalid cache key format: {}. Key must be alphanumeric or contain -_.:/ and be <= 1024 characters.",
key
)));
}
let escaped_key = escape_sql_string(key);
let query = format!(
"SELECT {} FROM {} WHERE {} = '{}'",
self.value_column, self.table_name, self.key_column, escaped_key
);
debug!("Executing database query: {}", query);
self.pool.execute_query(&query).await
}
#[instrument(skip(self), level = "debug")]
async fn load_batch(&self, keys: Vec<String>) -> Result<Vec<(String, Vec<u8>)>> {
if keys.is_empty() {
return Ok(Vec::new());
}
for key in &keys {
if !validate_cache_key(key) {
warn!("Invalid cache key in batch: {}", key);
return Err(CacheError::InvalidInput(format!(
"Invalid cache key format: {}. Key must be alphanumeric or contain -_.:/ and be <= 1024 characters.",
key
)));
}
}
let escaped_keys: Vec<String> = keys
.iter()
.map(|k| format!("'{}'", escape_sql_string(k)))
.collect();
let key_list = escaped_keys.join(",");
if !validate_sql_identifier(&self.key_column) {
return Err(CacheError::InvalidInput(format!(
"Invalid key_column identifier: {}",
self.key_column
)));
}
if !validate_sql_identifier(&self.value_column) {
return Err(CacheError::InvalidInput(format!(
"Invalid value_column identifier: {}",
self.value_column
)));
}
if !validate_sql_identifier(&self.table_name) {
return Err(CacheError::InvalidInput(format!(
"Invalid table_name identifier: {}",
self.table_name
)));
}
let query = format!(
"SELECT {}, {} FROM {} WHERE {} IN ({})",
self.key_column, self.value_column, self.table_name, self.key_column, key_list
);
debug!("Executing batch database query for {} keys", keys.len());
self.pool.execute_batch_query(&query).await
}
fn is_healthy(&self) -> bool {
self.pool.is_healthy()
}
}
#[async_trait]
pub trait DbConnectionPool: Send + Sync + std::fmt::Debug {
async fn execute_query(&self, query: &str) -> Result<Option<Vec<u8>>>;
async fn execute_batch_query(&self, query: &str) -> Result<Vec<(String, Vec<u8>)>>;
fn is_healthy(&self) -> bool;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DbFallbackConfig {
pub enabled: bool,
pub timeout_ms: u64,
pub max_retries: u32,
pub connection_string: String,
pub table_name: String,
pub key_column: String,
pub value_column: String,
}
impl Default for DbFallbackConfig {
fn default() -> Self {
Self {
enabled: false,
timeout_ms: 5000,
max_retries: 3,
connection_string: String::new(),
table_name: "cache_table".to_string(),
key_column: "cache_key".to_string(),
value_column: "cache_value".to_string(),
}
}
}