use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use super::{CertificationConfig, CheckResult, ConsistencyResult};
#[derive(Clone)]
struct SimStore {
inner: Arc<Mutex<HashMap<String, String>>>,
}
impl SimStore {
fn new() -> Self {
SimStore {
inner: Arc::new(Mutex::new(HashMap::new())),
}
}
fn write(&self, key: &str, value: &str) -> Result<(), String> {
let mut guard = self
.inner
.lock()
.map_err(|e| format!("lock poisoned: {e}"))?;
guard.insert(key.to_string(), value.to_string());
Ok(())
}
fn read(&self, key: &str) -> Result<Option<String>, String> {
let guard = self
.inner
.lock()
.map_err(|e| format!("lock poisoned: {e}"))?;
Ok(guard.get(key).cloned())
}
}
pub fn certify(config: &CertificationConfig) -> ConsistencyResult {
let checks: Vec<CheckResult> = vec![
check_read_your_writes(config),
check_linearizability_probe(config),
check_convergence(config),
check_monotonic_read(config),
];
let passed = checks.iter().all(|c| c.passed);
let notes = if passed {
format!(
"All {} consistency checks passed ({} nodes, {} epochs).",
checks.len(),
config.node_count,
config.epochs
)
} else {
let failed: Vec<&str> = checks
.iter()
.filter(|c| !c.passed)
.map(|c| c.name.as_str())
.collect();
format!("Consistency violations: {}", failed.join(", "))
};
ConsistencyResult {
passed,
checks,
notes,
}
}
fn check_read_your_writes(config: &CertificationConfig) -> CheckResult {
let store = SimStore::new();
let iters = config.epochs.min(100);
let mut violation: Option<String> = None;
for i in 0..iters {
let key = format!("ryw-key-{i}");
let val = format!("ryw-val-{i}");
if let Err(e) = store.write(&key, &val) {
violation = Some(format!("Write failed: {e}"));
break;
}
match store.read(&key) {
Err(e) => {
violation = Some(format!("Read failed: {e}"));
break;
}
Ok(None) => {
violation = Some(format!("Key '{key}' not found after write"));
break;
}
Ok(Some(got)) if got != val => {
violation = Some(format!(
"Read-your-write violation at '{key}': wrote '{val}', read '{got}'"
));
break;
}
Ok(Some(_)) => {}
}
}
match violation {
Some(v) => CheckResult {
name: "read_your_writes".to_string(),
passed: false,
detail: v,
},
None => CheckResult {
name: "read_your_writes".to_string(),
passed: true,
detail: format!("Read-your-writes verified over {iters} key-value pairs"),
},
}
}
fn check_linearizability_probe(config: &CertificationConfig) -> CheckResult {
let n = config.node_count.max(2);
let root_store = SimStore::new();
let nodes: Vec<SimStore> = (0..n).map(|_| root_store.clone()).collect();
let iters = config.epochs.min(50);
let mut violation: Option<String> = None;
for i in 0..iters {
let key = format!("lin-key-{i}");
let val = format!("lin-val-{i}");
if let Err(e) = nodes[0].write(&key, &val) {
violation = Some(format!("Write to node 0 failed: {e}"));
break;
}
let reader_idx = 1 % n;
match nodes[reader_idx].read(&key) {
Err(e) => {
violation = Some(format!("Read from node {reader_idx} failed: {e}"));
break;
}
Ok(None) => {
violation = Some(format!(
"Linearizability violation: node {reader_idx} did not see key '{key}'"
));
break;
}
Ok(Some(got)) if got != val => {
violation = Some(format!(
"Linearizability violation at '{key}': expected '{val}', got '{got}'"
));
break;
}
Ok(Some(_)) => {}
}
}
match violation {
Some(v) => CheckResult {
name: "linearizability_probe".to_string(),
passed: false,
detail: v,
},
None => CheckResult {
name: "linearizability_probe".to_string(),
passed: true,
detail: format!("Linearizability probe: {iters} cross-node reads verified ({n} nodes)"),
},
}
}
fn check_convergence(config: &CertificationConfig) -> CheckResult {
let n = config.node_count.max(2);
let root_store = SimStore::new();
let nodes: Vec<SimStore> = (0..n).map(|_| root_store.clone()).collect();
let key = "convergence-key";
let mut last_written = String::new();
for i in 0..n {
let val = format!("version-{i}");
if let Err(e) = nodes[i].write(key, &val) {
return CheckResult {
name: "convergence".to_string(),
passed: false,
detail: format!("Write from node {i} failed: {e}"),
};
}
last_written = val;
}
let mut violation: Option<String> = None;
for (i, node) in nodes.iter().enumerate() {
match node.read(key) {
Err(e) => {
violation = Some(format!("Read from node {i} failed: {e}"));
break;
}
Ok(None) => {
violation = Some(format!("Node {i} lost the key after convergence"));
break;
}
Ok(Some(got)) if got != last_written => {
violation = Some(format!(
"Convergence failure at node {i}: expected '{last_written}', got '{got}'"
));
break;
}
Ok(Some(_)) => {}
}
}
match violation {
Some(v) => CheckResult {
name: "convergence".to_string(),
passed: false,
detail: v,
},
None => CheckResult {
name: "convergence".to_string(),
passed: true,
detail: format!(
"Convergence: {n} nodes agree on last-writer-wins value '{last_written}'"
),
},
}
}
fn check_monotonic_read(config: &CertificationConfig) -> CheckResult {
let store = SimStore::new();
let iters = config.epochs.min(50);
let key = "monotonic-key";
let mut last_version: i64 = -1;
let mut violation: Option<String> = None;
for i in 0..iters {
let val = format!("{i}");
if let Err(e) = store.write(key, &val) {
violation = Some(format!("Write failed at epoch {i}: {e}"));
break;
}
match store.read(key) {
Err(e) => {
violation = Some(format!("Read failed at epoch {i}: {e}"));
break;
}
Ok(None) => {
violation = Some(format!("Key missing at epoch {i}"));
break;
}
Ok(Some(got)) => {
let version: i64 = got.parse().unwrap_or(-1);
if version < last_version {
violation = Some(format!(
"Monotonic read violation: saw version {last_version} then {version}"
));
break;
}
last_version = version;
}
}
}
match violation {
Some(v) => CheckResult {
name: "monotonic_read".to_string(),
passed: false,
detail: v,
},
None => CheckResult {
name: "monotonic_read".to_string(),
passed: true,
detail: format!("Monotonic read: versions 0..{iters} read in non-decreasing order"),
},
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::certification::CertificationConfig;
fn default_config() -> CertificationConfig {
CertificationConfig::default()
}
#[test]
fn test_consistency_certify_passes_with_defaults() {
let cfg = default_config();
let result = certify(&cfg);
assert!(
result.passed,
"Consistency certification should pass: {:?}",
result.notes
);
}
#[test]
fn test_read_your_writes_passes() {
let cfg = default_config();
let check = check_read_your_writes(&cfg);
assert!(check.passed, "{}", check.detail);
}
#[test]
fn test_linearizability_probe_passes() {
let cfg = default_config();
let check = check_linearizability_probe(&cfg);
assert!(check.passed, "{}", check.detail);
}
#[test]
fn test_convergence_passes() {
let cfg = default_config();
let check = check_convergence(&cfg);
assert!(check.passed, "{}", check.detail);
}
#[test]
fn test_monotonic_read_passes() {
let cfg = default_config();
let check = check_monotonic_read(&cfg);
assert!(check.passed, "{}", check.detail);
}
#[test]
fn test_sim_store_basic() {
let store = SimStore::new();
store.write("k", "v").expect("write ok");
let got = store.read("k").expect("read ok");
assert_eq!(got.as_deref(), Some("v"));
}
#[test]
fn test_sim_store_missing_key() {
let store = SimStore::new();
let got = store.read("missing").expect("read ok");
assert!(got.is_none());
}
#[test]
fn test_consistency_result_has_checks() {
let cfg = default_config();
let result = certify(&cfg);
assert!(result.checks.len() >= 3);
}
}