use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, OnceLock, Weak};
use std::time::Duration;
use super::global_bandwidth::GlobalBandwidthManager;
use super::symmetric_message::SymmetricMessagePayload;
use crate::node::background_task_monitor::BackgroundTaskMonitor;
use crate::transport::TRANSPORT_METRICS;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum OutboundClass {
MustFlow,
Short,
Bulk,
}
pub(crate) fn classify(payload: &SymmetricMessagePayload) -> OutboundClass {
match payload {
SymmetricMessagePayload::StreamFragment { .. } => OutboundClass::Bulk,
SymmetricMessagePayload::ShortMessage { .. } => OutboundClass::Short,
SymmetricMessagePayload::NoOp
| SymmetricMessagePayload::AckConnection { .. }
| SymmetricMessagePayload::Ping { .. }
| SymmetricMessagePayload::Pong { .. } => OutboundClass::MustFlow,
}
}
struct OutboundCounters {
must_flow_bytes: AtomicU64,
short_bytes: AtomicU64,
bulk_bytes: AtomicU64,
}
static OUTBOUND: OutboundCounters = OutboundCounters {
must_flow_bytes: AtomicU64::new(0),
short_bytes: AtomicU64::new(0),
bulk_bytes: AtomicU64::new(0),
};
pub(crate) fn record_outbound(class: OutboundClass, bytes: usize) {
let counter = match class {
OutboundClass::MustFlow => &OUTBOUND.must_flow_bytes,
OutboundClass::Short => &OUTBOUND.short_bytes,
OutboundClass::Bulk => &OUTBOUND.bulk_bytes,
};
counter.fetch_add(bytes as u64, Ordering::Relaxed);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct OutboundSnapshot {
must_flow_bytes: u64,
short_bytes: u64,
bulk_bytes: u64,
}
fn snapshot_outbound() -> OutboundSnapshot {
OutboundSnapshot {
must_flow_bytes: OUTBOUND.must_flow_bytes.load(Ordering::Relaxed),
short_bytes: OUTBOUND.short_bytes.load(Ordering::Relaxed),
bulk_bytes: OUTBOUND.bulk_bytes.load(Ordering::Relaxed),
}
}
#[cfg(test)]
pub(crate) fn outbound_counters_snapshot() -> (u64, u64, u64) {
let s = snapshot_outbound();
(s.must_flow_bytes, s.short_bytes, s.bulk_bytes)
}
fn outbound_delta(prev: OutboundSnapshot, now: OutboundSnapshot) -> (u64, u64, u64) {
(
now.must_flow_bytes.saturating_sub(prev.must_flow_bytes),
now.short_bytes.saturating_sub(prev.short_bytes),
now.bulk_bytes.saturating_sub(prev.bulk_bytes),
)
}
static BROADCAST_QUEUE_DEPTH: AtomicUsize = AtomicUsize::new(0);
#[cfg_attr(feature = "simulation_tests", allow(dead_code))]
pub(crate) fn record_broadcast_queue_depth(depth: usize) {
BROADCAST_QUEUE_DEPTH.store(depth, Ordering::Relaxed);
}
static GLOBAL_BANDWIDTH: OnceLock<Weak<GlobalBandwidthManager>> = OnceLock::new();
pub(crate) fn register_global_bandwidth(manager: &Arc<GlobalBandwidthManager>) {
if GLOBAL_BANDWIDTH.set(Arc::downgrade(manager)).is_err() {
tracing::debug!(
target: "freenet::transport::shadow_demand",
"global bandwidth handle already registered; keeping the first"
);
}
}
fn global_bandwidth() -> Option<Arc<GlobalBandwidthManager>> {
GLOBAL_BANDWIDTH.get().and_then(Weak::upgrade)
}
fn active_connections() -> usize {
super::rolling_rtt_stats::SHADOW_RTT_REGISTRY.len()
}
const AGGREGATOR_INTERVAL: Duration = Duration::from_secs(1);
pub(crate) fn spawn_demand_aggregator(local_peer_id: String, monitor: &BackgroundTaskMonitor) {
let handle = tokio::spawn(async move {
let mut ticker = tokio::time::interval(AGGREGATOR_INTERVAL);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
ticker.tick().await;
let mut prev_sent = TRANSPORT_METRICS.cumulative_bytes_sent();
loop {
ticker.tick().await;
let now_sent = TRANSPORT_METRICS.cumulative_bytes_sent();
let sent_delta = now_sent.saturating_sub(prev_sent);
prev_sent = now_sent;
emit_demand_snapshot(&local_peer_id, sent_delta);
}
});
monitor.register("shadow_demand_aggregator", handle);
}
fn emit_demand_snapshot(local_peer_id: &str, sent_bytes_last_interval: u64) {
let active_connections = active_connections();
let broadcast_queue_depth = BROADCAST_QUEUE_DEPTH.load(Ordering::Relaxed);
let gbm = global_bandwidth();
let global_total_limit_bytes = gbm.as_ref().map(|m| m.total_limit() as u64);
let global_per_connection_rate_bytes =
gbm.as_ref().map(|m| m.current_per_connection_rate() as u64);
let global_active_connections = gbm.as_ref().map(|m| m.connection_count() as u64);
tracing::debug!(
target: "freenet::transport::shadow_demand",
sent_bytes_last_interval,
active_connections,
broadcast_queue_depth,
global_total_limit_bytes,
global_per_connection_rate_bytes,
"shadow_rate_demand"
);
crate::tracing::telemetry::send_standalone_event_with_peer_id(
"shadow_rate_demand",
local_peer_id,
serde_json::json!({
"sent_bytes_per_sec": sent_bytes_last_interval,
"active_connections": active_connections,
"broadcast_queue_depth": broadcast_queue_depth,
"global_total_limit_bytes": global_total_limit_bytes,
"global_per_connection_rate_bytes": global_per_connection_rate_bytes,
"global_active_connections": global_active_connections,
}),
);
}
pub(crate) fn spawn_outbound_class_aggregator(
local_peer_id: String,
monitor: &BackgroundTaskMonitor,
) {
let handle = tokio::spawn(async move {
let mut ticker = tokio::time::interval(AGGREGATOR_INTERVAL);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
ticker.tick().await;
let mut prev = snapshot_outbound();
loop {
ticker.tick().await;
let now = snapshot_outbound();
let (must_flow, short, bulk) = outbound_delta(prev, now);
prev = now;
emit_class_snapshot(&local_peer_id, must_flow, short, bulk);
}
});
monitor.register("shadow_outbound_class_aggregator", handle);
}
fn emit_class_snapshot(
local_peer_id: &str,
must_flow_bytes: u64,
short_bytes: u64,
bulk_bytes: u64,
) {
tracing::debug!(
target: "freenet::transport::shadow_demand",
must_flow_bytes,
short_bytes,
bulk_bytes,
"shadow_outbound_class"
);
crate::tracing::telemetry::send_standalone_event_with_peer_id(
"shadow_outbound_class",
local_peer_id,
serde_json::json!({
"must_flow_bytes_per_sec": must_flow_bytes,
"short_message_bytes_per_sec": short_bytes,
"bulk_bytes_per_sec": bulk_bytes,
}),
);
}
#[cfg(test)]
mod tests {
use super::*;
fn short_message() -> SymmetricMessagePayload {
SymmetricMessagePayload::ShortMessage {
payload: bytes::Bytes::new(),
}
}
fn stream_fragment() -> SymmetricMessagePayload {
SymmetricMessagePayload::StreamFragment {
stream_id: crate::transport::peer_connection::StreamId::next(),
total_length_bytes: 1000,
fragment_number: 0,
payload: bytes::Bytes::new(),
metadata_bytes: None,
}
}
#[test]
fn classify_maps_every_variant_to_its_bucket() {
assert_eq!(classify(&stream_fragment()), OutboundClass::Bulk);
assert_eq!(classify(&short_message()), OutboundClass::Short);
assert_eq!(
classify(&SymmetricMessagePayload::NoOp),
OutboundClass::MustFlow
);
assert_eq!(
classify(&SymmetricMessagePayload::Ping { sequence: 1 }),
OutboundClass::MustFlow
);
assert_eq!(
classify(&SymmetricMessagePayload::Pong { sequence: 1 }),
OutboundClass::MustFlow
);
assert_eq!(
classify(&SymmetricMessagePayload::AckConnection {
result: Err(std::borrow::Cow::Borrowed("rejected")),
}),
OutboundClass::MustFlow
);
}
#[test]
fn outbound_delta_is_per_interval_difference() {
let prev = OutboundSnapshot {
must_flow_bytes: 100,
short_bytes: 50,
bulk_bytes: 1000,
};
let now = OutboundSnapshot {
must_flow_bytes: 175,
short_bytes: 50,
bulk_bytes: 4000,
};
assert_eq!(outbound_delta(prev, now), (75, 0, 3000));
}
#[test]
fn outbound_delta_saturates_on_reset() {
let prev = OutboundSnapshot {
must_flow_bytes: 5000,
short_bytes: 5000,
bulk_bytes: 5000,
};
let now = OutboundSnapshot {
must_flow_bytes: 10,
short_bytes: 5000,
bulk_bytes: 6000,
};
assert_eq!(outbound_delta(prev, now), (0, 0, 1000));
}
#[test]
fn record_outbound_accumulates_into_the_right_counter() {
let before = snapshot_outbound();
record_outbound(OutboundClass::MustFlow, 11);
record_outbound(OutboundClass::Short, 22);
record_outbound(OutboundClass::Bulk, 33);
let after = snapshot_outbound();
assert!(after.must_flow_bytes - before.must_flow_bytes >= 11);
assert!(after.short_bytes - before.short_bytes >= 22);
assert!(after.bulk_bytes - before.bulk_bytes >= 33);
}
#[test]
fn outbound_delta_is_zero_for_a_quiet_interval() {
let snap = OutboundSnapshot {
must_flow_bytes: 12_345,
short_bytes: 67_890,
bulk_bytes: 1_000_000,
};
assert_eq!(outbound_delta(snap, snap), (0, 0, 0));
}
#[test]
fn broadcast_queue_depth_gauge_round_trips() {
record_broadcast_queue_depth(42);
assert_eq!(BROADCAST_QUEUE_DEPTH.load(Ordering::Relaxed), 42);
record_broadcast_queue_depth(0);
assert_eq!(BROADCAST_QUEUE_DEPTH.load(Ordering::Relaxed), 0);
}
#[tokio::test(start_paused = true)]
async fn demand_aggregator_survives_multiple_ticks() {
let monitor = BackgroundTaskMonitor::new();
spawn_demand_aggregator("test-peer".to_string(), &monitor);
tokio::time::advance(AGGREGATOR_INTERVAL * 3 + Duration::from_millis(100)).await;
tokio::task::yield_now().await;
let exit = monitor.wait_for_any_exit();
tokio::pin!(exit);
let still_running = tokio::time::timeout(Duration::from_millis(50), &mut exit)
.await
.is_err();
assert!(
still_running,
"demand aggregator task should still be alive after a few ticks"
);
}
#[tokio::test(start_paused = true)]
async fn class_aggregator_survives_multiple_ticks() {
let monitor = BackgroundTaskMonitor::new();
spawn_outbound_class_aggregator("test-peer".to_string(), &monitor);
tokio::time::advance(AGGREGATOR_INTERVAL * 3 + Duration::from_millis(100)).await;
tokio::task::yield_now().await;
let exit = monitor.wait_for_any_exit();
tokio::pin!(exit);
let still_running = tokio::time::timeout(Duration::from_millis(50), &mut exit)
.await
.is_err();
assert!(
still_running,
"class aggregator task should still be alive after a few ticks"
);
}
}