use std::collections::HashMap;
use std::fs;
use std::path::Path;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use hmac::digest::KeyInit;
use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use crate::config::{
ReplayBackend, ReplayConfig, ReplayHmacKeyProvider, ReplayIntegrityMode, ReplayKeyEncoding,
};
use crate::crypto::{ReplayState, merge_replay_states};
use crate::error::{ConfigError, Result, SrxError};
type HmacSha256 = Hmac<Sha256>;
pub trait CustomHmacKeyProvider: Send + Sync {
fn resolve_key(&self, replay: &ReplayConfig) -> Result<Vec<u8>>;
}
static CUSTOM_HMAC_PROVIDERS: OnceLock<RwLock<HashMap<String, Arc<dyn CustomHmacKeyProvider>>>> =
OnceLock::new();
fn custom_hmac_providers() -> &'static RwLock<HashMap<String, Arc<dyn CustomHmacKeyProvider>>> {
CUSTOM_HMAC_PROVIDERS.get_or_init(|| RwLock::new(HashMap::new()))
}
pub fn register_custom_hmac_key_provider(
name: impl Into<String>,
provider: Arc<dyn CustomHmacKeyProvider>,
) {
let mut guard = custom_hmac_providers()
.write()
.expect("custom hmac provider registry poisoned");
guard.insert(name.into(), provider);
}
#[derive(Default)]
struct ReplayStoreMetrics {
cas_attempts: AtomicU64,
cas_successes: AtomicU64,
cas_conflicts: AtomicU64,
cas_retries: AtomicU64,
file: ReplayStoreBackendCounters,
sqlite: ReplayStoreBackendCounters,
redis: ReplayStoreBackendCounters,
}
#[derive(Default)]
struct ReplayStoreBackendCounters {
attempts: AtomicU64,
successes: AtomicU64,
conflicts: AtomicU64,
retries: AtomicU64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ReplayStoreBackendMetricsSnapshot {
pub attempts: u64,
pub successes: u64,
pub conflicts: u64,
pub retries: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ReplayStoreMetricsSnapshot {
pub cas_attempts: u64,
pub cas_successes: u64,
pub cas_conflicts: u64,
pub cas_retries: u64,
pub file: ReplayStoreBackendMetricsSnapshot,
pub sqlite: ReplayStoreBackendMetricsSnapshot,
pub redis: ReplayStoreBackendMetricsSnapshot,
}
static REPLAY_STORE_METRICS: OnceLock<ReplayStoreMetrics> = OnceLock::new();
fn replay_store_metrics() -> &'static ReplayStoreMetrics {
REPLAY_STORE_METRICS.get_or_init(ReplayStoreMetrics::default)
}
pub fn replay_store_metrics_snapshot() -> ReplayStoreMetricsSnapshot {
let m = replay_store_metrics();
ReplayStoreMetricsSnapshot {
cas_attempts: m.cas_attempts.load(Ordering::Relaxed),
cas_successes: m.cas_successes.load(Ordering::Relaxed),
cas_conflicts: m.cas_conflicts.load(Ordering::Relaxed),
cas_retries: m.cas_retries.load(Ordering::Relaxed),
file: snapshot_backend(&m.file),
sqlite: snapshot_backend(&m.sqlite),
redis: snapshot_backend(&m.redis),
}
}
fn snapshot_backend(v: &ReplayStoreBackendCounters) -> ReplayStoreBackendMetricsSnapshot {
ReplayStoreBackendMetricsSnapshot {
attempts: v.attempts.load(Ordering::Relaxed),
successes: v.successes.load(Ordering::Relaxed),
conflicts: v.conflicts.load(Ordering::Relaxed),
retries: v.retries.load(Ordering::Relaxed),
}
}
#[cfg(test)]
pub(crate) fn reset_replay_store_metrics_for_tests() {
let m = replay_store_metrics();
m.cas_attempts.store(0, Ordering::Relaxed);
m.cas_successes.store(0, Ordering::Relaxed);
m.cas_conflicts.store(0, Ordering::Relaxed);
m.cas_retries.store(0, Ordering::Relaxed);
reset_backend(&m.file);
reset_backend(&m.sqlite);
reset_backend(&m.redis);
}
#[cfg(test)]
fn reset_backend(v: &ReplayStoreBackendCounters) {
v.attempts.store(0, Ordering::Relaxed);
v.successes.store(0, Ordering::Relaxed);
v.conflicts.store(0, Ordering::Relaxed);
v.retries.store(0, Ordering::Relaxed);
}
fn backend_counters<'a>(
m: &'a ReplayStoreMetrics,
backend: &ReplayBackend,
) -> &'a ReplayStoreBackendCounters {
match backend {
ReplayBackend::FileJson => &m.file,
ReplayBackend::Sqlite { .. } => &m.sqlite,
ReplayBackend::Redis { .. } => &m.redis,
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ReplaySignedPayload {
session_binding: String,
state: ReplayState,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ReplayEnvelopeV1 {
version: u8,
payload: ReplaySignedPayload,
checksum_hex: String,
hmac_hex: Option<String>,
}
pub trait ReplayStorage: Send + Sync {
fn load_raw(&self, replay: &ReplayConfig) -> Result<Option<Vec<u8>>>;
fn save_raw(&self, replay: &ReplayConfig, raw: &[u8]) -> Result<()>;
fn save_raw_if_unchanged(
&self,
replay: &ReplayConfig,
expected_raw: Option<&[u8]>,
new_raw: &[u8],
) -> Result<bool>;
}
pub struct FileReplayStorage;
impl FileReplayStorage {
fn path(replay: &ReplayConfig) -> &Path {
replay.state_file.as_path()
}
}
impl ReplayStorage for FileReplayStorage {
fn load_raw(&self, replay: &ReplayConfig) -> Result<Option<Vec<u8>>> {
match fs::read(Self::path(replay)) {
Ok(raw) => Ok(Some(raw)),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(e) => Err(e.into()),
}
}
fn save_raw(&self, replay: &ReplayConfig, raw: &[u8]) -> Result<()> {
if let Some(parent) = Self::path(replay).parent()
&& !parent.as_os_str().is_empty()
{
fs::create_dir_all(parent)?;
}
let tmp_path = Self::path(replay).with_extension("tmp");
fs::write(&tmp_path, raw)?;
if Self::path(replay).exists() {
fs::remove_file(Self::path(replay))?;
}
fs::rename(tmp_path, Self::path(replay))?;
Ok(())
}
fn save_raw_if_unchanged(
&self,
replay: &ReplayConfig,
expected_raw: Option<&[u8]>,
new_raw: &[u8],
) -> Result<bool> {
let current = self.load_raw(replay)?;
let unchanged = match (current.as_deref(), expected_raw) {
(None, None) => true,
(Some(cur), Some(exp)) => cur == exp,
_ => false,
};
if unchanged {
self.save_raw(replay, new_raw)?;
Ok(true)
} else {
Ok(false)
}
}
}
#[cfg(feature = "replay-sqlite")]
pub struct SqliteReplayStorage {
path: std::path::PathBuf,
table: String,
}
#[cfg(feature = "replay-sqlite")]
impl SqliteReplayStorage {
pub fn new(path: std::path::PathBuf, table: String) -> Self {
Self { path, table }
}
}
#[cfg(feature = "replay-sqlite")]
impl ReplayStorage for SqliteReplayStorage {
fn load_raw(&self, replay: &ReplayConfig) -> Result<Option<Vec<u8>>> {
use rusqlite::{Connection, params};
let conn = Connection::open(&self.path).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!(
"sqlite open '{}': {e}",
self.path.display()
)))
})?;
let sql = format!(
"CREATE TABLE IF NOT EXISTS {} (k TEXT PRIMARY KEY, v BLOB NOT NULL)",
self.table
);
conn.execute(&sql, []).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!("sqlite create table: {e}")))
})?;
let sql = format!("SELECT v FROM {} WHERE k = ?1", self.table);
let mut stmt = conn.prepare(&sql).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!("sqlite prepare select: {e}")))
})?;
let mut rows = stmt
.query(params![replay.storage_key.as_str()])
.map_err(|e| SrxError::Config(ConfigError::Invalid(format!("sqlite query: {e}"))))?;
if let Some(row) = rows
.next()
.map_err(|e| SrxError::Config(ConfigError::Invalid(format!("sqlite next row: {e}"))))?
{
let v: Vec<u8> = row.get(0).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!("sqlite decode blob: {e}")))
})?;
Ok(Some(v))
} else {
Ok(None)
}
}
fn save_raw(&self, replay: &ReplayConfig, raw: &[u8]) -> Result<()> {
use rusqlite::{Connection, params};
let conn = Connection::open(&self.path).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!(
"sqlite open '{}': {e}",
self.path.display()
)))
})?;
let sql = format!(
"CREATE TABLE IF NOT EXISTS {} (k TEXT PRIMARY KEY, v BLOB NOT NULL)",
self.table
);
conn.execute(&sql, []).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!("sqlite create table: {e}")))
})?;
let sql = format!(
"INSERT INTO {} (k, v) VALUES (?1, ?2)
ON CONFLICT(k) DO UPDATE SET v = excluded.v",
self.table
);
conn.execute(&sql, params![replay.storage_key.as_str(), raw])
.map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!(
"sqlite upsert replay state: {e}"
)))
})?;
Ok(())
}
fn save_raw_if_unchanged(
&self,
replay: &ReplayConfig,
expected_raw: Option<&[u8]>,
new_raw: &[u8],
) -> Result<bool> {
use rusqlite::{Connection, params};
let mut conn = Connection::open(&self.path).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!(
"sqlite open '{}': {e}",
self.path.display()
)))
})?;
let tx = conn.transaction().map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!(
"sqlite begin transaction: {e}"
)))
})?;
let sql = format!(
"CREATE TABLE IF NOT EXISTS {} (k TEXT PRIMARY KEY, v BLOB NOT NULL)",
self.table
);
tx.execute(&sql, []).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!("sqlite create table: {e}")))
})?;
let changed = if let Some(exp) = expected_raw {
let sql = format!("UPDATE {} SET v = ?1 WHERE k = ?2 AND v = ?3", self.table);
tx.execute(&sql, params![new_raw, replay.storage_key.as_str(), exp])
.map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!("sqlite cas update: {e}")))
})?
== 1
} else {
let sql = format!(
"INSERT INTO {} (k, v) VALUES (?1, ?2) ON CONFLICT(k) DO NOTHING",
self.table
);
tx.execute(&sql, params![replay.storage_key.as_str(), new_raw])
.map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!("sqlite cas insert: {e}")))
})?
== 1
};
tx.commit().map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!(
"sqlite commit transaction: {e}"
)))
})?;
Ok(changed)
}
}
#[cfg(feature = "replay-redis")]
pub struct RedisReplayStorage {
url: String,
key_prefix: String,
}
#[cfg(feature = "replay-redis")]
impl RedisReplayStorage {
pub fn new(url: String, key_prefix: String) -> Self {
Self { url, key_prefix }
}
fn key(&self, replay: &ReplayConfig) -> String {
format!("{}{}", self.key_prefix, replay.storage_key)
}
}
#[cfg(feature = "replay-redis")]
impl ReplayStorage for RedisReplayStorage {
fn load_raw(&self, replay: &ReplayConfig) -> Result<Option<Vec<u8>>> {
use redis::Commands;
let client = redis::Client::open(self.url.as_str()).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!(
"redis client '{}': {e}",
self.url
)))
})?;
let mut conn = client.get_connection().map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!("redis connection: {e}")))
})?;
conn.get(self.key(replay)).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!("redis get replay state: {e}")))
})
}
fn save_raw(&self, replay: &ReplayConfig, raw: &[u8]) -> Result<()> {
use redis::Commands;
let client = redis::Client::open(self.url.as_str()).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!(
"redis client '{}': {e}",
self.url
)))
})?;
let mut conn = client.get_connection().map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!("redis connection: {e}")))
})?;
conn.set::<_, _, ()>(self.key(replay), raw).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!("redis set replay state: {e}")))
})?;
Ok(())
}
fn save_raw_if_unchanged(
&self,
replay: &ReplayConfig,
expected_raw: Option<&[u8]>,
new_raw: &[u8],
) -> Result<bool> {
let key = self.key(replay);
let script = r#"
local cur = redis.call('GET', KEYS[1])
if ARGV[1] == '__SRX_EXPECT_NONE__' then
if cur then
return 0
end
redis.call('SET', KEYS[1], ARGV[2])
return 1
end
if (not cur) then
return 0
end
if cur == ARGV[1] then
redis.call('SET', KEYS[1], ARGV[2])
return 1
end
return 0
"#;
let client = redis::Client::open(self.url.as_str()).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!(
"redis client '{}': {e}",
self.url
)))
})?;
let mut conn = client.get_connection().map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!("redis connection: {e}")))
})?;
let expected_marker = b"__SRX_EXPECT_NONE__".to_vec();
let expected = expected_raw
.map(|v| v.to_vec())
.unwrap_or(expected_marker.clone());
let res: i32 = redis::cmd("EVAL")
.arg(script)
.arg(1)
.arg(key)
.arg(expected)
.arg(new_raw)
.query(&mut conn)
.map_err(|e| SrxError::Config(ConfigError::Invalid(format!("redis cas eval: {e}"))))?;
Ok(res == 1)
}
}
pub fn storage_from_config(replay: &ReplayConfig) -> Result<Box<dyn ReplayStorage>> {
match &replay.backend {
ReplayBackend::FileJson => Ok(Box::new(FileReplayStorage)),
ReplayBackend::Sqlite { path, table } => {
#[cfg(feature = "replay-sqlite")]
{
Ok(Box::new(SqliteReplayStorage::new(
path.clone(),
table.clone(),
)))
}
#[cfg(not(feature = "replay-sqlite"))]
{
let _ = (path, table);
Err(SrxError::Config(ConfigError::Invalid(
"replay backend 'sqlite' requires feature 'replay-sqlite'".to_string(),
)))
}
}
ReplayBackend::Redis { url, key_prefix } => {
#[cfg(feature = "replay-redis")]
{
Ok(Box::new(RedisReplayStorage::new(
url.clone(),
key_prefix.clone(),
)))
}
#[cfg(not(feature = "replay-redis"))]
{
let _ = (url, key_prefix);
Err(SrxError::Config(ConfigError::Invalid(
"replay backend 'redis' requires feature 'replay-redis'".to_string(),
)))
}
}
}
}
pub fn encode_replay_envelope(
replay: &ReplayConfig,
session_binding: String,
state: ReplayState,
) -> Result<Vec<u8>> {
let payload = ReplaySignedPayload {
session_binding,
state,
};
let payload_bytes = serde_json::to_vec(&payload).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!(
"failed to serialize replay payload: {e}"
)))
})?;
let checksum_hex = hex_encode(&Sha256::digest(&payload_bytes));
let hmac_hex = match replay.integrity.mode {
ReplayIntegrityMode::None | ReplayIntegrityMode::ChecksumSha256 => None,
ReplayIntegrityMode::HmacSha256 => Some(hex_encode(&hmac_sha256(
&resolve_hmac_key(replay)?,
&payload_bytes,
)?)),
};
let envelope = ReplayEnvelopeV1 {
version: 1,
payload,
checksum_hex,
hmac_hex,
};
serde_json::to_vec(&envelope).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!(
"failed to serialize replay envelope: {e}"
)))
})
}
pub fn decode_replay_envelope(
replay: &ReplayConfig,
expected_session_binding: &str,
raw: &[u8],
) -> Result<Option<ReplayState>> {
if let Ok(envelope) = serde_json::from_slice::<ReplayEnvelopeV1>(raw) {
if envelope.version != 1 {
return Err(SrxError::Config(ConfigError::Invalid(format!(
"unsupported replay envelope version: {}",
envelope.version
))));
}
let payload_bytes = serde_json::to_vec(&envelope.payload).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!(
"failed to serialize replay payload for verification: {e}"
)))
})?;
let expected_checksum = hex_encode(&Sha256::digest(&payload_bytes));
if envelope.checksum_hex != expected_checksum {
return Err(SrxError::Config(ConfigError::Invalid(
"replay state checksum mismatch (file tamper/corruption suspected)".to_string(),
)));
}
if replay.integrity.mode == ReplayIntegrityMode::HmacSha256 {
let got = envelope.hmac_hex.as_deref().ok_or_else(|| {
SrxError::Config(ConfigError::Invalid(
"missing replay envelope hmac".to_string(),
))
})?;
let expected = hex_encode(&hmac_sha256(&resolve_hmac_key(replay)?, &payload_bytes)?);
if got != expected {
return Err(SrxError::Config(ConfigError::Invalid(
"replay state HMAC mismatch (file tamper suspected)".to_string(),
)));
}
}
if envelope.payload.session_binding != expected_session_binding {
tracing::warn!(
"replay state session binding mismatch; ignoring stale snapshot for a new handshake"
);
return Ok(None);
}
return Ok(Some(envelope.payload.state));
}
let legacy = match serde_json::from_slice::<ReplayState>(raw) {
Ok(state) => state,
Err(e) => {
return Err(SrxError::Config(ConfigError::Invalid(format!(
"invalid replay state payload: {e}"
))));
}
};
if replay.integrity.mode != ReplayIntegrityMode::None {
tracing::warn!(
"legacy replay snapshot without integrity/session binding ignored (enable one-time migration by setting replay.integrity.mode=None)"
);
return Ok(None);
}
Ok(Some(legacy))
}
pub fn merge_and_persist_replay_state(
replay: &ReplayConfig,
storage: &dyn ReplayStorage,
session_binding: &str,
local_state: ReplayState,
) -> Result<()> {
let mut candidate = local_state;
let max_retries = 12usize;
let metrics = replay_store_metrics();
for _ in 0..max_retries {
let by_backend = backend_counters(metrics, &replay.backend);
metrics.cas_attempts.fetch_add(1, Ordering::Relaxed);
by_backend.attempts.fetch_add(1, Ordering::Relaxed);
let observed_raw = storage.load_raw(replay)?;
if let Some(raw) = observed_raw.as_deref()
&& let Some(remote) = decode_replay_envelope(replay, session_binding, raw)?
{
candidate = merge_replay_states(&candidate, &remote).ok_or_else(|| {
SrxError::Config(ConfigError::Invalid(
"failed to merge replay snapshots: invalid bitmap size".to_string(),
))
})?;
}
let encoded =
encode_replay_envelope(replay, session_binding.to_string(), candidate.clone())?;
if storage.save_raw_if_unchanged(replay, observed_raw.as_deref(), &encoded)? {
metrics.cas_successes.fetch_add(1, Ordering::Relaxed);
by_backend.successes.fetch_add(1, Ordering::Relaxed);
return Ok(());
}
metrics.cas_conflicts.fetch_add(1, Ordering::Relaxed);
metrics.cas_retries.fetch_add(1, Ordering::Relaxed);
by_backend.conflicts.fetch_add(1, Ordering::Relaxed);
by_backend.retries.fetch_add(1, Ordering::Relaxed);
}
Err(SrxError::Config(ConfigError::Invalid(
"failed to persist replay state after concurrent update retries".to_string(),
)))
}
fn resolve_hmac_key(replay: &ReplayConfig) -> Result<Vec<u8>> {
match &replay.integrity.hmac_key_provider {
ReplayHmacKeyProvider::StaticConfig => replay.integrity.hmac_key.clone().ok_or_else(|| {
SrxError::Config(ConfigError::MissingField(
"replay.integrity.hmac_key for StaticConfig provider".to_string(),
))
}),
ReplayHmacKeyProvider::EnvVar { name, encoding } => {
let val = std::env::var(name).map_err(|_| {
SrxError::Config(ConfigError::MissingField(format!(
"missing replay hmac env var: {name}"
)))
})?;
decode_key_material(&val, *encoding)
}
ReplayHmacKeyProvider::File { path, encoding } => {
let raw = fs::read_to_string(path).map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!(
"failed reading replay hmac key file '{}': {e}",
path.display()
)))
})?;
decode_key_material(raw.trim(), *encoding)
}
ReplayHmacKeyProvider::Custom { provider } => {
let guard = custom_hmac_providers().read().map_err(|_| {
SrxError::Config(ConfigError::Invalid(
"custom hmac provider registry poisoned".to_string(),
))
})?;
let Some(resolver) = guard.get(provider) else {
return Err(SrxError::Config(ConfigError::MissingField(format!(
"custom hmac provider '{provider}' is not registered"
))));
};
resolver.resolve_key(replay)
}
}
}
fn decode_key_material(input: &str, enc: ReplayKeyEncoding) -> Result<Vec<u8>> {
match enc {
ReplayKeyEncoding::RawUtf8 => Ok(input.as_bytes().to_vec()),
ReplayKeyEncoding::Hex => hex_decode(input),
ReplayKeyEncoding::Base64 => {
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, input.trim())
.map_err(|e| {
SrxError::Config(ConfigError::Invalid(format!(
"invalid base64 replay hmac key: {e}"
)))
})
}
}
}
fn hmac_sha256(key: &[u8], msg: &[u8]) -> Result<[u8; 32]> {
let mut mac = <HmacSha256 as KeyInit>::new_from_slice(key)
.map_err(|e| SrxError::Config(ConfigError::Invalid(format!("invalid HMAC key: {e}"))))?;
mac.update(msg);
Ok(mac.finalize().into_bytes().into())
}
fn hex_encode(bytes: &[u8]) -> String {
let mut out = String::with_capacity(bytes.len() * 2);
for b in bytes {
out.push(char::from(b"0123456789abcdef"[(b >> 4) as usize]));
out.push(char::from(b"0123456789abcdef"[(b & 0x0f) as usize]));
}
out
}
fn hex_decode(s: &str) -> Result<Vec<u8>> {
let v = s.trim();
if !v.len().is_multiple_of(2) {
return Err(SrxError::Config(ConfigError::Invalid(
"invalid hex replay hmac key: odd length".to_string(),
)));
}
let mut out = Vec::with_capacity(v.len() / 2);
let bytes = v.as_bytes();
let mut i = 0usize;
while i < bytes.len() {
let hi = decode_hex_nibble(bytes[i])?;
let lo = decode_hex_nibble(bytes[i + 1])?;
out.push((hi << 4) | lo);
i += 2;
}
Ok(out)
}
fn decode_hex_nibble(b: u8) -> Result<u8> {
match b {
b'0'..=b'9' => Ok(b - b'0'),
b'a'..=b'f' => Ok(10 + (b - b'a')),
b'A'..=b'F' => Ok(10 + (b - b'A')),
_ => Err(SrxError::Config(ConfigError::Invalid(
"invalid hex replay hmac key: non-hex character".to_string(),
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{
ReplayBackend, ReplayConfig, ReplayHmacKeyProvider, ReplayIntegrityConfig,
};
use std::path::PathBuf;
use std::sync::Arc;
fn replay_cfg(mode: ReplayIntegrityMode) -> ReplayConfig {
ReplayConfig {
persist_enabled: true,
backend: ReplayBackend::FileJson,
state_file: PathBuf::from(".srx/test_replay_state.json"),
storage_key: "test".to_string(),
integrity: ReplayIntegrityConfig {
mode,
hmac_key: Some(vec![7u8; 32]),
hmac_key_provider: ReplayHmacKeyProvider::StaticConfig,
},
}
}
#[test]
fn encode_decode_roundtrip_with_hmac() {
let cfg = replay_cfg(ReplayIntegrityMode::HmacSha256);
let state = ReplayState {
top: 42,
bitmap_words: vec![0u64; 16],
};
let raw = encode_replay_envelope(&cfg, "sess-a".to_string(), state.clone()).unwrap();
let restored = decode_replay_envelope(&cfg, "sess-a", &raw)
.unwrap()
.expect("state");
assert_eq!(restored, state);
}
#[test]
fn decode_ignores_session_mismatch() {
let cfg = replay_cfg(ReplayIntegrityMode::ChecksumSha256);
let state = ReplayState {
top: 9,
bitmap_words: vec![0u64; 16],
};
let raw = encode_replay_envelope(&cfg, "sess-a".to_string(), state).unwrap();
let restored = decode_replay_envelope(&cfg, "sess-b", &raw).unwrap();
assert!(restored.is_none());
}
#[test]
fn merge_and_persist_unions_remote_state() {
let tmp = tempfile::tempdir().unwrap();
let mut cfg = replay_cfg(ReplayIntegrityMode::ChecksumSha256);
cfg.state_file = tmp.path().join("replay_merge_state.json");
let storage = FileReplayStorage;
let remote = ReplayState {
top: 100,
bitmap_words: {
let mut v = vec![0u64; 16];
v[0] |= 1; v
},
};
let remote_raw = encode_replay_envelope(&cfg, "sess-a".to_string(), remote).unwrap();
storage.save_raw(&cfg, &remote_raw).unwrap();
let mut local = ReplayState {
top: 99,
bitmap_words: vec![0u64; 16],
};
local.bitmap_words[0] |= 1 << 0;
merge_and_persist_replay_state(&cfg, &storage, "sess-a", local).unwrap();
let raw = storage.load_raw(&cfg).unwrap().expect("stored");
let merged = decode_replay_envelope(&cfg, "sess-a", &raw)
.unwrap()
.expect("decoded");
assert_eq!(merged.top, 100);
}
#[test]
fn hmac_key_from_file_provider_works() {
let tmp = tempfile::NamedTempFile::new().unwrap();
fs::write(tmp.path(), "746573742d6b6579").unwrap();
let mut cfg = replay_cfg(ReplayIntegrityMode::HmacSha256);
cfg.integrity.hmac_key = None;
cfg.integrity.hmac_key_provider = ReplayHmacKeyProvider::File {
path: tmp.path().to_path_buf(),
encoding: ReplayKeyEncoding::Hex,
};
let state = ReplayState {
top: 3,
bitmap_words: vec![0u64; 16],
};
let raw = encode_replay_envelope(&cfg, "sess-z".to_string(), state.clone()).unwrap();
let got = decode_replay_envelope(&cfg, "sess-z", &raw)
.unwrap()
.expect("state");
assert_eq!(got, state);
}
#[test]
fn custom_provider_resolves_hmac_key() {
struct TestProvider;
impl CustomHmacKeyProvider for TestProvider {
fn resolve_key(&self, _replay: &ReplayConfig) -> Result<Vec<u8>> {
Ok(b"test-key".to_vec())
}
}
register_custom_hmac_key_provider("test-provider", Arc::new(TestProvider));
let mut cfg = replay_cfg(ReplayIntegrityMode::HmacSha256);
cfg.integrity.hmac_key = None;
cfg.integrity.hmac_key_provider = ReplayHmacKeyProvider::Custom {
provider: "test-provider".to_string(),
};
let state = ReplayState {
top: 4,
bitmap_words: vec![0u64; 16],
};
let raw = encode_replay_envelope(&cfg, "sess-custom".to_string(), state.clone()).unwrap();
let got = decode_replay_envelope(&cfg, "sess-custom", &raw)
.unwrap()
.expect("state");
assert_eq!(got, state);
}
#[test]
fn merge_and_persist_updates_cas_metrics() {
reset_replay_store_metrics_for_tests();
let tmp = tempfile::tempdir().unwrap();
let mut cfg = replay_cfg(ReplayIntegrityMode::ChecksumSha256);
cfg.state_file = tmp.path().join("replay_metrics_state.json");
let storage = FileReplayStorage;
let local = ReplayState {
top: 5,
bitmap_words: vec![0u64; 16],
};
merge_and_persist_replay_state(&cfg, &storage, "sess-m", local).unwrap();
let snap = replay_store_metrics_snapshot();
assert!(snap.cas_attempts >= 1);
assert!(snap.cas_successes >= 1);
}
}