use std::collections::Bound;
use std::fmt::Debug;
use std::ops::RangeBounds;
use async_trait::async_trait;
use crate::raft::Entry;
use crate::raft_types::LogIdOptionExt;
use crate::AppData;
use crate::AppDataResponse;
use crate::DefensiveError;
use crate::ErrorSubject;
use crate::HardState;
use crate::LogId;
use crate::RaftStorage;
use crate::StorageError;
use crate::Violation;
use crate::Wrapper;
#[async_trait]
pub trait DefensiveCheck<D, R, T>
where
D: AppData,
R: AppDataResponse,
T: RaftStorage<D, R>,
Self: Wrapper<D, R, T>,
{
fn set_defensive(&self, v: bool);
fn is_defensive(&self) -> bool;
async fn defensive_no_dirty_log(&self) -> Result<(), StorageError> {
if !self.is_defensive() {
return Ok(());
}
let last_log_id = self.inner().get_log_state().await?.last_log_id;
let (last_applied, _) = self.inner().last_applied_state().await?;
if last_log_id.index() > last_applied.index() && last_log_id < last_applied {
return Err(
DefensiveError::new(ErrorSubject::Log(last_log_id.unwrap()), Violation::DirtyLog {
higher_index_log_id: last_log_id.unwrap(),
lower_index_log_id: last_applied.unwrap(),
})
.into(),
);
}
Ok(())
}
async fn defensive_incremental_hard_state(&self, hs: &HardState) -> Result<(), StorageError> {
if !self.is_defensive() {
return Ok(());
}
let h = self.inner().read_hard_state().await?;
let curr = h.unwrap_or_default();
if hs.current_term < curr.current_term {
return Err(
DefensiveError::new(ErrorSubject::HardState, Violation::TermNotAscending {
curr: curr.current_term,
to: hs.current_term,
})
.into(),
);
}
if hs.current_term == curr.current_term && curr.voted_for.is_some() && hs.voted_for != curr.voted_for {
return Err(
DefensiveError::new(ErrorSubject::HardState, Violation::VotedForChanged {
curr,
to: hs.clone(),
})
.into(),
);
}
Ok(())
}
async fn defensive_consecutive_input(&self, entries: &[&Entry<D>]) -> Result<(), StorageError> {
if !self.is_defensive() {
return Ok(());
}
if entries.is_empty() {
return Ok(());
}
let mut prev_log_id = entries[0].log_id;
for e in entries.iter().skip(1) {
if e.log_id.index != prev_log_id.index + 1 {
return Err(DefensiveError::new(ErrorSubject::Logs, Violation::LogsNonConsecutive {
prev: Some(prev_log_id),
next: e.log_id,
})
.into());
}
prev_log_id = e.log_id;
}
Ok(())
}
async fn defensive_nonempty_input(&self, entries: &[&Entry<D>]) -> Result<(), StorageError> {
if !self.is_defensive() {
return Ok(());
}
if entries.is_empty() {
return Err(DefensiveError::new(ErrorSubject::Logs, Violation::LogsEmpty).into());
}
Ok(())
}
async fn defensive_append_log_index_is_last_plus_one(&self, entries: &[&Entry<D>]) -> Result<(), StorageError> {
if !self.is_defensive() {
return Ok(());
}
let last_id = self.inner().get_log_state().await?.last_log_id;
let first_id = entries[0].log_id;
if last_id.next_index() != first_id.index {
return Err(
DefensiveError::new(ErrorSubject::Log(first_id), Violation::LogsNonConsecutive {
prev: last_id,
next: first_id,
})
.into(),
);
}
Ok(())
}
async fn defensive_append_log_id_gt_last(&self, entries: &[&Entry<D>]) -> Result<(), StorageError> {
if !self.is_defensive() {
return Ok(());
}
let last_id = self.inner().get_log_state().await?.last_log_id;
let first_id = entries[0].log_id;
if last_id.is_some() && Some(first_id) <= last_id {
return Err(
DefensiveError::new(ErrorSubject::Log(first_id), Violation::LogsNonConsecutive {
prev: last_id,
next: first_id,
})
.into(),
);
}
Ok(())
}
async fn defensive_purge_applied_le_last_applied(&self, upto: LogId) -> Result<(), StorageError> {
let (last_applied, _) = self.inner().last_applied_state().await?;
if Some(upto.index) > last_applied.index() {
return Err(
DefensiveError::new(ErrorSubject::Log(upto), Violation::PurgeNonApplied {
last_applied,
purge_upto: upto,
})
.into(),
);
}
Ok(())
}
async fn defensive_delete_conflict_gt_last_applied(&self, since: LogId) -> Result<(), StorageError> {
if !self.is_defensive() {
return Ok(());
}
let (last_applied, _) = self.inner().last_applied_state().await?;
if Some(since.index) <= last_applied.index() {
return Err(
DefensiveError::new(ErrorSubject::Log(since), Violation::AppliedWontConflict {
last_applied,
first_conflict_log_id: since,
})
.into(),
);
}
Ok(())
}
async fn defensive_apply_index_is_last_applied_plus_one(&self, entries: &[&Entry<D>]) -> Result<(), StorageError> {
if !self.is_defensive() {
return Ok(());
}
let (last_id, _) = self.inner().last_applied_state().await?;
let first_id = entries[0].log_id;
if last_id.next_index() != first_id.index {
return Err(
DefensiveError::new(ErrorSubject::Apply(first_id), Violation::ApplyNonConsecutive {
prev: last_id,
next: first_id,
})
.into(),
);
}
Ok(())
}
async fn defensive_nonempty_range<RB: RangeBounds<u64> + Clone + Debug + Send>(
&self,
range: RB,
) -> Result<(), StorageError> {
if !self.is_defensive() {
return Ok(());
}
let start = match range.start_bound() {
Bound::Included(i) => Some(*i),
Bound::Excluded(i) => Some(*i + 1),
Bound::Unbounded => None,
};
let end = match range.end_bound() {
Bound::Included(i) => Some(*i),
Bound::Excluded(i) => Some(*i - 1),
Bound::Unbounded => None,
};
if start.is_none() || end.is_none() {
return Ok(());
}
if start > end {
return Err(DefensiveError::new(ErrorSubject::Logs, Violation::RangeEmpty { start, end }).into());
}
Ok(())
}
async fn defensive_half_open_range<RB: RangeBounds<u64> + Clone + Debug + Send>(
&self,
range: RB,
) -> Result<(), StorageError> {
if !self.is_defensive() {
return Ok(());
}
if let Bound::Unbounded = range.start_bound() {
return Ok(());
};
if let Bound::Unbounded = range.end_bound() {
return Ok(());
};
Err(DefensiveError::new(ErrorSubject::Logs, Violation::RangeNotHalfOpen {
start: range.start_bound().cloned(),
end: range.end_bound().cloned(),
})
.into())
}
async fn defensive_range_hits_logs<RB: RangeBounds<u64> + Debug + Send>(
&self,
range: RB,
logs: &[Entry<D>],
) -> Result<(), StorageError> {
if !self.is_defensive() {
return Ok(());
}
check_range_matches_entries(range, logs)?;
Ok(())
}
async fn defensive_apply_log_id_gt_last(&self, entries: &[&Entry<D>]) -> Result<(), StorageError> {
if !self.is_defensive() {
return Ok(());
}
let (last_id, _) = self.inner().last_applied_state().await?;
let first_id = entries[0].log_id;
if Some(first_id) <= last_id {
return Err(
DefensiveError::new(ErrorSubject::Apply(first_id), Violation::ApplyNonConsecutive {
prev: last_id,
next: first_id,
})
.into(),
);
}
Ok(())
}
}
pub fn check_range_matches_entries<D: AppData, RB: RangeBounds<u64> + Debug + Send>(
range: RB,
entries: &[Entry<D>],
) -> Result<(), StorageError> {
let want_first = match range.start_bound() {
Bound::Included(i) => Some(*i),
Bound::Excluded(i) => Some(*i + 1),
Bound::Unbounded => None,
};
let want_last = match range.end_bound() {
Bound::Included(i) => Some(*i),
Bound::Excluded(i) => Some(*i - 1),
Bound::Unbounded => None,
};
if want_first.is_some() && want_last.is_some() && want_first > want_last {
return Ok(());
}
{
let first = entries.first().map(|x| x.log_id.index);
if let Some(want) = want_first {
if first != want_first {
return Err(
DefensiveError::new(ErrorSubject::LogIndex(want), Violation::LogIndexNotFound {
want,
got: first,
})
.into(),
);
}
}
}
{
let last = entries.last().map(|x| x.log_id.index);
if let Some(want) = want_last {
if last != want_last {
return Err(
DefensiveError::new(ErrorSubject::LogIndex(want), Violation::LogIndexNotFound {
want,
got: last,
})
.into(),
);
}
}
}
Ok(())
}