#![allow(clippy::result_large_err)]
use crate::RejectReason;
use crate::combinator::retry::RetryPolicy;
use crate::decoding::{DecodingConfig, DecodingPipeline, SymbolAcceptResult};
use crate::error::{Error, ErrorKind};
use crate::security::AuthenticatedSymbol;
use crate::security::SecurityContext;
use crate::security::tag::AuthenticationTag;
use crate::types::symbol::{ObjectParams, Symbol};
use crate::types::{RegionId, Time};
use std::collections::{HashMap, HashSet};
use std::time::Duration;
use super::snapshot::RegionSnapshot;
#[derive(Debug, Clone)]
pub enum RecoveryTrigger {
QuorumLost {
region_id: RegionId,
available_replicas: Vec<String>,
required_quorum: u32,
},
NodeRestart {
region_id: RegionId,
last_known_sequence: u64,
},
ManualTrigger {
region_id: RegionId,
initiator: String,
reason: Option<String>,
},
InconsistencyDetected {
region_id: RegionId,
local_sequence: u64,
remote_sequence: u64,
},
}
impl RecoveryTrigger {
#[must_use]
pub fn region_id(&self) -> RegionId {
match self {
Self::QuorumLost { region_id, .. }
| Self::NodeRestart { region_id, .. }
| Self::ManualTrigger { region_id, .. }
| Self::InconsistencyDetected { region_id, .. } => *region_id,
}
}
#[must_use]
pub fn is_critical(&self) -> bool {
matches!(
self,
Self::QuorumLost { .. } | Self::InconsistencyDetected { .. }
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CollectionConsistency {
Any,
Quorum,
All,
}
#[derive(Debug, Clone)]
pub struct RecoveryConfig {
pub min_symbols: u32,
pub recovery_timeout: Duration,
pub replica_timeout: Duration,
pub max_concurrent_requests: usize,
pub collection_consistency: CollectionConsistency,
pub allow_partial: bool,
pub retry_policy: RetryPolicy,
pub max_attempts: u32,
}
impl Default for RecoveryConfig {
fn default() -> Self {
Self {
min_symbols: 0,
recovery_timeout: Duration::from_mins(1),
replica_timeout: Duration::from_secs(5),
max_concurrent_requests: 10,
collection_consistency: CollectionConsistency::Quorum,
allow_partial: false,
retry_policy: RetryPolicy::new().with_max_attempts(3),
max_attempts: 3,
}
}
}
#[derive(Debug, Clone)]
pub struct CollectedSymbol {
pub symbol: Symbol,
pub tag: AuthenticationTag,
pub source_replica: String,
pub collected_at: Time,
pub verified: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecoveryPhase {
Initializing,
Collecting,
Verifying,
Decoding,
Applying,
Complete,
Failed,
}
#[derive(Debug, Clone)]
pub struct RecoveryProgress {
pub started_at: Time,
pub symbols_needed: u32,
pub symbols_collected: u32,
pub replicas_queried: u32,
pub replicas_responded: u32,
pub phase: RecoveryPhase,
}
#[derive(Debug, Default)]
pub struct CollectionMetrics {
pub symbols_requested: u64,
pub symbols_received: u64,
pub symbols_duplicate: u64,
pub symbols_corrupt: u64,
pub requests_sent: u64,
pub requests_successful: u64,
pub requests_failed: u64,
pub requests_timeout: u64,
}
pub struct RecoveryCollector {
config: RecoveryConfig,
collected: Vec<CollectedSymbol>,
symbol_to_idx: HashMap<(u8, u32), usize>,
pub object_params: Option<ObjectParams>,
progress: RecoveryProgress,
pub metrics: CollectionMetrics,
cancelled: bool,
}
impl RecoveryCollector {
fn required_symbols_for_decode(&self, params: &ObjectParams) -> usize {
params
.min_symbols_for_decode()
.saturating_add(self.config.min_symbols) as usize
}
fn block_symbol_requirements(params: &ObjectParams) -> Vec<usize> {
if params.object_size == 0 || params.symbol_size == 0 || params.source_blocks == 0 {
return Vec::new();
}
let symbol_size = u64::from(params.symbol_size);
let max_block_size = u64::from(params.symbols_per_block) * symbol_size;
if max_block_size == 0 {
return Vec::new();
}
let mut requirements = Vec::with_capacity(usize::from(params.source_blocks));
for block in 0..params.source_blocks {
let start = u64::from(block) * max_block_size;
if start >= params.object_size {
break;
}
let remaining = params.object_size - start;
let block_size = remaining.min(max_block_size);
requirements.push(block_size.div_ceil(symbol_size) as usize);
}
requirements
}
#[must_use]
pub fn new(config: RecoveryConfig) -> Self {
Self {
config,
collected: Vec::with_capacity(64),
symbol_to_idx: HashMap::with_capacity(64),
object_params: None,
progress: RecoveryProgress {
started_at: Time::ZERO,
symbols_needed: 0,
symbols_collected: 0,
replicas_queried: 0,
replicas_responded: 0,
phase: RecoveryPhase::Initializing,
},
metrics: CollectionMetrics::default(),
cancelled: false,
}
}
#[must_use]
pub fn progress(&self) -> &RecoveryProgress {
&self.progress
}
#[must_use]
pub fn symbols(&self) -> &[CollectedSymbol] {
&self.collected
}
#[must_use]
pub fn can_decode(&self) -> bool {
let Some(params) = &self.object_params else {
return false;
};
if self.collected.len() < self.required_symbols_for_decode(params) {
return false;
}
let block_requirements = Self::block_symbol_requirements(params);
if block_requirements.is_empty() {
return true;
}
let mut block_counts = vec![0usize; block_requirements.len()];
for collected in &self.collected {
let block = usize::from(collected.symbol.sbn());
if let Some(count) = block_counts.get_mut(block) {
*count += 1;
}
}
block_counts
.iter()
.zip(block_requirements.iter())
.all(|(have, need)| *have >= *need)
}
pub fn cancel(&mut self) {
self.cancelled = true;
}
fn reset_for_attempt(&mut self, params: ObjectParams) {
self.collected.clear();
self.symbol_to_idx.clear();
let symbols_needed = self.required_symbols_for_decode(¶ms) as u32;
self.object_params = Some(params);
self.progress.symbols_needed = symbols_needed;
self.progress.symbols_collected = 0;
self.progress.replicas_queried = 0;
self.progress.replicas_responded = 0;
self.progress.phase = RecoveryPhase::Collecting;
self.metrics = CollectionMetrics::default();
}
#[inline]
pub fn add_collected(&mut self, cs: CollectedSymbol) -> bool {
let symbol_key = (cs.symbol.sbn(), cs.symbol.esi());
if let Some(&idx) = self.symbol_to_idx.get(&symbol_key) {
if !self.collected[idx].verified && cs.verified {
self.collected[idx] = cs;
return true;
}
self.metrics.symbols_duplicate += 1;
return false;
}
let idx = self.collected.len();
self.symbol_to_idx.insert(symbol_key, idx);
self.metrics.symbols_received += 1;
self.progress.symbols_collected += 1;
self.collected.push(cs);
true
}
#[inline]
pub fn add_collected_with_verify(&mut self, cs: CollectedSymbol) -> Result<bool, Error> {
if let Some(params) = &self.object_params {
if cs.symbol.object_id() != params.object_id {
self.metrics.symbols_corrupt += 1;
return Err(Error::new(ErrorKind::CorruptedSymbol).with_message(format!(
"symbol object {} does not match recovery object {}",
cs.symbol.object_id(),
params.object_id
)));
}
if u16::from(cs.symbol.sbn()) >= params.source_blocks {
self.metrics.symbols_corrupt += 1;
return Err(Error::new(ErrorKind::CorruptedSymbol).with_message(format!(
"SBN {} exceeds expected source block range for object",
cs.symbol.sbn()
)));
}
let max_expected = params
.total_source_symbols()
.saturating_add(self.config.min_symbols);
if cs.symbol.esi() > max_expected.saturating_add(50_000) {
self.metrics.symbols_corrupt += 1;
return Err(Error::new(ErrorKind::CorruptedSymbol).with_message(format!(
"ESI {} exceeds expected range for object",
cs.symbol.esi()
)));
}
}
Ok(self.add_collected(cs))
}
}
impl std::fmt::Debug for RecoveryCollector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RecoveryCollector")
.field("config", &self.config)
.field("collected", &self.collected.len())
.field("symbol_to_idx", &self.symbol_to_idx.len())
.field("object_params", &self.object_params)
.field("phase", &self.progress.phase)
.field("metrics", &self.metrics)
.field("cancelled", &self.cancelled)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct RecoveryDecodingConfig {
pub verify_integrity: bool,
pub auth_context: Option<SecurityContext>,
pub max_decode_attempts: u32,
pub allow_partial_decode: bool,
}
impl Default for RecoveryDecodingConfig {
fn default() -> Self {
Self {
verify_integrity: false,
auth_context: None,
max_decode_attempts: 3,
allow_partial_decode: false,
}
}
}
#[derive(Debug)]
enum DecoderState {
Accumulating { received: u32 },
Complete,
Failed,
}
pub struct StateDecoder {
config: RecoveryDecodingConfig,
decoder_state: DecoderState,
symbols: Vec<AuthenticatedSymbol>,
seen_symbols: HashSet<(u8, u32)>,
}
#[inline]
fn saturating_symbol_count(value: usize) -> u32 {
u32::try_from(value).unwrap_or(u32::MAX)
}
impl StateDecoder {
#[must_use]
pub fn new(config: RecoveryDecodingConfig) -> Self {
Self {
config,
decoder_state: DecoderState::Accumulating { received: 0 },
symbols: Vec::with_capacity(64),
seen_symbols: HashSet::with_capacity(64),
}
}
pub fn add_symbol(&mut self, auth_symbol: &AuthenticatedSymbol) -> Result<(), Error> {
let symbol_key = (auth_symbol.symbol().sbn(), auth_symbol.symbol().esi());
if self.seen_symbols.contains(&symbol_key) {
return Ok(()); }
self.seen_symbols.insert(symbol_key);
self.symbols.push(auth_symbol.clone());
if let DecoderState::Accumulating { received } = &mut self.decoder_state {
*received = saturating_symbol_count(self.symbols.len());
}
Ok(())
}
#[must_use]
pub fn can_decode(&self) -> bool {
!self.symbols.is_empty()
}
#[must_use]
pub fn symbols_received(&self) -> u32 {
saturating_symbol_count(self.symbols.len())
}
#[must_use]
pub fn symbols_needed(&self, params: &ObjectParams) -> u32 {
params.min_symbols_for_decode()
}
pub fn reset(&mut self) {
self.symbols.clear();
self.seen_symbols.clear();
self.decoder_state = DecoderState::Accumulating { received: 0 };
}
pub fn decode(&mut self, params: &ObjectParams) -> Result<Vec<u8>, Error> {
let k = params.min_symbols_for_decode();
if self.symbols.len() < k as usize {
self.decoder_state = DecoderState::Failed;
return Err(Error::insufficient_symbols(
saturating_symbol_count(self.symbols.len()),
k,
));
}
let config = DecodingConfig {
symbol_size: params.symbol_size,
max_block_size: usize::from(params.symbols_per_block) * usize::from(params.symbol_size),
repair_overhead: 1.0,
min_overhead: 0,
max_buffered_symbols: 0,
block_timeout: Duration::from_secs(30),
verify_auth: self.config.verify_integrity,
};
let mut pipeline = if self.config.verify_integrity {
if let Some(ctx) = self.config.auth_context.clone() {
DecodingPipeline::with_auth(config, ctx)
} else {
DecodingPipeline::new(config)
}
} else {
DecodingPipeline::new(config)
};
if let Err(err) = pipeline.set_object_params(*params) {
self.decoder_state = DecoderState::Failed;
return Err(Error::from(err));
}
for symbol in &self.symbols {
match pipeline.feed(symbol.clone()).map_err(Error::from)? {
SymbolAcceptResult::Rejected(RejectReason::BlockAlreadyDecoded) => {
}
SymbolAcceptResult::Rejected(reason) => {
#[cfg(feature = "tracing-integration")]
tracing::warn!(reason = ?reason, "ignoring rejected symbol during recovery");
#[cfg(not(feature = "tracing-integration"))]
let _ = &reason;
}
_ => {}
}
}
match pipeline.into_data() {
Ok(data) => {
self.decoder_state = DecoderState::Complete;
Ok(data)
}
Err(err) => {
self.decoder_state = DecoderState::Failed;
Err(Error::from(err))
}
}
}
pub fn decode_snapshot(&mut self, params: &ObjectParams) -> Result<RegionSnapshot, Error> {
let data = self.decode(params)?;
RegionSnapshot::from_bytes(&data).map_err(|e| {
Error::new(ErrorKind::DecodingFailed)
.with_message(format!("snapshot deserialization failed: {e}"))
})
}
}
impl std::fmt::Debug for StateDecoder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StateDecoder")
.field("config", &self.config)
.field("symbols", &self.symbols.len())
.field("seen_symbols", &self.seen_symbols.len())
.field("state", &self.decoder_state)
.finish()
}
}
#[derive(Debug)]
pub struct RecoveryResult {
pub snapshot: RegionSnapshot,
pub symbols_used: u32,
pub contributing_replicas: Vec<String>,
pub duration: Duration,
pub attempt: u32,
pub verified: bool,
}
pub struct RecoveryOrchestrator {
config: RecoveryConfig,
collector: RecoveryCollector,
decoder: StateDecoder,
attempt: u32,
recovering: bool,
cancelled: bool,
}
impl RecoveryOrchestrator {
#[must_use]
pub fn new(recovery_config: RecoveryConfig, decoding_config: RecoveryDecodingConfig) -> Self {
let collector = RecoveryCollector::new(recovery_config.clone());
let decoder = StateDecoder::new(decoding_config);
Self {
config: recovery_config,
collector,
decoder,
attempt: 0,
recovering: false,
cancelled: false,
}
}
#[must_use]
pub fn progress(&self) -> &RecoveryProgress {
self.collector.progress()
}
#[must_use]
pub fn is_recovering(&self) -> bool {
self.recovering && !self.cancelled
}
pub fn cancel(&mut self, _reason: &str) {
self.cancelled = true;
self.recovering = false;
self.collector.cancel();
}
pub fn recover_from_symbols(
&mut self,
trigger: &RecoveryTrigger,
symbols: &[CollectedSymbol],
params: ObjectParams,
duration: Duration,
) -> Result<RecoveryResult, Error> {
if self.cancelled {
return Err(Error::new(ErrorKind::RecoveryFailed)
.with_message("recovery session was cancelled"));
}
let max_attempts = self.config.max_attempts.max(1);
if self.attempt >= max_attempts {
return Err(Error::new(ErrorKind::RecoveryFailed)
.with_message(format!("recovery attempts exhausted ({max_attempts})")));
}
self.recovering = true;
self.attempt += 1;
let _ = trigger.region_id();
self.collector.reset_for_attempt(params);
self.decoder.reset();
for cs in symbols {
match self.collector.add_collected_with_verify(cs.clone()) {
Ok(_) => {}
Err(e) if e.kind() == ErrorKind::CorruptedSymbol => {
}
Err(e) => {
self.recovering = false;
return Err(e);
}
}
}
if !self.collector.can_decode() {
self.recovering = false;
return Err(Error::new(ErrorKind::RecoveryFailed)
.with_message("insufficient symbols for recovery"));
}
for cs in self.collector.symbols() {
let trust_verified_flag = !self.decoder.config.verify_integrity;
let auth = if cs.verified && trust_verified_flag {
AuthenticatedSymbol::new_verified(cs.symbol.clone(), cs.tag)
} else {
AuthenticatedSymbol::from_parts(cs.symbol.clone(), cs.tag)
};
if let Err(e) = self.decoder.add_symbol(&auth) {
self.recovering = false;
return Err(e);
}
}
let snapshot = match self.decoder.decode_snapshot(¶ms) {
Ok(s) => s,
Err(e) => {
self.recovering = false;
return Err(e);
}
};
let mut seen_replicas = HashSet::new();
let contributing: Vec<String> = self
.collector
.symbols()
.iter()
.filter(|s| seen_replicas.insert(s.source_replica.as_str()))
.map(|s| s.source_replica.clone())
.collect();
self.recovering = false;
Ok(RecoveryResult {
snapshot,
symbols_used: self.decoder.symbols_received(),
contributing_replicas: contributing,
duration,
attempt: self.attempt,
verified: self.decoder.config.verify_integrity,
})
}
}
impl std::fmt::Debug for RecoveryOrchestrator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RecoveryOrchestrator")
.field("config", &self.config)
.field("collector", &self.collector)
.field("decoder", &self.decoder)
.field("attempt", &self.attempt)
.field("recovering", &self.recovering)
.field("cancelled", &self.cancelled)
.finish()
}
}
#[cfg(test)]
#[allow(clippy::similar_names)]
mod tests {
use super::*;
use crate::distributed::encoding::{EncodedState, EncodingConfig, StateEncoder};
use crate::distributed::snapshot::{BudgetSnapshot, TaskSnapshot, TaskState};
use crate::record::region::RegionState;
use crate::types::symbol::{ObjectId, SymbolId, SymbolKind};
use crate::types::{RegionId, TaskId};
use crate::util::DetRng;
#[test]
fn trigger_region_id_extraction() {
let trigger = RecoveryTrigger::QuorumLost {
region_id: RegionId::new_for_test(1, 0),
available_replicas: vec!["r1".to_string()],
required_quorum: 2,
};
assert_eq!(trigger.region_id(), RegionId::new_for_test(1, 0));
let trigger2 = RecoveryTrigger::NodeRestart {
region_id: RegionId::new_for_test(2, 0),
last_known_sequence: 5,
};
assert_eq!(trigger2.region_id(), RegionId::new_for_test(2, 0));
let trigger3 = RecoveryTrigger::InconsistencyDetected {
region_id: RegionId::new_for_test(3, 0),
local_sequence: 10,
remote_sequence: 15,
};
assert_eq!(trigger3.region_id(), RegionId::new_for_test(3, 0));
}
#[test]
fn trigger_critical_classification() {
let critical = RecoveryTrigger::QuorumLost {
region_id: RegionId::new_for_test(1, 0),
available_replicas: vec![],
required_quorum: 2,
};
assert!(critical.is_critical());
let also_critical = RecoveryTrigger::InconsistencyDetected {
region_id: RegionId::new_for_test(1, 0),
local_sequence: 10,
remote_sequence: 15,
};
assert!(also_critical.is_critical());
let non_critical = RecoveryTrigger::ManualTrigger {
region_id: RegionId::new_for_test(1, 0),
initiator: "admin".to_string(),
reason: None,
};
assert!(!non_critical.is_critical());
let also_non_critical = RecoveryTrigger::NodeRestart {
region_id: RegionId::new_for_test(1, 0),
last_known_sequence: 0,
};
assert!(!also_non_critical.is_critical());
}
#[test]
fn collector_deduplicates_by_esi() {
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
let sym1 = Symbol::new_for_test(1, 0, 5, &[1, 2, 3]);
let sym2 = Symbol::new_for_test(1, 0, 5, &[1, 2, 3]);
let added1 = collector.add_collected(CollectedSymbol {
symbol: sym1,
tag: AuthenticationTag::zero(),
source_replica: "r1".to_string(),
collected_at: Time::ZERO,
verified: false,
});
assert!(added1);
let added2 = collector.add_collected(CollectedSymbol {
symbol: sym2,
tag: AuthenticationTag::zero(),
source_replica: "r2".to_string(),
collected_at: Time::from_secs(1),
verified: false,
});
assert!(!added2);
assert_eq!(collector.symbols().len(), 1);
assert_eq!(collector.metrics.symbols_duplicate, 1);
}
#[test]
fn collector_accepts_same_esi_on_different_blocks() {
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
collector.object_params =
Some(ObjectParams::new(ObjectId::new_for_test(1), 512, 128, 2, 2));
assert!(collector.add_collected(make_collected_symbol_with_block(0, 0)));
assert!(collector.add_collected(make_collected_symbol_with_block(0, 1)));
assert!(collector.add_collected(make_collected_symbol_with_block(1, 0)));
assert!(collector.add_collected(make_collected_symbol_with_block(1, 1)));
assert_eq!(collector.symbols().len(), 4);
assert_eq!(collector.metrics.symbols_duplicate, 0);
assert!(collector.can_decode());
}
#[test]
fn collector_progress_tracking() {
let collector = RecoveryCollector::new(RecoveryConfig::default());
let progress = collector.progress();
assert_eq!(progress.phase, RecoveryPhase::Initializing);
assert_eq!(progress.symbols_collected, 0);
}
#[test]
fn collector_can_decode_threshold() {
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
collector.object_params = Some(ObjectParams::new(
ObjectId::new_for_test(1),
1000,
128,
1,
10, ));
for i in 0..7 {
collector.add_collected(make_collected_symbol(i));
}
assert!(!collector.can_decode());
collector.add_collected(make_collected_symbol(7));
assert!(collector.can_decode());
}
#[test]
fn collector_can_decode_uses_total_object_symbols_for_multi_block_objects() {
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
collector.object_params = Some(ObjectParams::new(
ObjectId::new_for_test(1),
2560,
128,
2,
10,
));
for i in 0..10 {
collector.add_collected(make_collected_symbol_with_block(0, i));
}
for i in 0..9 {
collector.add_collected(make_collected_symbol_with_block(1, i));
}
assert!(!collector.can_decode());
collector.add_collected(make_collected_symbol_with_block(1, 9));
assert!(collector.can_decode());
}
#[test]
fn collector_reset_for_attempt_uses_total_object_symbols() {
let params = ObjectParams::new(ObjectId::new_for_test(1), 2560, 128, 2, 10);
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
collector.reset_for_attempt(params);
assert_eq!(collector.progress().symbols_needed, 20);
assert_eq!(collector.progress().symbols_collected, 0);
assert_eq!(collector.progress().phase, RecoveryPhase::Collecting);
}
#[test]
fn collector_reset_for_attempt_includes_configured_extra_symbols() {
let config = RecoveryConfig {
min_symbols: 2,
..RecoveryConfig::default()
};
let params = ObjectParams::new(ObjectId::new_for_test(1), 1280, 128, 1, 10);
let mut collector = RecoveryCollector::new(config);
collector.reset_for_attempt(params);
assert_eq!(collector.progress().symbols_needed, 12);
}
#[test]
fn collector_respects_configured_extra_symbol_threshold() {
let config = RecoveryConfig {
min_symbols: 2,
..RecoveryConfig::default()
};
let mut collector = RecoveryCollector::new(config);
let params = ObjectParams::new(ObjectId::new_for_test(1), 1280, 128, 1, 10);
collector.object_params = Some(params);
for i in 0..10 {
collector.add_collected(make_collected_symbol(i));
}
assert!(!collector.can_decode());
collector.add_collected(make_collected_symbol(10));
assert!(!collector.can_decode());
collector.add_collected(make_collected_symbol(11));
assert!(collector.can_decode());
}
#[test]
fn collector_cancel() {
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
assert!(!collector.cancelled);
collector.cancel();
assert!(collector.cancelled);
}
#[test]
fn decoder_accumulates_symbols() {
let mut decoder = StateDecoder::new(RecoveryDecodingConfig::default());
let sym = AuthenticatedSymbol::new_verified(
Symbol::new_for_test(1, 0, 0, &[1, 2, 3]),
AuthenticationTag::zero(),
);
decoder.add_symbol(&sym).unwrap();
assert_eq!(decoder.symbols_received(), 1);
}
#[test]
fn saturating_symbol_count_clamps_large_usize() {
assert_eq!(saturating_symbol_count(0), 0);
assert_eq!(saturating_symbol_count(u32::MAX as usize), u32::MAX);
assert_eq!(saturating_symbol_count(usize::MAX), u32::MAX);
if usize::BITS > u32::BITS {
assert_eq!(saturating_symbol_count((u32::MAX as usize) + 1), u32::MAX);
}
}
#[test]
fn decoder_deduplicates() {
let mut decoder = StateDecoder::new(RecoveryDecodingConfig::default());
let sym = AuthenticatedSymbol::new_verified(
Symbol::new_for_test(1, 0, 0, &[1, 2, 3]),
AuthenticationTag::zero(),
);
decoder.add_symbol(&sym).unwrap();
decoder.add_symbol(&sym).unwrap();
assert_eq!(decoder.symbols_received(), 1);
}
#[test]
fn decoder_accepts_same_esi_on_different_blocks() {
let mut decoder = StateDecoder::new(RecoveryDecodingConfig::default());
let block0 = AuthenticatedSymbol::new_verified(
Symbol::new_for_test(1, 0, 0, &[1, 2, 3]),
AuthenticationTag::zero(),
);
let block1 = AuthenticatedSymbol::new_verified(
Symbol::new_for_test(1, 1, 0, &[4, 5, 6]),
AuthenticationTag::zero(),
);
decoder.add_symbol(&block0).unwrap();
decoder.add_symbol(&block1).unwrap();
assert_eq!(decoder.symbols_received(), 2);
}
#[test]
fn decoder_rejects_insufficient_symbols() {
let mut decoder = StateDecoder::new(RecoveryDecodingConfig::default());
let params = ObjectParams::new(ObjectId::new_for_test(1), 12800, 1280, 1, 10);
for i in 0..2 {
let sym = AuthenticatedSymbol::new_verified(
Symbol::new_for_test(1, 0, i, &[0u8; 1280]),
AuthenticationTag::zero(),
);
decoder.add_symbol(&sym).unwrap();
}
let result = decoder.decode(¶ms);
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), ErrorKind::InsufficientSymbols);
}
#[test]
fn decoder_rejects_insufficient_symbols_for_multi_block_object() {
let mut decoder = StateDecoder::new(RecoveryDecodingConfig::default());
let params = ObjectParams::new(ObjectId::new_for_test(1), 2560, 128, 2, 10);
for i in 0..19 {
let sym = AuthenticatedSymbol::new_verified(
Symbol::new_for_test(1, 0, i, &[0u8; 128]),
AuthenticationTag::zero(),
);
decoder.add_symbol(&sym).unwrap();
}
let err = decoder.decode(¶ms).unwrap_err();
assert_eq!(err.kind(), ErrorKind::InsufficientSymbols);
assert!(err.to_string().contains("20"));
}
#[test]
fn decoder_successful_decode() {
let snapshot = create_test_snapshot();
let encoded = encode_test_snapshot(&snapshot);
let mut decoder = StateDecoder::new(RecoveryDecodingConfig::default());
for sym in &encoded.symbols {
let sym = AuthenticatedSymbol::new_verified(sym.clone(), AuthenticationTag::zero());
decoder.add_symbol(&sym).unwrap();
}
let recovered = decoder.decode_snapshot(&encoded.params).unwrap();
assert_eq!(recovered.region_id, snapshot.region_id);
assert_eq!(recovered.sequence, snapshot.sequence);
}
#[test]
fn decoder_verify_integrity_with_auth_context_verifies_unverified_symbols() {
let snapshot = create_test_snapshot();
let encoded = encode_test_snapshot(&snapshot);
let ctx = SecurityContext::for_testing(42);
let mut decoder = StateDecoder::new(RecoveryDecodingConfig {
verify_integrity: true,
auth_context: Some(ctx.clone()),
..Default::default()
});
for sym in &encoded.symbols {
let signed = ctx.sign_symbol(sym);
let unverified =
AuthenticatedSymbol::from_parts(signed.symbol().clone(), *signed.tag());
decoder.add_symbol(&unverified).unwrap();
}
let recovered = decoder.decode_snapshot(&encoded.params).unwrap();
assert_eq!(recovered.region_id, snapshot.region_id);
assert_eq!(recovered.sequence, snapshot.sequence);
}
#[test]
fn decoder_verify_integrity_without_auth_context_rejects_unverified_symbols() {
let snapshot = create_test_snapshot();
let encoded = encode_test_snapshot(&snapshot);
let mut decoder = StateDecoder::new(RecoveryDecodingConfig {
verify_integrity: true,
auth_context: None,
..Default::default()
});
for sym in &encoded.symbols {
let unverified =
AuthenticatedSymbol::from_parts(sym.clone(), AuthenticationTag::zero());
decoder.add_symbol(&unverified).unwrap();
}
let err = decoder.decode_snapshot(&encoded.params).unwrap_err();
assert_eq!(err.kind(), ErrorKind::InsufficientSymbols);
}
#[test]
fn decoder_reset() {
let mut decoder = StateDecoder::new(RecoveryDecodingConfig::default());
let sym = AuthenticatedSymbol::new_verified(
Symbol::new_for_test(1, 0, 0, &[1, 2, 3]),
AuthenticationTag::zero(),
);
decoder.add_symbol(&sym).unwrap();
assert_eq!(decoder.symbols_received(), 1);
decoder.reset();
assert_eq!(decoder.symbols_received(), 0);
}
#[test]
fn decoder_symbols_needed() {
let decoder = StateDecoder::new(RecoveryDecodingConfig::default());
let params = ObjectParams::new(ObjectId::new_for_test(1), 1000, 128, 1, 10);
assert_eq!(decoder.symbols_needed(¶ms), 8);
}
#[test]
fn decoder_symbols_needed_uses_total_object_symbols_for_multi_block_objects() {
let decoder = StateDecoder::new(RecoveryDecodingConfig::default());
let params = ObjectParams::new(ObjectId::new_for_test(1), 2560, 128, 2, 10);
assert_eq!(decoder.symbols_needed(¶ms), 20);
}
#[test]
fn orchestrator_successful_recovery() {
let snapshot = create_test_snapshot();
let encoded = encode_test_snapshot(&snapshot);
let symbols: Vec<CollectedSymbol> = encoded
.symbols
.iter()
.enumerate()
.map(|(i, s)| CollectedSymbol {
symbol: s.clone(),
tag: AuthenticationTag::zero(),
source_replica: format!("r{}", i % 3),
collected_at: Time::ZERO,
verified: false,
})
.collect();
let trigger = RecoveryTrigger::ManualTrigger {
region_id: RegionId::new_for_test(1, 0),
initiator: "test".to_string(),
reason: None,
};
let mut orchestrator = RecoveryOrchestrator::new(
RecoveryConfig::default(),
RecoveryDecodingConfig {
verify_integrity: false,
..Default::default()
},
);
let result = orchestrator
.recover_from_symbols(
&trigger,
&symbols,
encoded.params,
Duration::from_millis(10),
)
.unwrap();
assert!(!result.verified); assert!(!result.contributing_replicas.is_empty());
assert_eq!(result.snapshot.region_id, snapshot.region_id);
assert_eq!(result.snapshot.sequence, snapshot.sequence);
}
#[test]
fn orchestrator_ignores_single_corrupt_symbol_when_valid_set_is_sufficient() {
let snapshot = create_test_snapshot();
let encoded = encode_test_snapshot(&snapshot);
let mut symbols: Vec<CollectedSymbol> = encoded
.symbols
.iter()
.enumerate()
.map(|(i, s)| CollectedSymbol {
symbol: s.clone(),
tag: AuthenticationTag::zero(),
source_replica: format!("r{}", i % 3),
collected_at: Time::ZERO,
verified: false,
})
.collect();
symbols.push(CollectedSymbol {
symbol: Symbol::new_for_test(1, 0, 60_000, &[0u8; 16]),
tag: AuthenticationTag::zero(),
source_replica: "faulty".to_string(),
collected_at: Time::ZERO,
verified: false,
});
let trigger = RecoveryTrigger::ManualTrigger {
region_id: RegionId::new_for_test(1, 0),
initiator: "test".to_string(),
reason: None,
};
let mut orchestrator = RecoveryOrchestrator::new(
RecoveryConfig::default(),
RecoveryDecodingConfig {
verify_integrity: false,
..Default::default()
},
);
let result = orchestrator.recover_from_symbols(
&trigger,
&symbols,
encoded.params,
Duration::from_millis(10),
);
assert!(
result.is_ok(),
"recovery should tolerate one corrupt symbol"
);
assert_eq!(orchestrator.collector.metrics.symbols_corrupt, 1);
let recovered = result.unwrap();
assert_eq!(recovered.snapshot.region_id, snapshot.region_id);
assert_eq!(recovered.snapshot.sequence, snapshot.sequence);
}
#[test]
fn orchestrator_cancellation() {
let mut orchestrator =
RecoveryOrchestrator::new(RecoveryConfig::default(), RecoveryDecodingConfig::default());
assert!(!orchestrator.is_recovering());
orchestrator.cancel("test cancellation");
assert!(!orchestrator.is_recovering());
}
#[test]
fn orchestrator_insufficient_symbols() {
let trigger = RecoveryTrigger::ManualTrigger {
region_id: RegionId::new_for_test(1, 0),
initiator: "test".to_string(),
reason: None,
};
let params = ObjectParams::new(ObjectId::new_for_test(1), 1000, 128, 1, 10);
let symbols: Vec<CollectedSymbol> = (0..2).map(make_collected_symbol).collect();
let mut orchestrator = RecoveryOrchestrator::new(
RecoveryConfig::default(),
RecoveryDecodingConfig {
verify_integrity: false,
..Default::default()
},
);
let result = orchestrator.recover_from_symbols(
&trigger,
&symbols,
params,
Duration::from_millis(10),
);
assert!(result.is_err());
}
#[test]
fn full_recovery_workflow() {
let original = RegionSnapshot {
region_id: RegionId::new_for_test(1, 0),
state: RegionState::Open,
timestamp: Time::from_secs(100),
sequence: 42,
tasks: vec![TaskSnapshot {
task_id: TaskId::new_for_test(1, 0),
state: TaskState::Running,
priority: 5,
}],
children: vec![RegionId::new_for_test(2, 0)],
finalizer_count: 3,
budget: BudgetSnapshot {
deadline_nanos: None,
polls_remaining: None,
cost_remaining: None,
},
cancel_reason: None,
parent: None,
metadata: vec![1, 2, 3, 4],
};
let encoded = encode_test_snapshot(&original);
let symbols: Vec<CollectedSymbol> = encoded
.symbols
.iter()
.enumerate()
.map(|(i, s)| CollectedSymbol {
symbol: s.clone(),
tag: AuthenticationTag::zero(),
source_replica: format!("r{}", i % 3),
collected_at: Time::ZERO,
verified: false,
})
.collect();
let trigger = RecoveryTrigger::NodeRestart {
region_id: RegionId::new_for_test(1, 0),
last_known_sequence: 41,
};
let mut orchestrator = RecoveryOrchestrator::new(
RecoveryConfig::default(),
RecoveryDecodingConfig {
verify_integrity: false,
..Default::default()
},
);
let result = orchestrator
.recover_from_symbols(
&trigger,
&symbols,
encoded.params,
Duration::from_millis(50),
)
.unwrap();
assert_eq!(result.snapshot.region_id, original.region_id);
assert_eq!(result.snapshot.sequence, original.sequence);
assert_eq!(result.snapshot.tasks.len(), original.tasks.len());
assert_eq!(result.snapshot.children, original.children);
assert_eq!(result.snapshot.metadata, original.metadata);
assert_eq!(result.snapshot.finalizer_count, original.finalizer_count);
}
fn create_test_snapshot() -> RegionSnapshot {
RegionSnapshot {
region_id: RegionId::new_for_test(1, 0),
state: RegionState::Open,
timestamp: Time::from_secs(100),
sequence: 1,
tasks: vec![TaskSnapshot {
task_id: TaskId::new_for_test(1, 0),
state: TaskState::Running,
priority: 5,
}],
children: vec![],
finalizer_count: 2,
budget: BudgetSnapshot {
deadline_nanos: Some(1_000_000_000),
polls_remaining: Some(100),
cost_remaining: None,
},
cancel_reason: None,
parent: None,
metadata: vec![],
}
}
fn encode_test_snapshot(snapshot: &RegionSnapshot) -> EncodedState {
let config = EncodingConfig {
symbol_size: 128,
min_repair_symbols: 4,
..Default::default()
};
let mut enc = StateEncoder::new(config, DetRng::new(42));
enc.encode(snapshot, Time::ZERO).unwrap()
}
fn encode_multi_block_test_snapshot(snapshot: &RegionSnapshot) -> EncodedState {
let config = EncodingConfig {
symbol_size: 16,
min_repair_symbols: 0,
max_source_blocks: 2,
..Default::default()
};
let mut enc = StateEncoder::new(config, DetRng::new(42));
let encoded = enc.encode(snapshot, Time::ZERO).unwrap();
assert!(
encoded.params.source_blocks > 1,
"test snapshot should span multiple source blocks"
);
encoded
}
fn make_collected_symbol(esi: u32) -> CollectedSymbol {
CollectedSymbol {
symbol: Symbol::new_for_test(1, 0, esi, &[0u8; 128]),
tag: AuthenticationTag::zero(),
source_replica: "r1".to_string(),
collected_at: Time::ZERO,
verified: false,
}
}
fn make_collected_symbol_with_block(sbn: u8, esi: u32) -> CollectedSymbol {
CollectedSymbol {
symbol: Symbol::new_for_test(1, sbn, esi, &[0u8; 128]),
tag: AuthenticationTag::zero(),
source_replica: "r1".to_string(),
collected_at: Time::ZERO,
verified: false,
}
}
fn make_source_symbol(esi: u32, data: &[u8]) -> Symbol {
Symbol::new(
SymbolId::new(ObjectId::new_for_test(1), 0, esi),
data.to_vec(),
SymbolKind::Source,
)
}
#[test]
fn collector_duplicate_esi_from_same_replica() {
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
let sym1 = CollectedSymbol {
symbol: Symbol::new_for_test(1, 0, 5, &[1, 2, 3]),
tag: AuthenticationTag::zero(),
source_replica: "r1".to_string(),
collected_at: Time::ZERO,
verified: false,
};
let sym2 = CollectedSymbol {
symbol: Symbol::new_for_test(1, 0, 5, &[4, 5, 6]),
tag: AuthenticationTag::zero(),
source_replica: "r1".to_string(),
collected_at: Time::from_secs(1),
verified: false,
};
assert!(collector.add_collected(sym1));
assert!(!collector.add_collected(sym2));
assert_eq!(collector.symbols().len(), 1);
assert_eq!(collector.metrics.symbols_duplicate, 1);
assert_eq!(collector.metrics.symbols_received, 1);
}
#[test]
fn collector_verify_rejects_out_of_range_esi() {
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
collector.object_params = Some(ObjectParams::new(
ObjectId::new_for_test(1),
1280,
128,
1,
10,
));
let cs = CollectedSymbol {
symbol: Symbol::new_for_test(1, 0, 60_000, &[0u8; 128]),
tag: AuthenticationTag::zero(),
source_replica: "r1".to_string(),
collected_at: Time::ZERO,
verified: false,
};
let result = collector.add_collected_with_verify(cs);
assert!(result.is_err());
assert_eq!(collector.metrics.symbols_corrupt, 1);
}
#[test]
fn collector_verify_accepts_in_range_esi() {
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
collector.object_params = Some(ObjectParams::new(
ObjectId::new_for_test(1),
1280,
128,
1,
10,
));
let cs = CollectedSymbol {
symbol: Symbol::new_for_test(1, 0, 15, &[0u8; 128]),
tag: AuthenticationTag::zero(),
source_replica: "r1".to_string(),
collected_at: Time::ZERO,
verified: false,
};
let result = collector.add_collected_with_verify(cs);
assert!(result.is_ok());
assert!(result.unwrap()); }
#[test]
fn collector_verify_rejects_foreign_object_before_dedup() {
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
collector.object_params = Some(ObjectParams::new(
ObjectId::new_for_test(1),
1280,
128,
1,
10,
));
let foreign = CollectedSymbol {
symbol: Symbol::new_for_test(2, 0, 15, &[0u8; 128]),
tag: AuthenticationTag::zero(),
source_replica: "foreign".to_string(),
collected_at: Time::ZERO,
verified: false,
};
let accepted = CollectedSymbol {
symbol: Symbol::new_for_test(1, 0, 15, &[1u8; 128]),
tag: AuthenticationTag::zero(),
source_replica: "good".to_string(),
collected_at: Time::from_secs(1),
verified: false,
};
let foreign_result = collector.add_collected_with_verify(foreign);
assert!(foreign_result.is_err());
assert_eq!(collector.metrics.symbols_corrupt, 1);
assert_eq!(collector.symbols().len(), 0);
let accepted_result = collector.add_collected_with_verify(accepted);
assert!(accepted_result.is_ok());
assert!(accepted_result.unwrap());
assert_eq!(collector.symbols().len(), 1);
assert_eq!(collector.metrics.symbols_duplicate, 0);
assert_eq!(collector.symbols()[0].source_replica, "good");
}
#[test]
fn collector_verify_accepts_high_valid_sbn_at_256_block_boundary() {
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
collector.object_params =
Some(ObjectParams::new(ObjectId::new_for_test(1), 256, 1, 256, 1));
let high_sbn = CollectedSymbol {
symbol: Symbol::new_for_test(1, 255, 0, &[7u8]),
tag: AuthenticationTag::zero(),
source_replica: "r1".to_string(),
collected_at: Time::ZERO,
verified: false,
};
let result = collector.add_collected_with_verify(high_sbn);
assert!(result.is_ok());
assert!(result.unwrap());
assert_eq!(collector.symbols().len(), 1);
assert_eq!(collector.metrics.symbols_corrupt, 0);
}
#[test]
fn collector_verify_no_params_accepts_any() {
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
let cs = CollectedSymbol {
symbol: Symbol::new_for_test(1, 0, 999_999, &[0u8; 128]),
tag: AuthenticationTag::zero(),
source_replica: "r1".to_string(),
collected_at: Time::ZERO,
verified: false,
};
let result = collector.add_collected_with_verify(cs);
assert!(result.is_ok());
assert!(result.unwrap());
}
#[test]
fn collector_cancel_prevents_is_recovering() {
let mut orchestrator =
RecoveryOrchestrator::new(RecoveryConfig::default(), RecoveryDecodingConfig::default());
orchestrator.cancel("test");
assert!(!orchestrator.is_recovering());
assert!(orchestrator.cancelled);
}
#[test]
fn collector_metrics_accuracy() {
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
collector.object_params = Some(ObjectParams::new(
ObjectId::new_for_test(1),
1280,
128,
1,
10,
));
for i in 0..5 {
collector.add_collected(make_collected_symbol(i));
}
for i in 0..3 {
collector.add_collected(make_collected_symbol(i));
}
assert_eq!(collector.metrics.symbols_received, 5);
assert_eq!(collector.metrics.symbols_duplicate, 3);
assert_eq!(collector.progress().symbols_collected, 5);
assert_eq!(collector.symbols().len(), 5);
}
#[test]
fn decoder_insufficient_symbols_error_kind() {
let mut decoder = StateDecoder::new(RecoveryDecodingConfig::default());
let params = ObjectParams::new(ObjectId::new_for_test(1), 12800, 1280, 1, 10);
for i in 0..9 {
let sym = AuthenticatedSymbol::new_verified(
make_source_symbol(i, &[0u8; 1280]),
AuthenticationTag::zero(),
);
decoder.add_symbol(&sym).unwrap();
}
let err = decoder.decode(¶ms).unwrap_err();
assert_eq!(err.kind(), ErrorKind::InsufficientSymbols);
}
#[test]
fn decoder_zero_symbols_fails() {
let mut decoder = StateDecoder::new(RecoveryDecodingConfig::default());
let params = ObjectParams::new(ObjectId::new_for_test(1), 1280, 128, 1, 10);
let result = decoder.decode(¶ms);
assert!(result.is_err());
}
#[test]
fn decoder_reset_allows_reuse() {
let mut decoder = StateDecoder::new(RecoveryDecodingConfig::default());
let sym = make_source_symbol(0, &[1, 2, 3]);
let auth = AuthenticatedSymbol::new_verified(sym.clone(), AuthenticationTag::zero());
decoder.add_symbol(&auth).unwrap();
assert_eq!(decoder.symbols_received(), 1);
decoder.reset();
assert_eq!(decoder.symbols_received(), 0);
assert!(!decoder.can_decode());
let auth = AuthenticatedSymbol::new_verified(sym, AuthenticationTag::zero());
decoder.add_symbol(&auth).unwrap();
assert_eq!(decoder.symbols_received(), 1);
assert!(decoder.can_decode());
}
#[test]
fn decoder_mixed_source_repair_boundary_decode() {
let snapshot = create_test_snapshot();
let encoded = encode_test_snapshot(&snapshot);
let k = encoded.params.min_symbols_for_decode() as usize;
assert!(
encoded.symbols.len() >= k,
"encoded should have at least K symbols"
);
let mut decoder = StateDecoder::new(RecoveryDecodingConfig::default());
for sym in encoded.symbols.iter().take(k) {
let sym = AuthenticatedSymbol::new_verified(sym.clone(), AuthenticationTag::zero());
decoder.add_symbol(&sym).unwrap();
}
let result = decoder.decode_snapshot(&encoded.params);
assert!(result.is_ok());
let recovered = result.unwrap();
assert_eq!(recovered.region_id, snapshot.region_id);
}
#[test]
fn orchestrator_recover_with_zero_symbols() {
let trigger = RecoveryTrigger::ManualTrigger {
region_id: RegionId::new_for_test(1, 0),
initiator: "test".to_string(),
reason: None,
};
let params = ObjectParams::new(ObjectId::new_for_test(1), 1000, 128, 1, 10);
let mut orchestrator =
RecoveryOrchestrator::new(RecoveryConfig::default(), RecoveryDecodingConfig::default());
let result =
orchestrator.recover_from_symbols(&trigger, &[], params, Duration::from_millis(1));
assert!(result.is_err());
assert!(!orchestrator.is_recovering());
}
#[test]
fn orchestrator_attempt_counter_increments() {
let trigger = RecoveryTrigger::ManualTrigger {
region_id: RegionId::new_for_test(1, 0),
initiator: "test".to_string(),
reason: None,
};
let params = ObjectParams::new(ObjectId::new_for_test(1), 1000, 128, 1, 10);
let mut orchestrator =
RecoveryOrchestrator::new(RecoveryConfig::default(), RecoveryDecodingConfig::default());
let _ = orchestrator.recover_from_symbols(&trigger, &[], params, Duration::ZERO);
assert_eq!(orchestrator.attempt, 1);
let _ = orchestrator.recover_from_symbols(&trigger, &[], params, Duration::ZERO);
assert_eq!(orchestrator.attempt, 2);
}
#[test]
fn orchestrator_attempts_are_isolated() {
let snapshot = create_test_snapshot();
let encoded = encode_test_snapshot(&snapshot);
let good_symbols: Vec<CollectedSymbol> = encoded
.symbols
.iter()
.map(|s| CollectedSymbol {
symbol: s.clone(),
tag: AuthenticationTag::zero(),
source_replica: "r1".to_string(),
collected_at: Time::ZERO,
verified: false,
})
.collect();
let trigger = RecoveryTrigger::ManualTrigger {
region_id: RegionId::new_for_test(1, 0),
initiator: "test".to_string(),
reason: None,
};
let mut orchestrator = RecoveryOrchestrator::new(
RecoveryConfig::default(),
RecoveryDecodingConfig {
verify_integrity: false,
..Default::default()
},
);
let first = orchestrator.recover_from_symbols(
&trigger,
&good_symbols,
encoded.params,
Duration::from_millis(1),
);
assert!(first.is_ok());
let second = orchestrator.recover_from_symbols(
&trigger,
&[],
encoded.params,
Duration::from_millis(1),
);
assert!(second.is_err());
assert_eq!(orchestrator.attempt, 2);
}
#[test]
fn collector_preserves_multi_block_source_symbols_with_repeated_esi_values() {
let snapshot = create_test_snapshot();
let encoded = encode_multi_block_test_snapshot(&snapshot);
let symbols: Vec<CollectedSymbol> = encoded
.symbols
.iter()
.filter(|symbol| symbol.kind().is_source())
.map(|symbol| CollectedSymbol {
symbol: symbol.clone(),
tag: AuthenticationTag::zero(),
source_replica: "r1".to_string(),
collected_at: Time::ZERO,
verified: false,
})
.collect();
let mut seen_esi = HashSet::new();
assert!(
symbols
.iter()
.any(|symbol| !seen_esi.insert(symbol.symbol.esi())),
"multi-block fixture should reuse ESI values across blocks"
);
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
collector.object_params = Some(encoded.params);
for symbol in &symbols {
assert!(
collector.add_collected(symbol.clone()),
"collector should preserve distinct (SBN, ESI) symbols"
);
}
assert_eq!(collector.symbols().len(), symbols.len());
assert_eq!(collector.metrics.symbols_duplicate, 0);
assert!(collector.can_decode());
}
#[test]
fn orchestrator_enforces_max_attempts() {
let trigger = RecoveryTrigger::ManualTrigger {
region_id: RegionId::new_for_test(1, 0),
initiator: "test".to_string(),
reason: None,
};
let params = ObjectParams::new(ObjectId::new_for_test(1), 1000, 128, 1, 10);
let mut orchestrator = RecoveryOrchestrator::new(
RecoveryConfig {
max_attempts: 1,
..RecoveryConfig::default()
},
RecoveryDecodingConfig::default(),
);
let first = orchestrator.recover_from_symbols(&trigger, &[], params, Duration::ZERO);
assert!(first.is_err());
assert_eq!(orchestrator.attempt, 1);
let second = orchestrator.recover_from_symbols(&trigger, &[], params, Duration::ZERO);
assert!(second.is_err());
assert_eq!(orchestrator.attempt, 1);
let err = second.unwrap_err();
assert_eq!(err.kind(), ErrorKind::RecoveryFailed);
assert!(
err.to_string().contains("attempts exhausted"),
"unexpected max-attempts error: {err}"
);
}
#[test]
fn orchestrator_cancel_after_start() {
let snapshot = create_test_snapshot();
let encoded = encode_test_snapshot(&snapshot);
let mut orchestrator =
RecoveryOrchestrator::new(RecoveryConfig::default(), RecoveryDecodingConfig::default());
orchestrator.cancel("pre-emptive cancel");
assert!(orchestrator.cancelled);
assert!(!orchestrator.is_recovering());
assert_eq!(orchestrator.attempt, 0);
let symbols: Vec<CollectedSymbol> = encoded
.symbols
.iter()
.map(|s| CollectedSymbol {
symbol: s.clone(),
tag: AuthenticationTag::zero(),
source_replica: "r1".to_string(),
collected_at: Time::ZERO,
verified: false,
})
.collect();
let result = orchestrator.recover_from_symbols(
&RecoveryTrigger::ManualTrigger {
region_id: RegionId::new_for_test(1, 0),
initiator: "test".to_string(),
reason: None,
},
&symbols,
encoded.params,
Duration::ZERO,
);
assert!(result.is_err());
assert!(!orchestrator.is_recovering());
}
#[test]
fn recovery_config_default_values() {
let config = RecoveryConfig::default();
assert_eq!(config.min_symbols, 0);
assert_eq!(config.recovery_timeout, Duration::from_mins(1));
assert_eq!(config.replica_timeout, Duration::from_secs(5));
assert_eq!(config.max_concurrent_requests, 10);
assert_eq!(config.collection_consistency, CollectionConsistency::Quorum);
assert!(!config.allow_partial);
assert_eq!(config.max_attempts, 3);
}
#[test]
fn decoding_config_default_values() {
let config = RecoveryDecodingConfig::default();
assert!(!config.verify_integrity);
assert!(config.auth_context.is_none());
assert_eq!(config.max_decode_attempts, 3);
assert!(!config.allow_partial_decode);
}
#[test]
fn trigger_manual_with_reason() {
let trigger = RecoveryTrigger::ManualTrigger {
region_id: RegionId::new_for_test(5, 0),
initiator: "admin".to_string(),
reason: Some("routine maintenance".to_string()),
};
assert_eq!(trigger.region_id(), RegionId::new_for_test(5, 0));
assert!(!trigger.is_critical());
}
#[test]
fn recovery_phase_equality() {
assert_eq!(RecoveryPhase::Initializing, RecoveryPhase::Initializing);
assert_ne!(RecoveryPhase::Collecting, RecoveryPhase::Verifying);
assert_ne!(RecoveryPhase::Complete, RecoveryPhase::Failed);
}
#[test]
fn collector_debug_format() {
let collector = RecoveryCollector::new(RecoveryConfig::default());
let debug = format!("{collector:?}");
assert!(debug.contains("RecoveryCollector"));
assert!(debug.contains("collected"));
}
#[test]
fn orchestrator_debug_format() {
let orchestrator =
RecoveryOrchestrator::new(RecoveryConfig::default(), RecoveryDecodingConfig::default());
let debug = format!("{orchestrator:?}");
assert!(debug.contains("RecoveryOrchestrator"));
assert!(debug.contains("attempt"));
}
#[test]
fn collector_upgrades_unverified_to_verified_same_esi() {
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
let esi = 7;
let unverified = CollectedSymbol {
symbol: Symbol::new_for_test(1, 0, esi, &[1, 2, 3]),
tag: AuthenticationTag::zero(),
source_replica: "r1".to_string(),
collected_at: Time::ZERO,
verified: false,
};
assert!(collector.add_collected(unverified));
assert!(!collector.symbols()[0].verified);
let verified = CollectedSymbol {
symbol: Symbol::new_for_test(1, 0, esi, &[1, 2, 3]),
tag: AuthenticationTag::zero(),
source_replica: "r2".to_string(),
collected_at: Time::from_secs(1),
verified: true,
};
assert!(
collector.add_collected(verified),
"verified symbol must replace unverified for same block + ESI"
);
assert_eq!(collector.symbols().len(), 1);
assert!(
collector.symbols()[0].verified,
"stored symbol must now be verified"
);
assert_eq!(collector.symbols()[0].source_replica, "r2");
}
#[test]
fn collector_rejects_downgrade_verified_to_unverified() {
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
let verified = CollectedSymbol {
symbol: Symbol::new_for_test(1, 0, 7, &[1, 2, 3]),
tag: AuthenticationTag::zero(),
source_replica: "r1".to_string(),
collected_at: Time::ZERO,
verified: true,
};
assert!(collector.add_collected(verified));
let unverified = CollectedSymbol {
symbol: Symbol::new_for_test(1, 0, 7, &[4, 5, 6]),
tag: AuthenticationTag::zero(),
source_replica: "r2".to_string(),
collected_at: Time::from_secs(1),
verified: false,
};
assert!(
!collector.add_collected(unverified),
"unverified must not replace verified"
);
assert_eq!(collector.metrics.symbols_duplicate, 1);
assert!(collector.symbols()[0].verified);
assert_eq!(collector.symbols()[0].source_replica, "r1");
}
#[test]
fn collector_same_esi_different_replicas_is_duplicate() {
let mut collector = RecoveryCollector::new(RecoveryConfig::default());
let from_r1 = CollectedSymbol {
symbol: Symbol::new_for_test(1, 0, 10, &[1, 2, 3]),
tag: AuthenticationTag::zero(),
source_replica: "r1".to_string(),
collected_at: Time::ZERO,
verified: false,
};
let from_r2 = CollectedSymbol {
symbol: Symbol::new_for_test(1, 0, 10, &[1, 2, 3]),
tag: AuthenticationTag::zero(),
source_replica: "r2".to_string(),
collected_at: Time::from_secs(1),
verified: false,
};
assert!(collector.add_collected(from_r1));
assert!(
!collector.add_collected(from_r2),
"same block + ESI from different replica must be rejected as duplicate"
);
assert_eq!(collector.symbols().len(), 1);
assert_eq!(collector.metrics.symbols_duplicate, 1);
}
#[test]
fn orchestrator_recovery_rejects_foreign_object_symbol_poisoning() {
let snapshot = create_test_snapshot();
let config = EncodingConfig {
symbol_size: 128,
min_repair_symbols: 0,
..Default::default()
};
let mut enc = StateEncoder::new(config, DetRng::new(42));
let encoded = enc.encode(&snapshot, Time::ZERO).unwrap();
assert_eq!(encoded.params.source_blocks, 1);
let mut source_symbols: Vec<CollectedSymbol> = encoded
.symbols
.iter()
.filter(|symbol| symbol.kind().is_source())
.map(|symbol| CollectedSymbol {
symbol: symbol.clone(),
tag: AuthenticationTag::zero(),
source_replica: "good".to_string(),
collected_at: Time::ZERO,
verified: false,
})
.collect();
assert_eq!(
source_symbols.len(),
usize::try_from(encoded.params.total_source_symbols()).unwrap()
);
let collided = source_symbols
.first()
.expect("source symbol fixture")
.symbol
.clone();
let foreign = CollectedSymbol {
symbol: Symbol::new(
SymbolId::new(ObjectId::new_for_test(999), collided.sbn(), collided.esi()),
collided.data().to_vec(),
collided.kind(),
),
tag: AuthenticationTag::zero(),
source_replica: "foreign".to_string(),
collected_at: Time::ZERO,
verified: false,
};
let mut poisoned_inputs = Vec::with_capacity(source_symbols.len() + 1);
poisoned_inputs.push(foreign);
poisoned_inputs.append(&mut source_symbols);
let trigger = RecoveryTrigger::ManualTrigger {
region_id: snapshot.region_id,
initiator: "test".to_string(),
reason: None,
};
let mut orchestrator =
RecoveryOrchestrator::new(RecoveryConfig::default(), RecoveryDecodingConfig::default());
let result = orchestrator.recover_from_symbols(
&trigger,
&poisoned_inputs,
encoded.params,
Duration::from_millis(1),
);
assert!(
result.is_ok(),
"foreign-object symbol must be rejected before it can poison dedup"
);
}
#[test]
fn orchestrator_cancel_is_definitive_with_valid_data() {
let snapshot = create_test_snapshot();
let encoded = encode_test_snapshot(&snapshot);
let symbols: Vec<CollectedSymbol> = encoded
.symbols
.iter()
.map(|s| CollectedSymbol {
symbol: s.clone(),
tag: AuthenticationTag::zero(),
source_replica: "r1".to_string(),
collected_at: Time::ZERO,
verified: false,
})
.collect();
let mut orchestrator = RecoveryOrchestrator::new(
RecoveryConfig::default(),
RecoveryDecodingConfig {
verify_integrity: false,
..Default::default()
},
);
orchestrator.cancel("operator abort");
let result = orchestrator.recover_from_symbols(
&RecoveryTrigger::ManualTrigger {
region_id: RegionId::new_for_test(1, 0),
initiator: "test".to_string(),
reason: None,
},
&symbols,
encoded.params,
Duration::from_millis(1),
);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), ErrorKind::RecoveryFailed);
assert!(
err.to_string().contains("cancelled"),
"error message must mention cancellation: {err}"
);
}
#[test]
fn decoder_reset_allows_reuse_for_different_object() {
let snapshot1 = create_test_snapshot();
let encoded1 = encode_test_snapshot(&snapshot1);
let snapshot2 = RegionSnapshot {
region_id: RegionId::new_for_test(99, 0),
sequence: 42,
..create_test_snapshot()
};
let encoded2 = encode_test_snapshot(&snapshot2);
let mut decoder = StateDecoder::new(RecoveryDecodingConfig::default());
for sym in &encoded1.symbols {
let auth = AuthenticatedSymbol::new_verified(sym.clone(), AuthenticationTag::zero());
decoder.add_symbol(&auth).unwrap();
}
let recovered1 = decoder.decode_snapshot(&encoded1.params).unwrap();
assert_eq!(recovered1.region_id, snapshot1.region_id);
decoder.reset();
for sym in &encoded2.symbols {
let auth = AuthenticatedSymbol::new_verified(sym.clone(), AuthenticationTag::zero());
decoder.add_symbol(&auth).unwrap();
}
let recovered2 = decoder.decode_snapshot(&encoded2.params).unwrap();
assert_eq!(recovered2.region_id, snapshot2.region_id);
assert_eq!(recovered2.sequence, 42);
}
}