use bytes::Bytes;
use std::collections::BTreeMap;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, instrument, trace, warn};
use crate::binary_protocol::{ClientMessage, PayloadType};
use crate::errors::{Error, ProtocolError, Result};
use crate::protocol::MessageType;
pub const DEFAULT_RTT_MS: u64 = 100;
pub const RTT_VARIATION_FACTOR: f64 = 4.0;
pub const MIN_TIMEOUT_MS: u64 = 50;
pub const MAX_TIMEOUT_MS: u64 = 30_000;
pub const MAX_RETRANSMIT_ATTEMPTS: u32 = 5;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct AcknowledgeContent {
#[serde(rename = "AcknowledgedMessageType")]
pub acknowledged_message_type: String,
#[serde(rename = "AcknowledgedMessageId")]
pub acknowledged_message_id: String,
#[serde(rename = "AcknowledgedMessageSequenceNumber")]
pub acknowledged_message_sequence_number: i64,
#[serde(rename = "IsSequentialMessage")]
pub is_sequential_message: bool,
}
#[derive(Debug, Clone)]
pub struct PendingMessage {
pub message: ClientMessage,
pub serialized: Bytes,
pub sent_at: Instant,
pub attempts: u32,
pub last_sent: Instant,
}
#[derive(Debug, Clone)]
pub struct RttStats {
pub srtt: Duration,
pub rttvar: Duration,
pub rto: Duration,
pub sample_count: u64,
pub min_rtt: Duration,
pub max_rtt: Duration,
}
impl Default for RttStats {
fn default() -> Self {
let default_rtt = Duration::from_millis(DEFAULT_RTT_MS);
Self {
srtt: default_rtt,
rttvar: default_rtt / 2,
rto: Duration::from_millis(DEFAULT_RTT_MS * 3),
sample_count: 0,
min_rtt: Duration::from_secs(u64::MAX),
max_rtt: Duration::ZERO,
}
}
}
impl RttStats {
pub fn update(&mut self, sample: Duration) {
self.sample_count += 1;
if sample < self.min_rtt {
self.min_rtt = sample;
}
if sample > self.max_rtt {
self.max_rtt = sample;
}
if self.sample_count == 1 {
self.srtt = sample;
self.rttvar = sample / 2;
} else {
let diff = self.srtt.abs_diff(sample);
self.rttvar = (self.rttvar * 3 + diff) / 4;
self.srtt = (self.srtt * 7 + sample) / 8;
}
let rto_raw = self.srtt + self.rttvar.mul_f64(RTT_VARIATION_FACTOR);
self.rto = rto_raw
.max(Duration::from_millis(MIN_TIMEOUT_MS))
.min(Duration::from_millis(MAX_TIMEOUT_MS));
}
pub fn timeout(&self) -> Duration {
self.rto
}
}
pub struct AckTracker {
pending: RwLock<BTreeMap<i64, PendingMessage>>,
rtt_stats: RwLock<RttStats>,
next_sequence: AtomicI64,
expected_sequence: AtomicI64,
max_pending: usize,
}
impl AckTracker {
pub fn new(max_pending: usize) -> Self {
Self {
pending: RwLock::new(BTreeMap::new()),
rtt_stats: RwLock::new(RttStats::default()),
next_sequence: AtomicI64::new(0),
expected_sequence: AtomicI64::new(0),
max_pending,
}
}
pub fn next_sequence(&self) -> i64 {
self.next_sequence.fetch_add(1, Ordering::SeqCst)
}
pub fn expected_sequence(&self) -> i64 {
self.expected_sequence.load(Ordering::SeqCst)
}
#[instrument(skip(self, message, serialized))]
pub async fn track_sent(&self, message: ClientMessage, serialized: Bytes) -> Result<()> {
let mut pending = self.pending.write().await;
if pending.len() >= self.max_pending {
warn!(
pending = pending.len(),
max = self.max_pending,
"Flow control: too many pending messages"
);
return Err(Error::Protocol(ProtocolError::InvalidMessage(
"Too many pending messages".to_string(),
)));
}
let seq = message.sequence_number;
let now = Instant::now();
pending.insert(
seq,
PendingMessage {
message,
serialized,
sent_at: now,
attempts: 1,
last_sent: now,
},
);
trace!(sequence = seq, pending = pending.len(), "Tracking message");
Ok(())
}
#[instrument(skip(self))]
pub async fn process_ack(&self, ack: &AcknowledgeContent) -> Result<Duration> {
let seq = ack.acknowledged_message_sequence_number;
let mut pending = self.pending.write().await;
if let Some(msg) = pending.remove(&seq) {
let rtt = msg.sent_at.elapsed();
if msg.attempts == 1 {
let mut stats = self.rtt_stats.write().await;
stats.update(rtt);
debug!(
sequence = seq,
rtt_ms = rtt.as_millis(),
srtt_ms = stats.srtt.as_millis(),
rto_ms = stats.rto.as_millis(),
"Acknowledgment received"
);
}
Ok(rtt)
} else {
trace!(sequence = seq, "Ack for unknown or already-acked message");
Ok(Duration::ZERO)
}
}
pub fn update_expected_sequence(&self, received_seq: i64) {
let expected = self.expected_sequence.load(Ordering::SeqCst);
if received_seq == expected {
self.expected_sequence.store(expected + 1, Ordering::SeqCst);
}
}
#[instrument(skip(self))]
pub async fn get_retransmit_candidates(&self) -> Vec<(i64, Bytes)> {
let stats = self.rtt_stats.read().await;
let timeout = stats.timeout();
drop(stats);
let mut pending = self.pending.write().await;
let mut candidates = Vec::new();
let now = Instant::now();
for (seq, msg) in pending.iter_mut() {
if now.duration_since(msg.last_sent) >= timeout {
if msg.attempts >= MAX_RETRANSMIT_ATTEMPTS {
warn!(
sequence = seq,
attempts = msg.attempts,
"Max retransmit attempts reached"
);
continue;
}
msg.attempts += 1;
msg.last_sent = now;
candidates.push((*seq, msg.serialized.clone()));
debug!(
sequence = seq,
attempt = msg.attempts,
timeout_ms = timeout.as_millis(),
"Retransmitting message"
);
}
}
candidates
}
pub async fn prune_failed(&self) -> Vec<i64> {
let mut pending = self.pending.write().await;
let mut failed = Vec::new();
pending.retain(|seq, msg| {
if msg.attempts >= MAX_RETRANSMIT_ATTEMPTS {
failed.push(*seq);
false
} else {
true
}
});
failed
}
pub async fn rtt_stats(&self) -> RttStats {
self.rtt_stats.read().await.clone()
}
pub async fn pending_count(&self) -> usize {
self.pending.read().await.len()
}
pub fn create_ack(message: &ClientMessage, is_sequential: bool) -> Result<ClientMessage> {
use crate::binary_protocol::flags;
let ack_content = AcknowledgeContent {
acknowledged_message_type: message.message_type.as_str().to_string(),
acknowledged_message_id: message.message_id.to_string(),
acknowledged_message_sequence_number: message.sequence_number,
is_sequential_message: is_sequential,
};
let payload = serde_json::to_vec(&ack_content).map_err(|e| {
Error::Protocol(ProtocolError::Framing(format!(
"Failed to serialize ack: {}",
e
)))
})?;
let mut ack_msg = ClientMessage::new(
MessageType::Acknowledge,
0, PayloadType::Undefined, Bytes::from(payload),
);
ack_msg.flags = flags::SYN | flags::FIN;
Ok(ack_msg)
}
pub fn parse_ack(message: &ClientMessage) -> Result<AcknowledgeContent> {
serde_json::from_slice(&message.payload).map_err(|e| {
Error::Protocol(ProtocolError::Framing(format!(
"Failed to parse ack: {}",
e
)))
})
}
}
type AsyncSendFn = Box<
dyn Fn(Bytes) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
+ Send
+ Sync,
>;
pub struct ReliableSender {
ack_tracker: Arc<AckTracker>,
send_fn: AsyncSendFn,
}
impl ReliableSender {
pub fn new<F, Fut>(ack_tracker: Arc<AckTracker>, send_fn: F) -> Self
where
F: Fn(Bytes) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<()>> + Send + 'static,
{
Self {
ack_tracker,
send_fn: Box::new(move |data| Box::pin(send_fn(data))),
}
}
pub async fn send(&self, message: ClientMessage) -> Result<()> {
let serialized = message.serialize()?;
self.ack_tracker
.track_sent(message, serialized.clone())
.await?;
(self.send_fn)(serialized).await
}
pub async fn retransmit_loop(&self, interval: Duration) {
let mut timer = tokio::time::interval(interval);
loop {
timer.tick().await;
let candidates = self.ack_tracker.get_retransmit_candidates().await;
for (seq, data) in candidates {
if let Err(e) = (self.send_fn)(data).await {
warn!(sequence = seq, error = ?e, "Retransmit failed");
}
}
let failed = self.ack_tracker.prune_failed().await;
if !failed.is_empty() {
warn!(failed = ?failed, "Messages failed after max retries");
}
}
}
}
pub struct IncomingMessageBuffer {
messages: RwLock<std::collections::HashMap<i64, BufferedMessage>>,
capacity: usize,
}
#[derive(Debug, Clone)]
pub struct BufferedMessage {
pub raw: Bytes,
pub message: ClientMessage,
pub received_at: Instant,
}
impl IncomingMessageBuffer {
pub fn new(capacity: usize) -> Self {
Self {
messages: RwLock::new(std::collections::HashMap::new()),
capacity,
}
}
pub async fn has_capacity(&self) -> bool {
self.messages.read().await.len() < self.capacity
}
pub async fn add(&self, message: ClientMessage, raw: Bytes) -> bool {
let mut messages = self.messages.write().await;
if messages.len() >= self.capacity {
debug!(
capacity = self.capacity,
size = messages.len(),
seq = message.sequence_number,
"IncomingMessageBuffer full, dropping out-of-order message"
);
return false;
}
let seq = message.sequence_number;
messages.insert(
seq,
BufferedMessage {
raw,
message,
received_at: Instant::now(),
},
);
debug!(
seq,
buffer_size = messages.len(),
"Added to IncomingMessageBuffer"
);
true
}
pub async fn get(&self, seq: i64) -> Option<BufferedMessage> {
self.messages.read().await.get(&seq).cloned()
}
pub async fn remove(&self, seq: i64) -> Option<BufferedMessage> {
self.messages.write().await.remove(&seq)
}
pub async fn contains(&self, seq: i64) -> bool {
self.messages.read().await.contains_key(&seq)
}
pub async fn len(&self) -> usize {
self.messages.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.messages.read().await.is_empty()
}
pub async fn buffered_sequences(&self) -> Vec<i64> {
self.messages.read().await.keys().copied().collect()
}
}
pub struct OutgoingMessageBuffer {
messages: RwLock<std::collections::VecDeque<StreamingMessage>>,
capacity: usize,
rtt_stats: RwLock<RttStats>,
retransmission_timeout: RwLock<Duration>,
}
#[derive(Debug, Clone)]
pub struct StreamingMessage {
pub content: Bytes,
pub sequence_number: i64,
pub last_sent_time: Instant,
pub resend_attempt: u32,
}
impl OutgoingMessageBuffer {
pub fn new(capacity: usize) -> Self {
Self {
messages: RwLock::new(std::collections::VecDeque::with_capacity(
capacity.min(1000),
)),
capacity,
rtt_stats: RwLock::new(RttStats::default()),
retransmission_timeout: RwLock::new(Duration::from_millis(DEFAULT_RTT_MS * 3)),
}
}
pub async fn add(&self, content: Bytes, sequence_number: i64) {
let mut messages = self.messages.write().await;
if messages.len() >= self.capacity {
if let Some(dropped) = messages.pop_front() {
warn!(
seq = dropped.sequence_number,
"OutgoingMessageBuffer full, dropping oldest unACKed message"
);
}
}
messages.push_back(StreamingMessage {
content,
sequence_number,
last_sent_time: Instant::now(),
resend_attempt: 0,
});
trace!(
seq = sequence_number,
buffer_size = messages.len(),
"Added to OutgoingMessageBuffer"
);
}
pub async fn process_ack(&self, sequence_number: i64) -> bool {
let mut messages = self.messages.write().await;
if let Some(pos) = messages
.iter()
.position(|m| m.sequence_number == sequence_number)
{
let msg = messages.remove(pos).unwrap();
if msg.resend_attempt == 0 {
let rtt = msg.last_sent_time.elapsed();
drop(messages);
self.update_rtt(rtt).await;
debug!(
seq = sequence_number,
rtt_ms = rtt.as_millis(),
"ACK processed, RTT measured"
);
} else {
debug!(
seq = sequence_number,
resend_attempts = msg.resend_attempt,
"ACK processed for retransmitted message"
);
}
true
} else {
trace!(seq = sequence_number, "ACK for unknown sequence number");
false
}
}
async fn update_rtt(&self, rtt: Duration) {
let mut stats = self.rtt_stats.write().await;
stats.update(rtt);
let new_timeout = stats.timeout();
drop(stats);
let mut timeout = self.retransmission_timeout.write().await;
*timeout = new_timeout;
}
pub async fn get_retransmit_candidates(&self, max_attempts: u32) -> Vec<(i64, Bytes, bool)> {
let timeout = *self.retransmission_timeout.read().await;
let mut messages = self.messages.write().await;
let now = Instant::now();
let mut candidates = Vec::new();
let mut timed_out = false;
if let Some(msg) = messages.front_mut() {
if now.duration_since(msg.last_sent_time) > timeout {
if msg.resend_attempt >= max_attempts {
warn!(
seq = msg.sequence_number,
attempts = msg.resend_attempt,
"Max retransmit attempts reached"
);
timed_out = true;
} else {
msg.resend_attempt += 1;
msg.last_sent_time = now;
debug!(
seq = msg.sequence_number,
attempt = msg.resend_attempt,
timeout_ms = timeout.as_millis(),
"Retransmitting message"
);
candidates.push((msg.sequence_number, msg.content.clone(), false));
}
}
}
if timed_out {
candidates.push((-1, Bytes::new(), true));
}
candidates
}
pub async fn len(&self) -> usize {
self.messages.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.messages.read().await.is_empty()
}
pub async fn retransmission_timeout(&self) -> Duration {
*self.retransmission_timeout.read().await
}
pub async fn rtt_stats(&self) -> RttStats {
self.rtt_stats.read().await.clone()
}
pub async fn last_rtt(&self) -> Option<Duration> {
let stats = self.rtt_stats.read().await;
if stats.sample_count > 0 {
Some(stats.srtt)
} else {
None
}
}
pub async fn clear(&self) {
self.messages.write().await.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rtt_stats_initial() {
let stats = RttStats::default();
assert_eq!(stats.sample_count, 0);
assert!(stats.rto >= Duration::from_millis(MIN_TIMEOUT_MS));
}
#[test]
fn test_rtt_stats_first_sample() {
let mut stats = RttStats::default();
let sample = Duration::from_millis(50);
stats.update(sample);
assert_eq!(stats.sample_count, 1);
assert_eq!(stats.srtt, sample);
assert_eq!(stats.min_rtt, sample);
assert_eq!(stats.max_rtt, sample);
}
#[test]
fn test_rtt_stats_multiple_samples() {
let mut stats = RttStats::default();
let samples = [50, 60, 55, 70, 45, 65, 50, 55, 60, 52];
for sample_ms in samples {
stats.update(Duration::from_millis(sample_ms));
}
assert_eq!(stats.sample_count, 10);
assert!(stats.min_rtt == Duration::from_millis(45));
assert!(stats.max_rtt == Duration::from_millis(70));
let srtt_ms = stats.srtt.as_millis();
assert!((50..=65).contains(&srtt_ms));
}
#[test]
fn test_rtt_stats_timeout_bounds() {
let mut stats = RttStats::default();
stats.update(Duration::from_micros(100));
assert!(stats.rto >= Duration::from_millis(MIN_TIMEOUT_MS));
let mut stats2 = RttStats::default();
stats2.update(Duration::from_secs(60));
assert!(stats2.rto <= Duration::from_millis(MAX_TIMEOUT_MS));
}
#[tokio::test]
async fn test_ack_tracker_basic() {
let tracker = AckTracker::new(100);
assert_eq!(tracker.next_sequence(), 0);
assert_eq!(tracker.next_sequence(), 1);
assert_eq!(tracker.next_sequence(), 2);
let msg = ClientMessage::new(
MessageType::InputStreamData,
0,
PayloadType::Output,
Bytes::from("test"),
);
let serialized = msg.serialize().unwrap();
tracker.track_sent(msg, serialized).await.unwrap();
assert_eq!(tracker.pending_count().await, 1);
}
#[tokio::test]
async fn test_ack_tracker_process_ack() {
let tracker = AckTracker::new(100);
let msg = ClientMessage::new(
MessageType::InputStreamData,
42,
PayloadType::Output,
Bytes::from("test"),
);
let serialized = msg.serialize().unwrap();
tracker.track_sent(msg, serialized).await.unwrap();
let ack = AcknowledgeContent {
acknowledged_message_type: "input_stream_data".to_string(),
acknowledged_message_id: "test-id".to_string(),
acknowledged_message_sequence_number: 42,
is_sequential_message: true,
};
let rtt = tracker.process_ack(&ack).await.unwrap();
assert!(rtt >= Duration::ZERO);
assert_eq!(tracker.pending_count().await, 0);
}
#[test]
fn test_ack_content_serialization() {
let ack = AcknowledgeContent {
acknowledged_message_type: "input_stream_data".to_string(),
acknowledged_message_id: "550e8400-e29b-41d4-a716-446655440000".to_string(),
acknowledged_message_sequence_number: 42,
is_sequential_message: true,
};
let json = serde_json::to_string(&ack).unwrap();
assert!(json.contains("AcknowledgedMessageType"));
assert!(json.contains("AcknowledgedMessageSequenceNumber"));
let parsed: AcknowledgeContent = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.acknowledged_message_sequence_number, 42);
}
}