use std::collections::{BTreeMap, BTreeSet};
use serde::{Deserialize, Serialize};
use crate::ProtocolError;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DeliveryProfile {
BestEffort,
Eventual,
Receipt,
Quorum,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct QuorumPolicy {
pub required_receipts: u32,
}
impl QuorumPolicy {
pub fn required_receipts(&self) -> u32 {
self.required_receipts.max(1)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DeliveryAnnouncement {
pub stream_id: String,
pub author_id: String,
pub sequence: u64,
pub object_path: String,
pub delivery_profile: DeliveryProfile,
#[serde(skip_serializing_if = "Option::is_none")]
pub quorum: Option<QuorumPolicy>,
#[serde(skip_serializing_if = "Option::is_none")]
pub object_digest: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DeliveryReceipt {
pub stream_id: String,
pub author_id: String,
pub recipient_id: String,
pub delivered_through: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SequenceRange {
pub start: u64,
pub end: u64,
}
impl SequenceRange {
pub fn contains(&self, sequence: u64) -> bool {
self.start <= sequence && sequence <= self.end
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DeliveryManifest {
pub stream_id: String,
pub author_id: String,
pub ranges: Vec<SequenceRange>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RepairTarget {
pub sequence: u64,
pub object_path: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DeliveryState {
pub sequence: u64,
pub receipt_count: u32,
pub required_receipts: u32,
pub is_delivered: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DeliveryTracker {
stream_id: String,
author_id: String,
recipients: BTreeSet<String>,
announced_sequences: BTreeSet<u64>,
highest_receipt_by_recipient: BTreeMap<String, u64>,
}
impl DeliveryTracker {
pub fn new(
stream_id: impl Into<String>,
author_id: impl Into<String>,
recipients: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
Self {
stream_id: stream_id.into(),
author_id: author_id.into(),
recipients: recipients.into_iter().map(Into::into).collect(),
announced_sequences: BTreeSet::new(),
highest_receipt_by_recipient: BTreeMap::new(),
}
}
pub fn observe_announcement(
&mut self,
announcement: &DeliveryAnnouncement,
) -> Result<(), ProtocolError> {
validate_scope(
&self.stream_id,
&self.author_id,
&announcement.stream_id,
&announcement.author_id,
)?;
self.announced_sequences.insert(announcement.sequence);
Ok(())
}
pub fn observe_receipt(&mut self, receipt: &DeliveryReceipt) -> Result<(), ProtocolError> {
validate_scope(
&self.stream_id,
&self.author_id,
&receipt.stream_id,
&receipt.author_id,
)?;
let entry = self
.highest_receipt_by_recipient
.entry(receipt.recipient_id.clone())
.or_insert(0);
*entry = (*entry).max(receipt.delivered_through);
Ok(())
}
pub fn state_for(
&self,
sequence: u64,
profile: &DeliveryProfile,
quorum: Option<&QuorumPolicy>,
) -> Option<DeliveryState> {
if !self.announced_sequences.contains(&sequence) {
return None;
}
let receipt_count = self
.highest_receipt_by_recipient
.values()
.filter(|observed| **observed >= sequence)
.count() as u32;
let required_receipts = required_receipts(profile, quorum, self.recipients.len() as u32);
Some(DeliveryState {
sequence,
receipt_count,
required_receipts,
is_delivered: receipt_count >= required_receipts,
})
}
pub fn highest_receipt_for(&self, recipient_id: &str) -> Option<u64> {
self.highest_receipt_by_recipient.get(recipient_id).copied()
}
}
pub fn mesh_object_path(
namespace: &str,
stream_id: &str,
author_id: &str,
sequence: u64,
) -> String {
format!("{namespace}/stream/{stream_id}/obj/{author_id}/{sequence}")
}
pub fn encode_announcement(announcement: &DeliveryAnnouncement) -> Result<Vec<u8>, ProtocolError> {
serde_json::to_vec(announcement).map_err(|e| ProtocolError::InvalidEncoding(e.to_string()))
}
pub fn decode_announcement(bytes: &[u8]) -> Result<DeliveryAnnouncement, ProtocolError> {
serde_json::from_slice(bytes).map_err(|e| ProtocolError::InvalidEncoding(e.to_string()))
}
pub fn encode_receipt(receipt: &DeliveryReceipt) -> Result<Vec<u8>, ProtocolError> {
serde_json::to_vec(receipt).map_err(|e| ProtocolError::InvalidEncoding(e.to_string()))
}
pub fn decode_receipt(bytes: &[u8]) -> Result<DeliveryReceipt, ProtocolError> {
serde_json::from_slice(bytes).map_err(|e| ProtocolError::InvalidEncoding(e.to_string()))
}
pub fn encode_manifest(manifest: &DeliveryManifest) -> Result<Vec<u8>, ProtocolError> {
serde_json::to_vec(manifest).map_err(|e| ProtocolError::InvalidEncoding(e.to_string()))
}
pub fn decode_manifest(bytes: &[u8]) -> Result<DeliveryManifest, ProtocolError> {
serde_json::from_slice(bytes).map_err(|e| ProtocolError::InvalidEncoding(e.to_string()))
}
pub fn manifest_from_sequences(
stream_id: impl Into<String>,
author_id: impl Into<String>,
sequences: impl IntoIterator<Item = u64>,
) -> DeliveryManifest {
let mut sequences = sequences.into_iter().collect::<Vec<_>>();
sequences.sort_unstable();
sequences.dedup();
let mut ranges = Vec::new();
let mut iter = sequences.into_iter();
if let Some(first) = iter.next() {
let mut start = first;
let mut end = first;
for sequence in iter {
if sequence == end + 1 {
end = sequence;
} else {
ranges.push(SequenceRange { start, end });
start = sequence;
end = sequence;
}
}
ranges.push(SequenceRange { start, end });
}
DeliveryManifest {
stream_id: stream_id.into(),
author_id: author_id.into(),
ranges,
}
}
pub fn diff_manifest(
local: &DeliveryManifest,
remote: &DeliveryManifest,
) -> Result<Vec<SequenceRange>, ProtocolError> {
validate_scope(
&local.stream_id,
&local.author_id,
&remote.stream_id,
&remote.author_id,
)?;
let local_sequences = expand_ranges(&local.ranges);
let remote_sequences = expand_ranges(&remote.ranges);
let missing = local_sequences
.difference(&remote_sequences)
.copied()
.collect::<Vec<_>>();
Ok(manifest_from_sequences(local.stream_id.clone(), local.author_id.clone(), missing).ranges)
}
pub fn selective_repair_targets(
announcements: &[DeliveryAnnouncement],
remote: &DeliveryManifest,
) -> Result<Vec<RepairTarget>, ProtocolError> {
let Some(first) = announcements.first() else {
return Ok(Vec::new());
};
validate_scope(
&first.stream_id,
&first.author_id,
&remote.stream_id,
&remote.author_id,
)?;
if announcements.iter().any(|announcement| {
announcement.stream_id != first.stream_id || announcement.author_id != first.author_id
}) {
return Err(ProtocolError::InvalidEnvelope(
"announcements must all share the same stream/author".into(),
));
}
let remote_sequences = expand_ranges(&remote.ranges);
let mut targets = announcements
.iter()
.filter(|announcement| !remote_sequences.contains(&announcement.sequence))
.map(|announcement| RepairTarget {
sequence: announcement.sequence,
object_path: announcement.object_path.clone(),
})
.collect::<Vec<_>>();
targets.sort_by_key(|target| target.sequence);
Ok(targets)
}
fn required_receipts(
profile: &DeliveryProfile,
quorum: Option<&QuorumPolicy>,
recipient_count: u32,
) -> u32 {
match profile {
DeliveryProfile::BestEffort => 0,
DeliveryProfile::Receipt => 1,
DeliveryProfile::Eventual => recipient_count.max(1),
DeliveryProfile::Quorum => quorum
.map(QuorumPolicy::required_receipts)
.unwrap_or_else(|| recipient_count.max(1)),
}
}
fn validate_scope(
expected_stream_id: &str,
expected_author_id: &str,
actual_stream_id: &str,
actual_author_id: &str,
) -> Result<(), ProtocolError> {
if expected_stream_id != actual_stream_id || expected_author_id != actual_author_id {
return Err(ProtocolError::InvalidEnvelope(
"stream_id/author_id mismatch".into(),
));
}
Ok(())
}
fn expand_ranges(ranges: &[SequenceRange]) -> BTreeSet<u64> {
let mut sequences = BTreeSet::new();
for range in ranges {
for sequence in range.start..=range.end {
sequences.insert(sequence);
}
}
sequences
}