use std::sync::Arc;
use std::time::Duration;
use dashmap::DashMap;
use freenet_stdlib::prelude::ContractInstanceId;
use tokio::time::Instant;
use crate::contract::storages::Storage;
use crate::util::time_source::TimeSource;
pub(crate) const BROKEN_INVARIANT_TTL: Duration = Duration::from_secs(300);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BrokenInvariant {
NonIdempotent,
}
impl BrokenInvariant {
fn to_byte(self) -> u8 {
match self {
BrokenInvariant::NonIdempotent => 0,
}
}
fn from_byte(b: u8) -> Option<Self> {
match b {
0 => Some(BrokenInvariant::NonIdempotent),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy)]
struct FlagEntry {
kind: BrokenInvariant,
recorded_at: Instant,
}
pub(crate) struct BrokenInvariantsTracker {
flags: Arc<DashMap<ContractInstanceId, FlagEntry>>,
storage: std::sync::OnceLock<Storage>,
time_source: Arc<dyn TimeSource + Send + Sync>,
}
impl BrokenInvariantsTracker {
pub fn new(time_source: Arc<dyn TimeSource + Send + Sync>) -> Self {
Self {
flags: Arc::new(DashMap::new()),
storage: std::sync::OnceLock::new(),
time_source,
}
}
pub fn is_broken(&self, id: &ContractInstanceId) -> bool {
let now = self.time_source.now();
match self.flags.get(id) {
Some(entry) => now.saturating_duration_since(entry.recorded_at) < BROKEN_INVARIANT_TTL,
None => false,
}
}
#[cfg(test)]
pub fn get(&self, id: &ContractInstanceId) -> Option<BrokenInvariant> {
let now = self.time_source.now();
self.flags.get(id).and_then(|entry| {
if now.saturating_duration_since(entry.recorded_at) < BROKEN_INVARIANT_TTL {
Some(entry.kind)
} else {
None
}
})
}
pub fn record(&self, id: ContractInstanceId, kind: BrokenInvariant) {
let recorded_at = self.time_source.now();
let was_new = self
.flags
.insert(id, FlagEntry { kind, recorded_at })
.is_none();
if was_new {
tracing::warn!(
contract = %id,
invariant = ?kind,
event = "broken_invariant_detected",
"Marking contract as broken — gating outbound broadcast and merge propagation"
);
#[cfg(feature = "redb")]
if let Some(storage) = self.storage.get() {
if let Err(e) = storage.store_broken_invariant(&id, kind.to_byte()) {
tracing::warn!(
contract = %id,
error = %e,
"Failed to persist broken-invariant flag (in-memory flag still active)"
);
}
}
}
}
#[allow(dead_code)] pub fn clear(&self, id: &ContractInstanceId) -> Option<BrokenInvariant> {
let previous = self.flags.remove(id).map(|(_, v)| v.kind);
if previous.is_some() {
self.remove_from_storage(id);
tracing::warn!(
contract = %id,
event = "broken_invariant_cleared",
"Operator cleared broken-invariant flag — outbound broadcast re-enabled"
);
}
previous
}
pub fn cleanup(&self) {
let now = self.time_source.now();
let candidates: Vec<ContractInstanceId> = self
.flags
.iter()
.filter(|e| {
now.saturating_duration_since(e.value().recorded_at) >= BROKEN_INVARIANT_TTL
})
.map(|e| *e.key())
.collect();
for id in candidates {
let removed = self.flags.remove_if(&id, |_, entry| {
now.saturating_duration_since(entry.recorded_at) >= BROKEN_INVARIANT_TTL
});
if removed.is_some() {
self.remove_from_storage(&id);
}
}
}
fn remove_from_storage(&self, id: &ContractInstanceId) {
#[cfg(feature = "redb")]
if let Some(storage) = self.storage.get() {
if let Err(e) = storage.remove_broken_invariant(id) {
tracing::warn!(
contract = %id,
error = %e,
"Failed to remove persisted broken-invariant flag (in-memory flag already cleared)"
);
}
}
#[cfg(not(feature = "redb"))]
let _ = id;
}
pub fn set_storage(&self, storage: Storage) {
if self.storage.set(storage.clone()).is_err() {
tracing::warn!("BrokenInvariantsTracker storage already set; ignoring re-init");
return;
}
#[cfg(feature = "redb")]
match storage.load_all_broken_invariants() {
Ok(entries) => {
let recorded_at = self.time_source.now();
for (id, byte) in entries {
if let Some(kind) = BrokenInvariant::from_byte(byte) {
self.flags.insert(id, FlagEntry { kind, recorded_at });
} else {
tracing::warn!(
contract = %id,
byte,
"Skipping unknown broken-invariant byte on load"
);
}
}
tracing::debug!(
count = self.flags.len(),
"Loaded broken-invariant flags from storage"
);
}
Err(e) => {
tracing::warn!(error = %e, "Failed to load broken-invariant flags from storage");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::util::time_source::SharedMockTimeSource;
fn fake_id(seed: u8) -> ContractInstanceId {
let mut bytes = [0u8; 32];
bytes[0] = seed;
ContractInstanceId::new(bytes)
}
fn mk_tracker() -> (BrokenInvariantsTracker, SharedMockTimeSource) {
let ts = SharedMockTimeSource::new();
(BrokenInvariantsTracker::new(Arc::new(ts.clone())), ts)
}
#[test]
fn record_then_query_returns_true() {
let (t, _ts) = mk_tracker();
let id = fake_id(1);
assert!(!t.is_broken(&id));
t.record(id, BrokenInvariant::NonIdempotent);
assert!(t.is_broken(&id));
assert_eq!(t.get(&id), Some(BrokenInvariant::NonIdempotent));
}
#[test]
fn record_is_idempotent() {
let (t, _ts) = mk_tracker();
let id = fake_id(2);
t.record(id, BrokenInvariant::NonIdempotent);
t.record(id, BrokenInvariant::NonIdempotent);
assert!(t.is_broken(&id));
}
#[test]
fn unrelated_contracts_unaffected() {
let (t, _ts) = mk_tracker();
let broken = fake_id(3);
let healthy = fake_id(4);
t.record(broken, BrokenInvariant::NonIdempotent);
assert!(t.is_broken(&broken));
assert!(!t.is_broken(&healthy));
}
#[test]
fn clear_returns_previous_and_unsets() {
let (t, _ts) = mk_tracker();
let id = fake_id(5);
assert_eq!(t.clear(&id), None);
t.record(id, BrokenInvariant::NonIdempotent);
assert!(t.is_broken(&id));
let prev = t.clear(&id);
assert_eq!(prev, Some(BrokenInvariant::NonIdempotent));
assert!(
!t.is_broken(&id),
"after clear the contract is no longer broken"
);
assert_eq!(t.clear(&id), None);
}
#[test]
fn flag_expires_after_ttl() {
let (t, ts) = mk_tracker();
let id = fake_id(6);
t.record(id, BrokenInvariant::NonIdempotent);
assert!(t.is_broken(&id), "freshly recorded flag is active");
ts.advance_time(BROKEN_INVARIANT_TTL - Duration::from_secs(1));
assert!(t.is_broken(&id), "flag still active just before TTL");
ts.advance_time(Duration::from_secs(2));
assert!(
!t.is_broken(&id),
"flag must expire after TTL so a false positive self-heals"
);
assert_eq!(t.get(&id), None, "expired flag reports no kind");
}
#[test]
fn cleanup_reclaims_expired_entries() {
let (t, ts) = mk_tracker();
let id = fake_id(7);
t.record(id, BrokenInvariant::NonIdempotent);
assert_eq!(t.flags.len(), 1);
ts.advance_time(BROKEN_INVARIANT_TTL / 2);
t.cleanup();
assert_eq!(t.flags.len(), 1, "non-expired entry retained by cleanup");
ts.advance_time(BROKEN_INVARIANT_TTL);
t.cleanup();
assert_eq!(t.flags.len(), 0, "expired entry reclaimed by cleanup");
}
#[test]
fn record_refreshes_ttl() {
let (t, ts) = mk_tracker();
let id = fake_id(8);
t.record(id, BrokenInvariant::NonIdempotent);
ts.advance_time(BROKEN_INVARIANT_TTL - Duration::from_secs(1));
t.record(id, BrokenInvariant::NonIdempotent);
ts.advance_time(Duration::from_secs(2));
assert!(
t.is_broken(&id),
"re-recording must refresh the expiry window"
);
}
#[test]
fn byte_roundtrip_stable() {
let kinds: &[BrokenInvariant] = &[BrokenInvariant::NonIdempotent];
for kind in kinds {
assert_eq!(BrokenInvariant::from_byte(kind.to_byte()), Some(*kind));
}
assert_eq!(BrokenInvariant::from_byte(255), None);
}
}