use crate::dm::DmPath;
use crate::error::{NetworkError, NetworkResult};
use crate::identity::{AgentId, MachineId};
use crate::trust::TrustDecision;
use serde::Serialize;
use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::{broadcast, mpsc, Notify, RwLock};
pub const DIRECT_MESSAGE_STREAM_TYPE: u8 = 0x10;
pub const MAX_DIRECT_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
const DIRECT_SUBSCRIBER_BUFFER: usize = 8192;
const DIRECT_DIAGNOSTICS_IDLE_TTL_MS: u64 = 24 * 60 * 60 * 1000;
const DIRECT_DIAGNOSTICS_MIN_RETAIN: usize = 1024;
const LIFECYCLE_REPLACED_BROADCAST_CAPACITY: usize = 256;
#[doc(hidden)]
pub struct RawQuicAckRaceTestHook {
first_attempt_started: Notify,
first_attempt_result_release: Notify,
replaced_short_circuit: Notify,
}
impl std::fmt::Debug for RawQuicAckRaceTestHook {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RawQuicAckRaceTestHook")
.finish_non_exhaustive()
}
}
impl Default for RawQuicAckRaceTestHook {
fn default() -> Self {
Self {
first_attempt_started: Notify::new(),
first_attempt_result_release: Notify::new(),
replaced_short_circuit: Notify::new(),
}
}
}
impl RawQuicAckRaceTestHook {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub async fn wait_first_attempt_started(&self) {
self.first_attempt_started.notified().await;
}
pub fn release_first_attempt_result(&self) {
self.first_attempt_result_release.notify_one();
}
pub async fn wait_replaced_short_circuit(&self) {
self.replaced_short_circuit.notified().await;
}
pub(crate) fn notify_first_attempt_started(&self) {
self.first_attempt_started.notify_one();
}
pub(crate) async fn hold_first_attempt_result(&self) {
self.first_attempt_result_release.notified().await;
}
pub(crate) fn notify_replaced_short_circuit(&self) {
self.replaced_short_circuit.notify_one();
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DirectMessage {
pub sender: AgentId,
pub machine_id: MachineId,
pub payload: Vec<u8>,
pub received_at: u64,
pub verified: bool,
pub trust_decision: Option<TrustDecision>,
}
impl DirectMessage {
#[must_use]
pub fn new(sender: AgentId, machine_id: MachineId, payload: Vec<u8>) -> Self {
Self::new_verified(sender, machine_id, payload, false, None)
}
#[must_use]
pub fn new_verified(
sender: AgentId,
machine_id: MachineId,
payload: Vec<u8>,
verified: bool,
trust_decision: Option<TrustDecision>,
) -> Self {
let received_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
Self {
sender,
machine_id,
payload,
received_at,
verified,
trust_decision,
}
}
#[must_use]
pub fn payload_str(&self) -> Option<&str> {
std::str::from_utf8(&self.payload).ok()
}
}
fn now_unix_ms_lossy() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
fn direct_diagnostics_retain_limit(connected_len: usize) -> usize {
DIRECT_DIAGNOSTICS_MIN_RETAIN.max(connected_len.saturating_mul(2))
}
fn dm_path_label(path: DmPath) -> &'static str {
match path {
DmPath::Loopback => "loopback",
DmPath::GossipInbox => "gossip_inbox",
DmPath::RawQuic => "raw_quic",
DmPath::RawQuicAcked => "raw_quic_acked",
}
}
#[must_use]
pub fn dm_payload_digest_hex(bytes: &[u8]) -> String {
let hash = blake3::hash(bytes);
let hex = hex::encode(hash.as_bytes());
hex[..16].to_string()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum DirectSubscriberPush {
Delivered,
DeliveredWithEviction,
Closed,
}
#[derive(Debug)]
struct DirectSubscriberQueue {
queue: Mutex<VecDeque<DirectMessage>>,
notify: Notify,
closed: AtomicBool,
capacity: usize,
}
impl DirectSubscriberQueue {
fn new(capacity: usize) -> Self {
Self {
queue: Mutex::new(VecDeque::with_capacity(capacity)),
notify: Notify::new(),
closed: AtomicBool::new(false),
capacity,
}
}
fn push_drop_oldest(&self, msg: DirectMessage) -> DirectSubscriberPush {
if self.closed.load(Ordering::Relaxed) {
return DirectSubscriberPush::Closed;
}
let mut queue = match self.queue.lock() {
Ok(queue) => queue,
Err(e) => {
tracing::error!("direct subscriber queue poisoned: {e}");
return DirectSubscriberPush::Closed;
}
};
if self.closed.load(Ordering::Relaxed) {
return DirectSubscriberPush::Closed;
}
let evicted = if queue.len() >= self.capacity {
queue.pop_front();
true
} else {
false
};
queue.push_back(msg);
drop(queue);
self.notify.notify_one();
if evicted {
DirectSubscriberPush::DeliveredWithEviction
} else {
DirectSubscriberPush::Delivered
}
}
fn pop_front(&self) -> Option<DirectMessage> {
match self.queue.lock() {
Ok(mut queue) => queue.pop_front(),
Err(e) => {
tracing::error!("direct subscriber queue poisoned: {e}");
None
}
}
}
fn is_closed(&self) -> bool {
self.closed.load(Ordering::Relaxed)
}
fn close(&self) {
self.closed.store(true, Ordering::Relaxed);
self.notify.notify_waiters();
}
}
#[derive(Debug)]
pub struct DirectMessageReceiver {
id: Option<u64>,
queue: Arc<DirectSubscriberQueue>,
subscribers: Arc<Mutex<HashMap<u64, Arc<DirectSubscriberQueue>>>>,
next_subscriber_id: Arc<AtomicU64>,
capacity: usize,
}
impl DirectMessageReceiver {
fn new(
subscribers: Arc<Mutex<HashMap<u64, Arc<DirectSubscriberQueue>>>>,
next_subscriber_id: Arc<AtomicU64>,
capacity: usize,
) -> Self {
let queue = Arc::new(DirectSubscriberQueue::new(capacity));
let id = next_subscriber_id.fetch_add(1, Ordering::Relaxed);
let registered = match subscribers.lock() {
Ok(mut guard) => {
guard.insert(id, Arc::clone(&queue));
Some(id)
}
Err(e) => {
tracing::error!("direct subscriber registry poisoned: {e}");
queue.close();
None
}
};
Self {
id: registered,
queue,
subscribers,
next_subscriber_id,
capacity,
}
}
pub async fn recv(&mut self) -> Option<DirectMessage> {
loop {
let notified = self.queue.notify.notified();
if let Some(msg) = self.queue.pop_front() {
return Some(msg);
}
if self.queue.is_closed() {
return None;
}
notified.await;
}
}
pub fn try_recv(&mut self) -> Option<DirectMessage> {
self.queue.pop_front()
}
}
impl Clone for DirectMessageReceiver {
fn clone(&self) -> Self {
Self::new(
Arc::clone(&self.subscribers),
Arc::clone(&self.next_subscriber_id),
self.capacity,
)
}
}
impl Drop for DirectMessageReceiver {
fn drop(&mut self) {
let Some(id) = self.id.take() else {
return;
};
match self.subscribers.lock() {
Ok(mut guard) => {
guard.remove(&id);
}
Err(e) => tracing::error!("direct subscriber registry poisoned on drop: {e}"),
}
self.queue.close();
}
}
#[derive(Debug, Default)]
struct DirectDiagnosticsCounters {
outgoing_send_total: AtomicU64,
outgoing_send_succeeded: AtomicU64,
outgoing_send_failed: AtomicU64,
outgoing_path_loopback: AtomicU64,
outgoing_path_raw_quic: AtomicU64,
outgoing_path_gossip_inbox: AtomicU64,
incoming_envelopes_total: AtomicU64,
incoming_decode_failed: AtomicU64,
incoming_signature_failed: AtomicU64,
incoming_trust_rejected: AtomicU64,
incoming_delivered_to_subscribe: AtomicU64,
subscriber_channel_lagged: AtomicU64,
subscriber_events_evicted: AtomicU64,
subscriber_channel_closed: AtomicU64,
}
#[derive(Debug, Clone, Default)]
struct DirectPeerDiagnosticsState {
avg_rtt_ms: Option<u32>,
last_send_at_ms: Option<u64>,
last_recv_at_ms: Option<u64>,
send_succeeded: u64,
send_failed: u64,
recv_count: u64,
preferred_path: Option<&'static str>,
}
impl DirectPeerDiagnosticsState {
fn last_activity_ms(&self) -> Option<u64> {
self.last_send_at_ms.max(self.last_recv_at_ms)
}
}
#[derive(Debug, Clone, Default)]
struct DirectLifecycleState {
generation: Option<u64>,
blocked_reason: Option<String>,
last_updated_at_ms: Option<u64>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct DmDiagnosticsStats {
pub outgoing_send_total: u64,
pub outgoing_send_succeeded: u64,
pub outgoing_send_failed: u64,
pub outgoing_path_loopback: u64,
pub outgoing_path_raw_quic: u64,
pub outgoing_path_gossip_inbox: u64,
pub incoming_envelopes_total: u64,
pub incoming_decode_failed: u64,
pub incoming_signature_failed: u64,
pub incoming_trust_rejected: u64,
pub incoming_delivered_to_subscribe: u64,
pub subscriber_events_evicted: u64,
pub subscriber_channel_lagged: u64,
pub subscriber_channel_closed: u64,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct DmPeerDiagnostics {
pub avg_rtt_ms: Option<u32>,
pub last_send_ms_ago: Option<u64>,
pub last_recv_ms_ago: Option<u64>,
pub send_succeeded: u64,
pub send_failed: u64,
pub recv_count: u64,
pub preferred_path: String,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct DmDiagnosticsSnapshot {
pub stats: DmDiagnosticsStats,
pub per_peer: BTreeMap<String, DmPeerDiagnostics>,
pub subscriber_count: usize,
pub subscriber_capacity: usize,
}
#[derive(Debug)]
pub struct DirectMessaging {
machine_to_agent: Arc<RwLock<HashMap<MachineId, AgentId>>>,
connected_agents: Arc<RwLock<HashMap<AgentId, MachineId>>>,
subscribers: Arc<Mutex<HashMap<u64, Arc<DirectSubscriberQueue>>>>,
next_subscriber_id: Arc<AtomicU64>,
subscriber_capacity: usize,
diagnostics: Arc<DirectDiagnosticsCounters>,
peer_diagnostics: Arc<Mutex<HashMap<AgentId, DirectPeerDiagnosticsState>>>,
lifecycle: Arc<Mutex<HashMap<MachineId, DirectLifecycleState>>>,
lifecycle_replaced_tx: broadcast::Sender<(MachineId, u64)>,
raw_quic_ack_race_test_hook: Arc<Mutex<Option<Arc<RawQuicAckRaceTestHook>>>>,
internal_tx: mpsc::Sender<DirectMessage>,
internal_rx: Arc<tokio::sync::Mutex<mpsc::Receiver<DirectMessage>>>,
}
impl DirectMessaging {
#[must_use]
pub fn new() -> Self {
Self::with_subscriber_capacity(DIRECT_SUBSCRIBER_BUFFER)
}
fn with_subscriber_capacity(subscriber_capacity: usize) -> Self {
let subscriber_capacity = subscriber_capacity.max(1);
let (internal_tx, internal_rx) = mpsc::channel(subscriber_capacity);
let (lifecycle_replaced_tx, _) = broadcast::channel(LIFECYCLE_REPLACED_BROADCAST_CAPACITY);
Self {
machine_to_agent: Arc::new(RwLock::new(HashMap::new())),
connected_agents: Arc::new(RwLock::new(HashMap::new())),
subscribers: Arc::new(Mutex::new(HashMap::new())),
next_subscriber_id: Arc::new(AtomicU64::new(1)),
subscriber_capacity,
diagnostics: Arc::new(DirectDiagnosticsCounters::default()),
peer_diagnostics: Arc::new(Mutex::new(HashMap::new())),
lifecycle: Arc::new(Mutex::new(HashMap::new())),
lifecycle_replaced_tx,
raw_quic_ack_race_test_hook: Arc::new(Mutex::new(None)),
internal_tx,
internal_rx: Arc::new(tokio::sync::Mutex::new(internal_rx)),
}
}
pub async fn register_agent(&self, agent_id: AgentId, machine_id: MachineId) {
let mut map = self.machine_to_agent.write().await;
map.insert(machine_id, agent_id);
tracing::debug!(
"Registered agent mapping: {:?} -> {:?}",
machine_id,
agent_id
);
}
pub async fn lookup_agent(&self, machine_id: &MachineId) -> Option<AgentId> {
let map = self.machine_to_agent.read().await;
map.get(machine_id).copied()
}
pub async fn mark_connected(&self, agent_id: AgentId, machine_id: MachineId) {
self.register_agent(agent_id, machine_id).await;
let mut connected = self.connected_agents.write().await;
connected.insert(agent_id, machine_id);
self.record_lifecycle_established(machine_id, None);
tracing::info!("Agent connected: {:?}", agent_id);
}
pub async fn mark_disconnected(&self, agent_id: &AgentId) {
let mut connected = self.connected_agents.write().await;
connected.remove(agent_id);
tracing::info!("Agent disconnected: {:?}", agent_id);
}
pub fn record_lifecycle_established(&self, machine_id: MachineId, generation: Option<u64>) {
self.update_lifecycle(machine_id, |state| {
if let Some(generation) = generation {
state.generation = Some(generation);
}
state.blocked_reason = None;
});
}
pub fn record_lifecycle_replaced(&self, machine_id: MachineId, new_generation: u64) {
self.update_lifecycle(machine_id, |state| {
state.generation = Some(new_generation);
state.blocked_reason = None;
});
let _ = self
.lifecycle_replaced_tx
.send((machine_id, new_generation));
}
#[must_use]
pub fn current_generation(&self, machine_id: &MachineId) -> Option<u64> {
match self.lifecycle.lock() {
Ok(guard) => guard.get(machine_id).and_then(|state| state.generation),
Err(e) => {
tracing::error!("direct lifecycle registry poisoned: {e}");
None
}
}
}
#[must_use]
pub fn subscribe_lifecycle_replaced(&self) -> broadcast::Receiver<(MachineId, u64)> {
self.lifecycle_replaced_tx.subscribe()
}
#[doc(hidden)]
pub fn set_raw_quic_ack_race_test_hook_for_testing(
&self,
hook: Option<Arc<RawQuicAckRaceTestHook>>,
) {
match self.raw_quic_ack_race_test_hook.lock() {
Ok(mut guard) => *guard = hook,
Err(e) => tracing::error!("raw QUIC ACK race test hook poisoned: {e}"),
}
}
pub(crate) fn raw_quic_ack_race_test_hook(&self) -> Option<Arc<RawQuicAckRaceTestHook>> {
match self.raw_quic_ack_race_test_hook.lock() {
Ok(guard) => guard.clone(),
Err(e) => {
tracing::error!("raw QUIC ACK race test hook poisoned: {e}");
None
}
}
}
pub fn record_lifecycle_blocked(
&self,
machine_id: MachineId,
generation: Option<u64>,
reason: impl Into<String>,
) {
let reason = reason.into();
self.update_lifecycle(machine_id, |state| {
if let Some(generation) = generation {
match state.generation {
Some(current) if current != generation => return,
Some(_) => {}
None => state.generation = Some(generation),
}
}
state.blocked_reason = Some(reason);
});
}
#[must_use]
pub fn lifecycle_block_reason(&self, machine_id: &MachineId) -> Option<String> {
match self.lifecycle.lock() {
Ok(guard) => guard
.get(machine_id)
.and_then(|state| state.blocked_reason.clone()),
Err(e) => {
tracing::error!("direct lifecycle registry poisoned: {e}");
None
}
}
}
pub async fn is_connected(&self, agent_id: &AgentId) -> bool {
let connected = self.connected_agents.read().await;
connected.contains_key(agent_id)
}
pub async fn get_machine_id(&self, agent_id: &AgentId) -> Option<MachineId> {
let connected = self.connected_agents.read().await;
connected.get(agent_id).copied()
}
pub async fn connected_agents(&self) -> Vec<AgentId> {
let connected = self.connected_agents.read().await;
connected.keys().copied().collect()
}
pub fn subscribe(&self) -> DirectMessageReceiver {
DirectMessageReceiver::new(
Arc::clone(&self.subscribers),
Arc::clone(&self.next_subscriber_id),
self.subscriber_capacity,
)
}
pub fn subscriber_count(&self) -> usize {
match self.subscribers.lock() {
Ok(guard) => guard.len(),
Err(e) => {
tracing::error!("direct subscriber registry poisoned: {e}");
0
}
}
}
pub(crate) fn record_outgoing_started(&self, agent_id: AgentId, avg_rtt_ms: Option<u32>) {
self.diagnostics
.outgoing_send_total
.fetch_add(1, Ordering::Relaxed);
let now_ms = now_unix_ms_lossy();
self.with_peer_diagnostics(agent_id, |peer| {
peer.last_send_at_ms = Some(now_ms);
if let Some(rtt) = avg_rtt_ms.filter(|rtt| *rtt > 0) {
peer.avg_rtt_ms = Some(rtt);
}
});
}
pub(crate) fn record_outgoing_succeeded(&self, agent_id: AgentId, path: DmPath) {
self.diagnostics
.outgoing_send_succeeded
.fetch_add(1, Ordering::Relaxed);
match path {
DmPath::Loopback => {
self.diagnostics
.outgoing_path_loopback
.fetch_add(1, Ordering::Relaxed);
}
DmPath::RawQuic | DmPath::RawQuicAcked => {
self.diagnostics
.outgoing_path_raw_quic
.fetch_add(1, Ordering::Relaxed);
}
DmPath::GossipInbox => {
self.diagnostics
.outgoing_path_gossip_inbox
.fetch_add(1, Ordering::Relaxed);
}
}
let path_label = dm_path_label(path);
self.with_peer_diagnostics(agent_id, |peer| {
peer.send_succeeded = peer.send_succeeded.saturating_add(1);
peer.preferred_path = Some(path_label);
});
}
pub(crate) fn record_outgoing_failed(&self, agent_id: AgentId) {
self.diagnostics
.outgoing_send_failed
.fetch_add(1, Ordering::Relaxed);
self.with_peer_diagnostics(agent_id, |peer| {
peer.send_failed = peer.send_failed.saturating_add(1);
});
}
pub(crate) fn record_incoming_decode_failed(&self) {
self.diagnostics
.incoming_decode_failed
.fetch_add(1, Ordering::Relaxed);
}
pub(crate) fn record_incoming_signature_failed(&self) {
self.diagnostics
.incoming_signature_failed
.fetch_add(1, Ordering::Relaxed);
}
pub(crate) fn record_incoming_trust_rejected(&self, agent_id: AgentId) {
self.diagnostics
.incoming_trust_rejected
.fetch_add(1, Ordering::Relaxed);
self.with_peer_diagnostics(agent_id, |_| {});
}
#[must_use]
pub fn diagnostics_snapshot(&self) -> DmDiagnosticsSnapshot {
let stats = DmDiagnosticsStats {
outgoing_send_total: self.diagnostics.outgoing_send_total.load(Ordering::Relaxed),
outgoing_send_succeeded: self
.diagnostics
.outgoing_send_succeeded
.load(Ordering::Relaxed),
outgoing_send_failed: self
.diagnostics
.outgoing_send_failed
.load(Ordering::Relaxed),
outgoing_path_loopback: self
.diagnostics
.outgoing_path_loopback
.load(Ordering::Relaxed),
outgoing_path_raw_quic: self
.diagnostics
.outgoing_path_raw_quic
.load(Ordering::Relaxed),
outgoing_path_gossip_inbox: self
.diagnostics
.outgoing_path_gossip_inbox
.load(Ordering::Relaxed),
incoming_envelopes_total: self
.diagnostics
.incoming_envelopes_total
.load(Ordering::Relaxed),
incoming_decode_failed: self
.diagnostics
.incoming_decode_failed
.load(Ordering::Relaxed),
incoming_signature_failed: self
.diagnostics
.incoming_signature_failed
.load(Ordering::Relaxed),
incoming_trust_rejected: self
.diagnostics
.incoming_trust_rejected
.load(Ordering::Relaxed),
incoming_delivered_to_subscribe: self
.diagnostics
.incoming_delivered_to_subscribe
.load(Ordering::Relaxed),
subscriber_channel_lagged: self
.diagnostics
.subscriber_channel_lagged
.load(Ordering::Relaxed),
subscriber_events_evicted: self
.diagnostics
.subscriber_events_evicted
.load(Ordering::Relaxed),
subscriber_channel_closed: self
.diagnostics
.subscriber_channel_closed
.load(Ordering::Relaxed),
};
let now_ms = now_unix_ms_lossy();
let per_peer = match self.peer_diagnostics.lock() {
Ok(guard) => guard
.iter()
.map(|(agent_id, peer)| {
(
hex::encode(agent_id.as_bytes()),
DmPeerDiagnostics {
avg_rtt_ms: peer.avg_rtt_ms,
last_send_ms_ago: peer
.last_send_at_ms
.map(|ts| now_ms.saturating_sub(ts)),
last_recv_ms_ago: peer
.last_recv_at_ms
.map(|ts| now_ms.saturating_sub(ts)),
send_succeeded: peer.send_succeeded,
send_failed: peer.send_failed,
recv_count: peer.recv_count,
preferred_path: peer.preferred_path.unwrap_or("unknown").to_string(),
},
)
})
.collect(),
Err(e) => {
tracing::error!("direct peer diagnostics registry poisoned: {e}");
BTreeMap::new()
}
};
DmDiagnosticsSnapshot {
stats,
per_peer,
subscriber_count: self.subscriber_count(),
subscriber_capacity: self.subscriber_capacity,
}
}
pub(crate) async fn handle_loopback(
&self,
machine_id: MachineId,
agent_id: AgentId,
payload: Vec<u8>,
) -> u64 {
self.handle_incoming(
machine_id,
agent_id,
payload,
true,
Some(TrustDecision::Accept),
)
.await
}
pub async fn handle_incoming(
&self,
machine_id: MachineId,
sender_agent_id: AgentId,
payload: Vec<u8>,
verified: bool,
trust_decision: Option<TrustDecision>,
) -> u64 {
self.diagnostics
.incoming_envelopes_total
.fetch_add(1, Ordering::Relaxed);
let now_ms = now_unix_ms_lossy();
self.with_peer_diagnostics(sender_agent_id, |peer| {
peer.last_recv_at_ms = Some(now_ms);
peer.recv_count = peer.recv_count.saturating_add(1);
});
let msg = DirectMessage::new_verified(
sender_agent_id,
machine_id,
payload,
verified,
trust_decision,
);
let subscribers = self.subscriber_snapshot();
let mut delivered = 0_u64;
let mut remove_ids = Vec::new();
for (id, queue) in subscribers {
match queue.push_drop_oldest(msg.clone()) {
DirectSubscriberPush::Delivered => {
delivered = delivered.saturating_add(1);
}
DirectSubscriberPush::DeliveredWithEviction => {
self.diagnostics
.subscriber_channel_lagged
.fetch_add(1, Ordering::Relaxed);
self.diagnostics
.subscriber_events_evicted
.fetch_add(1, Ordering::Relaxed);
tracing::warn!(
subscriber_id = id,
capacity = self.subscriber_capacity,
"direct subscriber queue full; evicted oldest buffered event"
);
delivered = delivered.saturating_add(1);
}
DirectSubscriberPush::Closed => {
self.diagnostics
.subscriber_channel_closed
.fetch_add(1, Ordering::Relaxed);
remove_ids.push(id);
}
}
}
if delivered > 0 {
self.diagnostics
.incoming_delivered_to_subscribe
.fetch_add(1, Ordering::Relaxed);
}
if !remove_ids.is_empty() {
self.remove_subscribers(&remove_ids);
}
if self.internal_tx.try_send(msg).is_err() {
tracing::trace!("direct internal_tx full or closed, skipping pull-API copy");
}
delivered
}
fn update_lifecycle(
&self,
machine_id: MachineId,
update: impl FnOnce(&mut DirectLifecycleState),
) {
match self.lifecycle.lock() {
Ok(mut guard) => {
let state = guard.entry(machine_id).or_default();
state.last_updated_at_ms = Some(now_unix_ms_lossy());
update(state);
if guard.len() > DIRECT_DIAGNOSTICS_MIN_RETAIN {
if let Some(connected) = self.connected_machine_snapshot() {
Self::prune_lifecycle_locked(&mut guard, &connected);
}
}
}
Err(e) => tracing::error!("direct lifecycle registry poisoned: {e}"),
}
}
fn with_peer_diagnostics(
&self,
agent_id: AgentId,
update: impl FnOnce(&mut DirectPeerDiagnosticsState),
) {
match self.peer_diagnostics.lock() {
Ok(mut guard) => {
let peer = guard.entry(agent_id).or_default();
update(peer);
if guard.len() > DIRECT_DIAGNOSTICS_MIN_RETAIN {
if let Some(connected) = self.connected_agent_snapshot() {
Self::prune_peer_diagnostics_locked(&mut guard, &connected);
}
}
}
Err(e) => tracing::error!("direct peer diagnostics registry poisoned: {e}"),
}
}
fn connected_agent_snapshot(&self) -> Option<HashSet<AgentId>> {
match self.connected_agents.try_read() {
Ok(guard) => Some(guard.keys().copied().collect()),
Err(_) => None,
}
}
fn connected_machine_snapshot(&self) -> Option<HashSet<MachineId>> {
match self.connected_agents.try_read() {
Ok(guard) => Some(guard.values().copied().collect()),
Err(_) => None,
}
}
fn prune_peer_diagnostics_locked(
guard: &mut HashMap<AgentId, DirectPeerDiagnosticsState>,
connected: &HashSet<AgentId>,
) {
let limit = direct_diagnostics_retain_limit(connected.len());
if guard.len() <= limit {
return;
}
let now = now_unix_ms_lossy();
guard.retain(|agent_id, state| {
connected.contains(agent_id)
|| state
.last_activity_ms()
.is_some_and(|last| now.saturating_sub(last) <= DIRECT_DIAGNOSTICS_IDLE_TTL_MS)
});
if guard.len() <= limit {
return;
}
let mut idle: Vec<(AgentId, u64)> = guard
.iter()
.filter(|(agent_id, _)| !connected.contains(agent_id))
.map(|(agent_id, state)| (*agent_id, state.last_activity_ms().unwrap_or(0)))
.collect();
idle.sort_by_key(|(_, last)| *last);
let remove_count = guard.len().saturating_sub(limit).min(idle.len());
for (agent_id, _) in idle.into_iter().take(remove_count) {
guard.remove(&agent_id);
}
}
fn prune_lifecycle_locked(
guard: &mut HashMap<MachineId, DirectLifecycleState>,
connected: &HashSet<MachineId>,
) {
let limit = direct_diagnostics_retain_limit(connected.len());
if guard.len() <= limit {
return;
}
let now = now_unix_ms_lossy();
guard.retain(|machine_id, state| {
connected.contains(machine_id)
|| state
.last_updated_at_ms
.is_some_and(|last| now.saturating_sub(last) <= DIRECT_DIAGNOSTICS_IDLE_TTL_MS)
});
if guard.len() <= limit {
return;
}
let mut idle: Vec<(MachineId, u64)> = guard
.iter()
.filter(|(machine_id, _)| !connected.contains(machine_id))
.map(|(machine_id, state)| (*machine_id, state.last_updated_at_ms.unwrap_or(0)))
.collect();
idle.sort_by_key(|(_, last)| *last);
let remove_count = guard.len().saturating_sub(limit).min(idle.len());
for (machine_id, _) in idle.into_iter().take(remove_count) {
guard.remove(&machine_id);
}
}
fn subscriber_snapshot(&self) -> Vec<(u64, Arc<DirectSubscriberQueue>)> {
match self.subscribers.lock() {
Ok(guard) => guard.iter().map(|(id, tx)| (*id, tx.clone())).collect(),
Err(e) => {
tracing::error!("direct subscriber registry poisoned: {e}");
Vec::new()
}
}
}
fn remove_subscribers(&self, ids: &[u64]) {
match self.subscribers.lock() {
Ok(mut guard) => {
for id in ids {
guard.remove(id);
}
}
Err(e) => tracing::error!("direct subscriber registry poisoned: {e}"),
}
}
pub async fn recv(&self) -> Option<DirectMessage> {
let mut rx = self.internal_rx.lock().await;
rx.recv().await
}
pub fn encode_message(sender_agent_id: &AgentId, payload: &[u8]) -> NetworkResult<Vec<u8>> {
if payload.len() > MAX_DIRECT_PAYLOAD_SIZE {
return Err(NetworkError::PayloadTooLarge {
size: payload.len(),
max: MAX_DIRECT_PAYLOAD_SIZE,
});
}
let mut buf = Vec::with_capacity(1 + 32 + payload.len());
buf.push(DIRECT_MESSAGE_STREAM_TYPE);
buf.extend_from_slice(&sender_agent_id.0);
buf.extend_from_slice(payload);
Ok(buf)
}
pub fn decode_message(data: &[u8]) -> NetworkResult<(AgentId, Vec<u8>)> {
if data.len() < 33 {
return Err(NetworkError::InvalidMessage(
"Direct message too short".to_string(),
));
}
if data[0] != DIRECT_MESSAGE_STREAM_TYPE {
return Err(NetworkError::InvalidMessage(format!(
"Invalid stream type byte: expected {}, got {}",
DIRECT_MESSAGE_STREAM_TYPE, data[0]
)));
}
let mut agent_id_bytes = [0u8; 32];
agent_id_bytes.copy_from_slice(&data[1..33]);
let sender = AgentId(agent_id_bytes);
let payload = data[33..].to_vec();
Ok((sender, payload))
}
}
impl Default for DirectMessaging {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dm_payload_digest_is_stable_and_short() {
let payload = b"hello world".to_vec();
let digest = dm_payload_digest_hex(&payload);
assert_eq!(digest.len(), 16);
assert!(digest.chars().all(|c| c.is_ascii_hexdigit()));
assert_eq!(dm_payload_digest_hex(&payload), digest);
let other = dm_payload_digest_hex(b"different");
assert_ne!(other, digest);
}
#[test]
fn test_encode_decode_roundtrip() {
let agent_id = AgentId([42u8; 32]);
let payload = b"hello world".to_vec();
let encoded = DirectMessaging::encode_message(&agent_id, &payload).unwrap();
assert_eq!(encoded[0], DIRECT_MESSAGE_STREAM_TYPE);
assert_eq!(encoded.len(), 1 + 32 + payload.len());
let (decoded_agent, decoded_payload) = DirectMessaging::decode_message(&encoded).unwrap();
assert_eq!(decoded_agent, agent_id);
assert_eq!(decoded_payload, payload);
}
#[test]
fn test_decode_too_short() {
let short_data = vec![DIRECT_MESSAGE_STREAM_TYPE; 10];
let result = DirectMessaging::decode_message(&short_data);
assert!(result.is_err());
}
#[test]
fn test_decode_wrong_type() {
let mut data = vec![0x00; 50]; data[0] = 0x01;
let result = DirectMessaging::decode_message(&data);
assert!(result.is_err());
}
#[test]
fn test_encode_payload_too_large() {
let agent_id = AgentId([1u8; 32]);
let payload = vec![0u8; MAX_DIRECT_PAYLOAD_SIZE + 1];
let result = DirectMessaging::encode_message(&agent_id, &payload);
assert!(result.is_err());
}
#[tokio::test]
async fn test_register_and_lookup() {
let dm = DirectMessaging::new();
let agent_id = AgentId([1u8; 32]);
let machine_id = MachineId([2u8; 32]);
dm.register_agent(agent_id, machine_id).await;
let lookup = dm.lookup_agent(&machine_id).await;
assert_eq!(lookup, Some(agent_id));
}
#[tokio::test]
async fn test_connection_tracking() {
let dm = DirectMessaging::new();
let agent_id = AgentId([1u8; 32]);
let machine_id = MachineId([2u8; 32]);
assert!(!dm.is_connected(&agent_id).await);
dm.mark_connected(agent_id, machine_id).await;
assert!(dm.is_connected(&agent_id).await);
assert_eq!(dm.get_machine_id(&agent_id).await, Some(machine_id));
let connected = dm.connected_agents().await;
assert_eq!(connected, vec![agent_id]);
dm.mark_disconnected(&agent_id).await;
assert!(!dm.is_connected(&agent_id).await);
}
#[tokio::test]
async fn test_message_subscription() {
let dm = DirectMessaging::new();
let mut rx = dm.subscribe();
let sender = AgentId([1u8; 32]);
let machine_id = MachineId([2u8; 32]);
let payload = b"test message".to_vec();
dm.handle_incoming(machine_id, sender, payload.clone(), true, None)
.await;
let msg = rx.recv().await.unwrap();
assert_eq!(msg.sender, sender);
assert_eq!(msg.machine_id, machine_id);
assert_eq!(msg.payload, payload);
assert!(msg.verified);
assert!(msg.trust_decision.is_none());
let snap = dm.diagnostics_snapshot();
assert_eq!(snap.stats.incoming_envelopes_total, 1);
assert_eq!(snap.stats.incoming_delivered_to_subscribe, 1);
assert_eq!(snap.stats.subscriber_channel_lagged, 0);
}
#[tokio::test]
async fn test_message_subscription_clone_gets_independent_queue() {
let dm = DirectMessaging::new();
let mut rx1 = dm.subscribe();
let mut rx2 = rx1.clone();
let sender = AgentId([3u8; 32]);
let machine_id = MachineId([4u8; 32]);
let payload = b"fanout".to_vec();
dm.handle_incoming(machine_id, sender, payload.clone(), true, None)
.await;
assert_eq!(rx1.recv().await.unwrap().payload, payload);
assert_eq!(rx2.recv().await.unwrap().payload, payload);
assert_eq!(dm.subscriber_count(), 2);
}
#[tokio::test]
async fn dm_subscriber_bounded_drop_oldest_keeps_stream_alive() {
let dm = DirectMessaging::with_subscriber_capacity(2);
let mut lagging_rx = dm.subscribe();
let sender = AgentId([5u8; 32]);
let machine_id = MachineId([6u8; 32]);
for idx in 0_u64..=2 {
dm.handle_incoming(machine_id, sender, idx.to_be_bytes().to_vec(), true, None)
.await;
}
let snap = dm.diagnostics_snapshot();
assert_eq!(snap.stats.subscriber_channel_lagged, 1);
assert_eq!(snap.stats.subscriber_events_evicted, 1);
assert_eq!(snap.subscriber_count, 1);
let first = lagging_rx.recv().await.unwrap();
assert_eq!(first.payload, 1_u64.to_be_bytes().to_vec());
let second = lagging_rx.recv().await.unwrap();
assert_eq!(second.payload, 2_u64.to_be_bytes().to_vec());
}
#[test]
fn x0x_0041_current_generation_tracks_established_and_replaced() {
let dm = DirectMessaging::new();
let machine_id = MachineId([0xAB; 32]);
assert_eq!(dm.current_generation(&machine_id), None);
dm.record_lifecycle_established(machine_id, Some(7));
assert_eq!(dm.current_generation(&machine_id), Some(7));
dm.record_lifecycle_replaced(machine_id, 9);
assert_eq!(dm.current_generation(&machine_id), Some(9));
}
#[tokio::test]
async fn x0x_0041_subscribe_lifecycle_replaced_broadcasts_supersede() {
let dm = DirectMessaging::new();
let mut rx = dm.subscribe_lifecycle_replaced();
let machine_a = MachineId([0xA1; 32]);
let machine_b = MachineId([0xB2; 32]);
dm.record_lifecycle_established(machine_a, Some(1));
match rx.try_recv() {
Err(tokio::sync::broadcast::error::TryRecvError::Empty) => {}
other => panic!("expected Empty for Established, got {other:?}"),
}
dm.record_lifecycle_replaced(machine_a, 2);
let (m, gen) = rx.recv().await.expect("Replaced event");
assert_eq!(m, machine_a);
assert_eq!(gen, 2);
dm.record_lifecycle_replaced(machine_b, 5);
let (m, gen) = rx.recv().await.expect("Replaced event");
assert_eq!(m, machine_b);
assert_eq!(gen, 5);
}
#[test]
fn test_lifecycle_blocks_only_current_generation() {
let dm = DirectMessaging::new();
let machine_id = MachineId([7u8; 32]);
dm.record_lifecycle_established(machine_id, Some(1));
assert!(dm.lifecycle_block_reason(&machine_id).is_none());
dm.record_lifecycle_replaced(machine_id, 2);
dm.record_lifecycle_blocked(machine_id, Some(1), "closed: superseded");
assert!(dm.lifecycle_block_reason(&machine_id).is_none());
dm.record_lifecycle_blocked(machine_id, Some(2), "closed: timed out");
assert_eq!(
dm.lifecycle_block_reason(&machine_id).as_deref(),
Some("closed: timed out")
);
dm.record_lifecycle_established(machine_id, Some(3));
assert!(dm.lifecycle_block_reason(&machine_id).is_none());
}
#[test]
fn direct_diagnostics_prune_idle_entries_to_scaled_bound() {
fn agent_id_from_u32(id: u32) -> AgentId {
let mut bytes = [0u8; 32];
bytes[..4].copy_from_slice(&id.to_be_bytes());
AgentId(bytes)
}
let now = now_unix_ms_lossy();
let connected = agent_id_from_u32(1);
let mut connected_set = HashSet::new();
connected_set.insert(connected);
let mut guard = HashMap::new();
guard.insert(
connected,
DirectPeerDiagnosticsState {
last_recv_at_ms: Some(0),
..DirectPeerDiagnosticsState::default()
},
);
for id in 2..1100 {
guard.insert(
agent_id_from_u32(id),
DirectPeerDiagnosticsState {
last_recv_at_ms: Some(now),
..DirectPeerDiagnosticsState::default()
},
);
}
DirectMessaging::prune_peer_diagnostics_locked(&mut guard, &connected_set);
assert!(guard.len() <= DIRECT_DIAGNOSTICS_MIN_RETAIN);
assert!(guard.contains_key(&connected));
}
#[test]
fn test_direct_message_payload_str() {
let msg = DirectMessage::new(AgentId([1u8; 32]), MachineId([2u8; 32]), b"hello".to_vec());
assert_eq!(msg.payload_str(), Some("hello"));
let binary_msg =
DirectMessage::new(AgentId([1u8; 32]), MachineId([2u8; 32]), vec![0xff, 0xfe]);
assert!(binary_msg.payload_str().is_none());
}
}