use super::failover::NodeRole;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TailRecord {
pub lsn: u64,
pub term: u64,
pub payload: Vec<u8>,
}
impl TailRecord {
pub fn new(lsn: u64, term: u64, payload: impl Into<Vec<u8>>) -> Self {
Self {
lsn,
term,
payload: payload.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DivergentTail {
pub common_point_lsn: u64,
pub to_lsn: u64,
pub records: Vec<TailRecord>,
}
impl DivergentTail {
pub fn span_lsns(&self) -> u64 {
self.to_lsn.saturating_sub(self.common_point_lsn)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RollbackPlan {
pub recover_to_lsn: u64,
pub local_frontier: u64,
pub commit_watermark: u64,
pub tail_lsns: u64,
}
impl RollbackPlan {
pub fn compute(req: &RollbackRequest) -> Result<Self, RollbackError> {
if req.common_point < req.commit_watermark {
return Err(RollbackError::WatermarkViolation {
common_point: req.common_point,
commit_watermark: req.commit_watermark,
});
}
Ok(Self {
recover_to_lsn: req.common_point,
local_frontier: req.local_frontier,
commit_watermark: req.commit_watermark,
tail_lsns: req.local_frontier.saturating_sub(req.common_point),
})
}
pub fn has_divergent_tail(&self) -> bool {
self.local_frontier > self.recover_to_lsn
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RollbackRequest {
pub local_frontier: u64,
pub common_point: u64,
pub commit_watermark: u64,
pub new_primary_addr: String,
pub new_term: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RollbackEvent {
pub common_point_lsn: u64,
pub tail_to_lsn: u64,
pub tail_lsns: u64,
pub commit_watermark: u64,
pub rollback_file: String,
pub new_primary_addr: String,
pub new_term: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RollbackOutcome {
pub recovered_to_lsn: u64,
pub tail_lsns: u64,
pub rollback_file: Option<String>,
pub event_fired: bool,
pub role: NodeRole,
}
impl RollbackOutcome {
pub fn rolled_back_tail(&self) -> bool {
self.tail_lsns > 0
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RollbackError {
WatermarkViolation {
common_point: u64,
commit_watermark: u64,
},
TailPersistFailed { reason: String },
RecoverFailed { target_lsn: u64, reason: String },
}
impl std::fmt::Display for RollbackError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RollbackError::WatermarkViolation {
common_point,
commit_watermark,
} => write!(
f,
"auto-rollback refused: common point {common_point} is below the commit watermark \
{commit_watermark}; recovering to it would roll back committed data",
),
RollbackError::TailPersistFailed { reason } => write!(
f,
"auto-rollback aborted: could not persist divergent tail to a rollback file \
({reason}); nothing was rolled back",
),
RollbackError::RecoverFailed { target_lsn, reason } => write!(
f,
"auto-rollback failed: recover-to-LSN {target_lsn} over the MVCC history store \
failed ({reason}); the divergent tail was preserved but the timeline was not \
rolled back",
),
}
}
}
impl std::error::Error for RollbackError {}
pub trait RollbackTransport {
fn read_divergent_tail(&mut self, from_exclusive: u64, to_inclusive: u64) -> Vec<TailRecord>;
fn persist_rollback_file(&mut self, tail: &DivergentTail) -> Result<String, String>;
fn recover_to_lsn(&mut self, target_lsn: u64) -> Result<(), String>;
fn emit_rollback_event(&mut self, event: RollbackEvent);
fn rejoin_as_replica(&mut self, primary_addr: &str, term: u64);
}
pub struct RollbackCoordinator;
impl RollbackCoordinator {
pub fn run(
req: &RollbackRequest,
tx: &mut dyn RollbackTransport,
) -> Result<RollbackOutcome, RollbackError> {
let plan = RollbackPlan::compute(req)?;
let role = NodeRole::Replica {
primary_addr: req.new_primary_addr.clone(),
term: req.new_term,
};
if !plan.has_divergent_tail() {
tx.rejoin_as_replica(&req.new_primary_addr, req.new_term);
return Ok(RollbackOutcome {
recovered_to_lsn: plan.recover_to_lsn,
tail_lsns: 0,
rollback_file: None,
event_fired: false,
role,
});
}
let records = tx.read_divergent_tail(plan.recover_to_lsn, plan.local_frontier);
let tail = DivergentTail {
common_point_lsn: plan.recover_to_lsn,
to_lsn: plan.local_frontier,
records,
};
let rollback_file = tx
.persist_rollback_file(&tail)
.map_err(|reason| RollbackError::TailPersistFailed { reason })?;
tx.recover_to_lsn(plan.recover_to_lsn)
.map_err(|reason| RollbackError::RecoverFailed {
target_lsn: plan.recover_to_lsn,
reason,
})?;
tx.emit_rollback_event(RollbackEvent {
common_point_lsn: plan.recover_to_lsn,
tail_to_lsn: plan.local_frontier,
tail_lsns: plan.tail_lsns,
commit_watermark: plan.commit_watermark,
rollback_file: rollback_file.clone(),
new_primary_addr: req.new_primary_addr.clone(),
new_term: req.new_term,
});
tx.rejoin_as_replica(&req.new_primary_addr, req.new_term);
Ok(RollbackOutcome {
recovered_to_lsn: plan.recover_to_lsn,
tail_lsns: plan.tail_lsns,
rollback_file: Some(rollback_file),
event_fired: true,
role,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
struct FakeTransport {
available_tail: Vec<TailRecord>,
persist_should_fail: bool,
recover_should_fail: bool,
persisted: Option<DivergentTail>,
recovered_to: Option<u64>,
emitted: Option<RollbackEvent>,
rejoined: Option<(String, u64)>,
order: Vec<&'static str>,
}
impl FakeTransport {
fn new(available_tail: Vec<TailRecord>) -> Self {
Self {
available_tail,
persist_should_fail: false,
recover_should_fail: false,
persisted: None,
recovered_to: None,
emitted: None,
rejoined: None,
order: Vec::new(),
}
}
}
impl RollbackTransport for FakeTransport {
fn read_divergent_tail(
&mut self,
from_exclusive: u64,
to_inclusive: u64,
) -> Vec<TailRecord> {
self.order.push("read");
self.available_tail
.iter()
.filter(|r| r.lsn > from_exclusive && r.lsn <= to_inclusive)
.cloned()
.collect()
}
fn persist_rollback_file(&mut self, tail: &DivergentTail) -> Result<String, String> {
self.order.push("persist");
if self.persist_should_fail {
return Err("disk full".to_string());
}
self.persisted = Some(tail.clone());
Ok(format!(
"/data/rollback/lsn-{}-{}.rbk",
tail.common_point_lsn, tail.to_lsn
))
}
fn recover_to_lsn(&mut self, target_lsn: u64) -> Result<(), String> {
self.order.push("recover");
if self.recover_should_fail {
return Err("history truncated".to_string());
}
self.recovered_to = Some(target_lsn);
Ok(())
}
fn emit_rollback_event(&mut self, event: RollbackEvent) {
self.order.push("emit");
self.emitted = Some(event);
}
fn rejoin_as_replica(&mut self, primary_addr: &str, term: u64) {
self.order.push("rejoin");
self.rejoined = Some((primary_addr.to_string(), term));
}
}
fn request(local_frontier: u64, common_point: u64, watermark: u64) -> RollbackRequest {
RollbackRequest {
local_frontier,
common_point,
commit_watermark: watermark,
new_primary_addr: "http://node-b:50051".to_string(),
new_term: 8,
}
}
fn tail(lsns: &[u64], term: u64) -> Vec<TailRecord> {
lsns.iter()
.map(|lsn| TailRecord::new(*lsn, term, vec![*lsn as u8]))
.collect()
}
#[test]
fn plan_recovers_to_common_point_and_sizes_the_tail() {
let plan = RollbackPlan::compute(&request(230, 200, 200)).expect("valid plan");
assert_eq!(
plan.recover_to_lsn, 200,
"recover target is the common point"
);
assert_eq!(plan.tail_lsns, 30, "tail spans common_point..frontier");
assert!(plan.has_divergent_tail());
}
#[test]
fn plan_with_common_point_above_watermark_is_allowed() {
let plan = RollbackPlan::compute(&request(300, 250, 200)).expect("valid plan");
assert_eq!(plan.recover_to_lsn, 250);
assert_eq!(plan.tail_lsns, 50);
}
#[test]
fn plan_refuses_common_point_below_watermark() {
let err = RollbackPlan::compute(&request(300, 150, 200)).expect_err("must refuse");
assert_eq!(
err,
RollbackError::WatermarkViolation {
common_point: 150,
commit_watermark: 200,
}
);
}
#[test]
fn plan_at_watermark_is_the_inclusive_floor() {
let plan = RollbackPlan::compute(&request(220, 200, 200)).expect("valid at floor");
assert_eq!(plan.recover_to_lsn, 200);
assert_eq!(plan.tail_lsns, 20);
}
#[test]
fn run_preserves_tail_then_recovers_then_emits_then_rejoins() {
let mut tx = FakeTransport::new(tail(&[201, 210, 230], 7));
let outcome =
RollbackCoordinator::run(&request(230, 200, 200), &mut tx).expect("rollback succeeds");
assert_eq!(outcome.recovered_to_lsn, 200);
assert_eq!(outcome.tail_lsns, 30);
assert!(outcome.rolled_back_tail());
let persisted = tx.persisted.as_ref().expect("tail persisted");
assert_eq!(persisted.common_point_lsn, 200);
assert_eq!(persisted.to_lsn, 230);
assert_eq!(persisted.records, tail(&[201, 210, 230], 7));
assert_eq!(
outcome.rollback_file.as_deref(),
Some("/data/rollback/lsn-200-230.rbk")
);
assert_eq!(tx.recovered_to, Some(200));
assert!(outcome.event_fired);
let ev = tx.emitted.as_ref().expect("event emitted");
assert_eq!(ev.common_point_lsn, 200);
assert_eq!(ev.tail_to_lsn, 230);
assert_eq!(ev.tail_lsns, 30);
assert_eq!(ev.commit_watermark, 200);
assert_eq!(ev.rollback_file, "/data/rollback/lsn-200-230.rbk");
assert_eq!(ev.new_term, 8);
assert_eq!(tx.rejoined, Some(("http://node-b:50051".to_string(), 8)));
assert_eq!(
outcome.role,
NodeRole::Replica {
primary_addr: "http://node-b:50051".to_string(),
term: 8,
}
);
assert_eq!(
tx.order,
vec!["read", "persist", "recover", "emit", "rejoin"]
);
}
#[test]
fn run_with_no_tail_just_rejoins() {
let mut tx = FakeTransport::new(vec![]);
let outcome =
RollbackCoordinator::run(&request(200, 200, 200), &mut tx).expect("clean rejoin");
assert_eq!(outcome.tail_lsns, 0);
assert!(!outcome.rolled_back_tail());
assert!(!outcome.event_fired, "no event when nothing is discarded");
assert_eq!(outcome.rollback_file, None);
assert!(tx.persisted.is_none(), "nothing persisted");
assert!(tx.recovered_to.is_none(), "nothing recovered");
assert!(tx.emitted.is_none(), "no operator event");
assert_eq!(tx.rejoined, Some(("http://node-b:50051".to_string(), 8)));
assert_eq!(tx.order, vec!["rejoin"]);
}
#[test]
fn run_with_frontier_below_common_point_is_a_clean_rejoin() {
let mut tx = FakeTransport::new(vec![]);
let outcome =
RollbackCoordinator::run(&request(180, 200, 150), &mut tx).expect("clean rejoin");
assert_eq!(outcome.recovered_to_lsn, 200);
assert_eq!(outcome.tail_lsns, 0);
assert!(!outcome.event_fired);
assert_eq!(tx.order, vec!["rejoin"]);
}
#[test]
fn run_refuses_when_common_point_below_watermark_and_touches_nothing() {
let mut tx = FakeTransport::new(tail(&[160, 200, 300], 7));
let err = RollbackCoordinator::run(&request(300, 150, 200), &mut tx)
.expect_err("must refuse to cross the watermark");
assert!(matches!(err, RollbackError::WatermarkViolation { .. }));
assert!(tx.persisted.is_none());
assert!(tx.recovered_to.is_none());
assert!(tx.emitted.is_none());
assert!(tx.rejoined.is_none());
assert!(tx.order.is_empty());
}
#[test]
fn run_aborts_without_recovering_when_tail_cannot_be_persisted() {
let mut tx = FakeTransport::new(tail(&[210, 230], 7));
tx.persist_should_fail = true;
let err = RollbackCoordinator::run(&request(230, 200, 200), &mut tx)
.expect_err("must abort when persist fails");
assert!(matches!(err, RollbackError::TailPersistFailed { .. }));
assert!(tx.recovered_to.is_none(), "must not roll back the timeline");
assert!(tx.emitted.is_none());
assert!(tx.rejoined.is_none());
assert_eq!(tx.order, vec!["read", "persist"]);
}
#[test]
fn run_surfaces_recover_failure_after_preserving_the_tail() {
let mut tx = FakeTransport::new(tail(&[210, 230], 7));
tx.recover_should_fail = true;
let err = RollbackCoordinator::run(&request(230, 200, 200), &mut tx)
.expect_err("recover failure surfaces");
match err {
RollbackError::RecoverFailed { target_lsn, .. } => assert_eq!(target_lsn, 200),
other => panic!("expected RecoverFailed, got {other:?}"),
}
assert!(tx.persisted.is_some(), "tail preserved before recover");
assert!(
tx.emitted.is_none(),
"no completion event on failed recover"
);
assert!(
tx.rejoined.is_none(),
"must not rejoin after a failed recover"
);
assert_eq!(tx.order, vec!["read", "persist", "recover"]);
}
#[test]
fn span_lsns_counts_the_removed_range() {
let t = DivergentTail {
common_point_lsn: 200,
to_lsn: 230,
records: tail(&[210, 230], 7),
};
assert_eq!(t.span_lsns(), 30);
}
}