use super::classification::DataType;
use super::QoSClass;
use crate::Result;
use std::collections::{BTreeMap, HashMap};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct UpdateBatch {
pub id: u64,
pub data: Vec<u8>,
pub data_type: DataType,
pub created_at: Instant,
pub qos_class: QoSClass,
}
impl UpdateBatch {
pub fn new(id: u64, data: Vec<u8>, data_type: DataType, qos_class: QoSClass) -> Self {
Self {
id,
data,
data_type,
created_at: Instant::now(),
qos_class,
}
}
pub fn with_time(
id: u64,
data: Vec<u8>,
data_type: DataType,
qos_class: QoSClass,
created_at: Instant,
) -> Self {
Self {
id,
data,
data_type,
qos_class,
created_at,
}
}
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
pub fn size(&self) -> usize {
self.data.len()
}
}
#[derive(Debug)]
pub struct SyncRecovery {
queued_updates: BTreeMap<QoSClass, Vec<UpdateBatch>>,
obsolescence_windows: HashMap<DataType, Duration>,
total_bytes: usize,
obsolete_filtered: usize,
next_batch_id: u64,
recovery_in_progress: bool,
}
impl SyncRecovery {
pub fn new() -> Self {
Self {
queued_updates: BTreeMap::new(),
obsolescence_windows: HashMap::new(),
total_bytes: 0,
obsolete_filtered: 0,
next_batch_id: 0,
recovery_in_progress: false,
}
}
pub fn default_military() -> Self {
let mut recovery = Self::new();
recovery.set_obsolescence(DataType::PositionUpdate, Duration::from_secs(300)); recovery.set_obsolescence(DataType::Heartbeat, Duration::from_secs(60)); recovery.set_obsolescence(DataType::SensorTelemetry, Duration::from_secs(30)); recovery.set_obsolescence(DataType::EnvironmentData, Duration::from_secs(600));
recovery.set_obsolescence(DataType::HealthStatus, Duration::from_secs(600)); recovery.set_obsolescence(DataType::CapabilityChange, Duration::from_secs(1800)); recovery.set_obsolescence(DataType::FormationUpdate, Duration::from_secs(1800)); recovery.set_obsolescence(DataType::TaskAssignment, Duration::from_secs(3600));
recovery.set_obsolescence(DataType::TargetImage, Duration::from_secs(3600)); recovery.set_obsolescence(DataType::AudioIntercept, Duration::from_secs(3600)); recovery.set_obsolescence(DataType::MissionRetasking, Duration::from_secs(7200)); recovery.set_obsolescence(DataType::FormationChange, Duration::from_secs(3600));
recovery.set_obsolescence(DataType::DebugLog, Duration::from_secs(86400)); recovery.set_obsolescence(DataType::HistoricalTrack, Duration::from_secs(604800));
recovery
}
pub fn set_obsolescence(&mut self, data_type: DataType, window: Duration) {
self.obsolescence_windows.insert(data_type, window);
}
pub fn get_obsolescence(&self, data_type: &DataType) -> Option<Duration> {
self.obsolescence_windows.get(data_type).copied()
}
pub fn is_obsolete(&self, data_type: DataType, age: Duration) -> bool {
self.obsolescence_windows
.get(&data_type)
.map(|window| age > *window)
.unwrap_or(false) }
pub fn queue_update(&mut self, data: Vec<u8>, data_type: DataType, qos_class: QoSClass) -> u64 {
let id = self.next_batch_id;
self.next_batch_id += 1;
let batch = UpdateBatch::new(id, data.clone(), data_type, qos_class);
self.total_bytes += batch.size();
self.queued_updates
.entry(qos_class)
.or_default()
.push(batch);
id
}
pub fn queue_update_with_time(
&mut self,
data: Vec<u8>,
data_type: DataType,
qos_class: QoSClass,
created_at: Instant,
) -> u64 {
let id = self.next_batch_id;
self.next_batch_id += 1;
let batch = UpdateBatch::with_time(id, data.clone(), data_type, qos_class, created_at);
self.total_bytes += batch.size();
self.queued_updates
.entry(qos_class)
.or_default()
.push(batch);
id
}
pub fn apply_obsolescence_filter(&mut self) -> usize {
let mut filtered = 0;
let mut bytes_removed = 0;
let windows = self.obsolescence_windows.clone();
for batches in self.queued_updates.values_mut() {
let before_len = batches.len();
batches.retain(|batch| {
let is_obsolete = windows
.get(&batch.data_type)
.map(|window| batch.age() > *window)
.unwrap_or(false);
if is_obsolete {
bytes_removed += batch.size();
}
!is_obsolete
});
let removed = before_len - batches.len();
filtered += removed;
}
self.total_bytes = self.total_bytes.saturating_sub(bytes_removed);
self.obsolete_filtered += filtered;
filtered
}
pub async fn recover_from_partition(&mut self) -> Result<RecoveryIterator<'_>> {
self.recovery_in_progress = true;
self.apply_obsolescence_filter();
Ok(RecoveryIterator::new(self))
}
pub fn next_recovery_batch(&mut self) -> Option<UpdateBatch> {
for class in QoSClass::all_by_priority() {
if let Some(batches) = self.queued_updates.get_mut(class) {
if !batches.is_empty() {
let batch = batches.remove(0);
self.total_bytes = self.total_bytes.saturating_sub(batch.size());
return Some(batch);
}
}
}
None
}
pub fn complete_recovery(&mut self) {
self.recovery_in_progress = false;
}
pub fn is_recovering(&self) -> bool {
self.recovery_in_progress
}
pub fn total_bytes_queued(&self) -> usize {
self.total_bytes
}
pub fn queued_count_by_class(&self, class: QoSClass) -> usize {
self.queued_updates
.get(&class)
.map(|v| v.len())
.unwrap_or(0)
}
pub fn total_queued(&self) -> usize {
self.queued_updates.values().map(|v| v.len()).sum()
}
pub fn obsolete_filtered_count(&self) -> usize {
self.obsolete_filtered
}
pub fn stats(&self) -> RecoveryStats {
let mut by_class = HashMap::new();
for class in QoSClass::all_by_priority() {
by_class.insert(*class, self.queued_count_by_class(*class));
}
RecoveryStats {
total_queued: self.total_queued(),
total_bytes: self.total_bytes,
by_class,
obsolete_filtered: self.obsolete_filtered,
recovery_in_progress: self.recovery_in_progress,
}
}
pub fn clear(&mut self) {
self.queued_updates.clear();
self.total_bytes = 0;
}
}
impl Default for SyncRecovery {
fn default() -> Self {
Self::default_military()
}
}
pub struct RecoveryIterator<'a> {
recovery: &'a mut SyncRecovery,
}
impl<'a> RecoveryIterator<'a> {
fn new(recovery: &'a mut SyncRecovery) -> Self {
Self { recovery }
}
pub fn next_with_limit(&mut self, max_bytes: usize) -> Option<UpdateBatch> {
let peek_size = self.peek_size()?;
if peek_size > max_bytes {
return None;
}
self.recovery.next_recovery_batch()
}
pub fn peek_size(&self) -> Option<usize> {
for class in QoSClass::all_by_priority() {
if let Some(batches) = self.recovery.queued_updates.get(class) {
if let Some(batch) = batches.first() {
return Some(batch.size());
}
}
}
None
}
pub fn has_more(&self) -> bool {
self.recovery.total_queued() > 0
}
}
impl Iterator for RecoveryIterator<'_> {
type Item = UpdateBatch;
fn next(&mut self) -> Option<Self::Item> {
self.recovery.next_recovery_batch()
}
}
#[derive(Debug, Clone)]
pub struct RecoveryStats {
pub total_queued: usize,
pub total_bytes: usize,
pub by_class: HashMap<QoSClass, usize>,
pub obsolete_filtered: usize,
pub recovery_in_progress: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_update_batch_creation() {
let batch = UpdateBatch::new(
1,
vec![1, 2, 3],
DataType::ContactReport,
QoSClass::Critical,
);
assert_eq!(batch.id, 1);
assert_eq!(batch.size(), 3);
assert_eq!(batch.qos_class, QoSClass::Critical);
}
#[test]
fn test_recovery_creation() {
let recovery = SyncRecovery::new();
assert_eq!(recovery.total_queued(), 0);
assert_eq!(recovery.total_bytes_queued(), 0);
assert!(!recovery.is_recovering());
}
#[test]
fn test_obsolescence_windows() {
let recovery = SyncRecovery::default_military();
assert!(recovery.is_obsolete(DataType::PositionUpdate, Duration::from_secs(400)));
assert!(!recovery.is_obsolete(DataType::PositionUpdate, Duration::from_secs(200)));
assert!(!recovery.is_obsolete(DataType::ContactReport, Duration::from_secs(86400)));
}
#[test]
fn test_queue_update() {
let mut recovery = SyncRecovery::new();
let id = recovery.queue_update(vec![1, 2, 3], DataType::HealthStatus, QoSClass::Normal);
assert_eq!(id, 0);
assert_eq!(recovery.total_queued(), 1);
assert_eq!(recovery.total_bytes_queued(), 3);
assert_eq!(recovery.queued_count_by_class(QoSClass::Normal), 1);
}
#[test]
fn test_priority_ordering() {
let mut recovery = SyncRecovery::new();
recovery.queue_update(vec![5], DataType::DebugLog, QoSClass::Bulk);
recovery.queue_update(vec![1], DataType::ContactReport, QoSClass::Critical);
recovery.queue_update(vec![3], DataType::HealthStatus, QoSClass::Normal);
let batch1 = recovery.next_recovery_batch().unwrap();
assert_eq!(batch1.qos_class, QoSClass::Critical);
let batch2 = recovery.next_recovery_batch().unwrap();
assert_eq!(batch2.qos_class, QoSClass::Normal);
let batch3 = recovery.next_recovery_batch().unwrap();
assert_eq!(batch3.qos_class, QoSClass::Bulk);
}
#[test]
fn test_obsolescence_filter() {
let mut recovery = SyncRecovery::default_military();
let old_time = Instant::now() - Duration::from_secs(600); recovery.queue_update_with_time(vec![1], DataType::PositionUpdate, QoSClass::Low, old_time);
recovery.queue_update(vec![2], DataType::ContactReport, QoSClass::Critical);
assert_eq!(recovery.total_queued(), 2);
let filtered = recovery.apply_obsolescence_filter();
assert_eq!(filtered, 1);
assert_eq!(recovery.total_queued(), 1);
let batch = recovery.next_recovery_batch().unwrap();
assert_eq!(batch.data_type, DataType::ContactReport);
}
#[test]
fn test_stats() {
let mut recovery = SyncRecovery::new();
recovery.queue_update(vec![0; 100], DataType::ContactReport, QoSClass::Critical);
recovery.queue_update(vec![0; 200], DataType::HealthStatus, QoSClass::Normal);
recovery.queue_update(vec![0; 50], DataType::DebugLog, QoSClass::Bulk);
let stats = recovery.stats();
assert_eq!(stats.total_queued, 3);
assert_eq!(stats.total_bytes, 350);
assert_eq!(*stats.by_class.get(&QoSClass::Critical).unwrap(), 1);
assert_eq!(*stats.by_class.get(&QoSClass::Normal).unwrap(), 1);
assert_eq!(*stats.by_class.get(&QoSClass::Bulk).unwrap(), 1);
}
#[tokio::test]
async fn test_recover_from_partition() {
let mut recovery = SyncRecovery::default_military();
recovery.queue_update(vec![1], DataType::ContactReport, QoSClass::Critical);
recovery.queue_update(vec![2], DataType::HealthStatus, QoSClass::Normal);
let mut iter = recovery.recover_from_partition().await.unwrap();
assert!(iter.has_more());
let batch1 = iter.next();
assert!(batch1.is_some());
assert_eq!(batch1.unwrap().qos_class, QoSClass::Critical);
let batch2 = iter.next();
assert!(batch2.is_some());
assert!(!iter.has_more());
}
#[test]
fn test_clear() {
let mut recovery = SyncRecovery::new();
recovery.queue_update(vec![0; 100], DataType::ContactReport, QoSClass::Critical);
recovery.queue_update(vec![0; 200], DataType::HealthStatus, QoSClass::Normal);
recovery.clear();
assert_eq!(recovery.total_queued(), 0);
assert_eq!(recovery.total_bytes_queued(), 0);
}
#[test]
fn test_custom_obsolescence() {
let mut recovery = SyncRecovery::new();
recovery.set_obsolescence(DataType::HealthStatus, Duration::from_secs(60));
assert!(recovery.is_obsolete(DataType::HealthStatus, Duration::from_secs(120)));
assert!(!recovery.is_obsolete(DataType::HealthStatus, Duration::from_secs(30)));
}
#[test]
fn test_batch_age() {
let batch = UpdateBatch::new(1, vec![1], DataType::ContactReport, QoSClass::Critical);
assert!(batch.age() < Duration::from_secs(1));
}
#[test]
fn test_multiple_batches_same_class() {
let mut recovery = SyncRecovery::new();
recovery.queue_update(vec![1], DataType::ContactReport, QoSClass::Critical);
recovery.queue_update(vec![2], DataType::EmergencyAlert, QoSClass::Critical);
recovery.queue_update(vec![3], DataType::AbortCommand, QoSClass::Critical);
let batch1 = recovery.next_recovery_batch().unwrap();
assert_eq!(batch1.data, vec![1]);
let batch2 = recovery.next_recovery_batch().unwrap();
assert_eq!(batch2.data, vec![2]);
let batch3 = recovery.next_recovery_batch().unwrap();
assert_eq!(batch3.data, vec![3]);
}
#[test]
fn test_get_obsolescence() {
let recovery = SyncRecovery::default_military();
assert!(recovery
.get_obsolescence(&DataType::PositionUpdate)
.is_some());
assert!(recovery
.get_obsolescence(&DataType::ContactReport)
.is_none());
}
#[test]
fn test_obsolete_filtered_count() {
let mut recovery = SyncRecovery::default_military();
let old_time = Instant::now() - Duration::from_secs(600);
recovery.queue_update_with_time(vec![1], DataType::PositionUpdate, QoSClass::Low, old_time);
assert_eq!(recovery.obsolete_filtered_count(), 0);
recovery.apply_obsolescence_filter();
assert_eq!(recovery.obsolete_filtered_count(), 1);
}
}