use std::io;
use crate::storage::error::{StorageError, StorageResult};
use walletkit_db::{params, Connection, DbError, Transaction};
pub(super) fn map_db_err(err: &DbError) -> StorageError {
StorageError::CacheDb(err.to_string())
}
pub(super) fn map_io_err(err: &io::Error) -> StorageError {
StorageError::CacheDb(err.to_string())
}
pub(super) fn parse_fixed_bytes<const N: usize>(
bytes: &[u8],
label: &str,
) -> StorageResult<[u8; N]> {
if bytes.len() != N {
return Err(StorageError::CacheDb(format!(
"{label} length mismatch: expected {N}, got {}",
bytes.len()
)));
}
let mut out = [0u8; N];
out.copy_from_slice(bytes);
Ok(out)
}
pub(super) const CACHE_KEY_PREFIX_MERKLE: u8 = 0x01;
pub(super) const CACHE_KEY_PREFIX_SESSION: u8 = 0x02;
pub(super) const CACHE_KEY_PREFIX_REPLAY_NULLIFIER: u8 = 0x03;
#[derive(Clone, Copy, Debug)]
pub(super) struct CacheEntryTimes {
pub inserted_at: i64,
pub expires_at: i64,
}
pub(super) fn cache_entry_times(
now: u64,
ttl_seconds: u64,
) -> StorageResult<CacheEntryTimes> {
let expires_at = expiry_timestamp(now, ttl_seconds);
Ok(CacheEntryTimes {
inserted_at: to_i64(now, "now")?,
expires_at: to_i64(expires_at, "expires_at")?,
})
}
pub(super) fn prune_expired_entries(conn: &Connection, now: u64) -> StorageResult<()> {
let now_i64 = to_i64(now, "now")?;
conn.execute(
"DELETE FROM cache_entries WHERE expires_at <= ?1",
params![now_i64],
)
.map_err(|err| map_db_err(&err))?;
Ok(())
}
pub(super) fn prune_expired_entries_tx(
tx: &Transaction<'_>,
now: u64,
) -> StorageResult<()> {
let now_i64 = to_i64(now, "now")?;
tx.execute(
"DELETE FROM cache_entries WHERE expires_at <= ?1",
params![now_i64],
)
.map_err(|err| map_db_err(&err))?;
Ok(())
}
pub(super) fn upsert_cache_entry(
conn: &Connection,
key: &[u8],
value: &[u8],
times: CacheEntryTimes,
) -> StorageResult<()> {
conn.execute(
"INSERT OR REPLACE INTO cache_entries (
key_bytes,
value_bytes,
inserted_at,
expires_at
) VALUES (?1, ?2, ?3, ?4)",
params![key, value, times.inserted_at, times.expires_at,],
)
.map_err(|err| map_db_err(&err))?;
Ok(())
}
pub(super) fn insert_cache_entry_tx(
tx: &Transaction<'_>,
key: &[u8],
value: &[u8],
times: CacheEntryTimes,
) -> StorageResult<()> {
tx.execute(
"INSERT INTO cache_entries (
key_bytes,
value_bytes,
inserted_at,
expires_at
) VALUES (?1, ?2, ?3, ?4)",
params![key, value, times.inserted_at, times.expires_at,],
)
.map_err(|err| map_db_err(&err))?;
Ok(())
}
pub(super) fn get_cache_entry(
conn: &Connection,
key: &[u8],
now: u64,
insertion_before: Option<u64>,
) -> StorageResult<Option<Vec<u8>>> {
let now = to_i64(now, "now")?;
if let Some(insertion_before) = insertion_before {
let insertion_before = to_i64(insertion_before, "insertion_before")?;
conn.query_row_optional(
"SELECT value_bytes FROM cache_entries WHERE key_bytes = ?1 AND expires_at >= ?2 AND inserted_at < ?3",
params![key, now, insertion_before],
|stmt| Ok(stmt.column_blob(0)),
)
.map_err(|err| map_db_err(&err))
} else {
conn.query_row_optional(
"SELECT value_bytes FROM cache_entries WHERE key_bytes = ?1 AND expires_at >= ?2",
params![key, now],
|stmt| Ok(stmt.column_blob(0)),
)
.map_err(|err| map_db_err(&err))
}
}
pub(super) fn get_cache_entry_tx(
tx: &Transaction<'_>,
key: &[u8],
now: u64,
insertion_before: Option<u64>,
) -> StorageResult<Option<Vec<u8>>> {
let now = to_i64(now, "now")?;
if let Some(insertion_before) = insertion_before {
let insertion_before = to_i64(insertion_before, "insertion_before")?;
let mut stmt = tx.prepare(
"SELECT value_bytes FROM cache_entries WHERE key_bytes = ?1 AND expires_at >= ?2 AND inserted_at < ?3",
).map_err(|err| map_db_err(&err))?;
stmt.bind_values(params![key, now, insertion_before])
.map_err(|err| map_db_err(&err))?;
match stmt.step().map_err(|err| map_db_err(&err))? {
walletkit_db::StepResult::Row(row) => Ok(Some(row.column_blob(0))),
walletkit_db::StepResult::Done => Ok(None),
}
} else {
let mut stmt = tx.prepare(
"SELECT value_bytes FROM cache_entries WHERE key_bytes = ?1 AND expires_at >= ?2",
).map_err(|err| map_db_err(&err))?;
stmt.bind_values(params![key, now])
.map_err(|err| map_db_err(&err))?;
match stmt.step().map_err(|err| map_db_err(&err))? {
walletkit_db::StepResult::Row(row) => Ok(Some(row.column_blob(0))),
walletkit_db::StepResult::Done => Ok(None),
}
}
}
fn cache_key_with_prefix(prefix: u8, payload: &[u8]) -> Vec<u8> {
let mut key = Vec::with_capacity(1 + payload.len());
key.push(prefix);
key.extend_from_slice(payload);
key
}
pub(super) fn session_cache_key(rp_id: [u8; 32]) -> Vec<u8> {
cache_key_with_prefix(CACHE_KEY_PREFIX_SESSION, rp_id.as_ref())
}
pub(super) fn replay_nullifier_key(nullifier: [u8; 32]) -> Vec<u8> {
cache_key_with_prefix(CACHE_KEY_PREFIX_REPLAY_NULLIFIER, nullifier.as_ref())
}
pub(super) const fn expiry_timestamp(now: u64, ttl_seconds: u64) -> u64 {
now.saturating_add(ttl_seconds)
}
pub(super) fn to_i64(value: u64, label: &str) -> StorageResult<i64> {
i64::try_from(value).map_err(|_| {
StorageError::CacheDb(format!("{label} out of range for i64: {value}"))
})
}