use crate::error::{NetworkError, NetworkResult};
use crate::identity::{AgentId, MachineId};
use crate::trust::TrustDecision;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::{broadcast, mpsc, RwLock};
pub const DIRECT_MESSAGE_STREAM_TYPE: u8 = 0x10;
pub const MAX_DIRECT_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DirectMessage {
pub sender: AgentId,
pub machine_id: MachineId,
pub payload: Vec<u8>,
pub received_at: u64,
pub verified: bool,
pub trust_decision: Option<TrustDecision>,
}
impl DirectMessage {
#[must_use]
pub fn new(sender: AgentId, machine_id: MachineId, payload: Vec<u8>) -> Self {
Self::new_verified(sender, machine_id, payload, false, None)
}
#[must_use]
pub fn new_verified(
sender: AgentId,
machine_id: MachineId,
payload: Vec<u8>,
verified: bool,
trust_decision: Option<TrustDecision>,
) -> Self {
let received_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
Self {
sender,
machine_id,
payload,
received_at,
verified,
trust_decision,
}
}
#[must_use]
pub fn payload_str(&self) -> Option<&str> {
std::str::from_utf8(&self.payload).ok()
}
}
#[derive(Debug)]
pub struct DirectMessageReceiver {
rx: broadcast::Receiver<DirectMessage>,
}
impl DirectMessageReceiver {
pub(crate) fn new(rx: broadcast::Receiver<DirectMessage>) -> Self {
Self { rx }
}
pub async fn recv(&mut self) -> Option<DirectMessage> {
loop {
match self.rx.recv().await {
Ok(msg) => return Some(msg),
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("Direct message receiver lagged, skipped {} messages", n);
continue;
}
Err(broadcast::error::RecvError::Closed) => return None,
}
}
}
pub fn try_recv(&mut self) -> Option<DirectMessage> {
self.rx.try_recv().ok()
}
}
impl Clone for DirectMessageReceiver {
fn clone(&self) -> Self {
Self {
rx: self.rx.resubscribe(),
}
}
}
#[derive(Debug)]
pub struct DirectMessaging {
machine_to_agent: Arc<RwLock<HashMap<MachineId, AgentId>>>,
connected_agents: Arc<RwLock<HashMap<AgentId, MachineId>>>,
message_tx: broadcast::Sender<DirectMessage>,
internal_tx: mpsc::Sender<DirectMessage>,
internal_rx: Arc<tokio::sync::Mutex<mpsc::Receiver<DirectMessage>>>,
}
impl DirectMessaging {
#[must_use]
pub fn new() -> Self {
let (message_tx, _) = broadcast::channel(256);
let (internal_tx, internal_rx) = mpsc::channel(256);
Self {
machine_to_agent: Arc::new(RwLock::new(HashMap::new())),
connected_agents: Arc::new(RwLock::new(HashMap::new())),
message_tx,
internal_tx,
internal_rx: Arc::new(tokio::sync::Mutex::new(internal_rx)),
}
}
pub async fn register_agent(&self, agent_id: AgentId, machine_id: MachineId) {
let mut map = self.machine_to_agent.write().await;
map.insert(machine_id, agent_id);
tracing::debug!(
"Registered agent mapping: {:?} -> {:?}",
machine_id,
agent_id
);
}
pub async fn lookup_agent(&self, machine_id: &MachineId) -> Option<AgentId> {
let map = self.machine_to_agent.read().await;
map.get(machine_id).copied()
}
pub async fn mark_connected(&self, agent_id: AgentId, machine_id: MachineId) {
self.register_agent(agent_id, machine_id).await;
let mut connected = self.connected_agents.write().await;
connected.insert(agent_id, machine_id);
tracing::info!("Agent connected: {:?}", agent_id);
}
pub async fn mark_disconnected(&self, agent_id: &AgentId) {
let mut connected = self.connected_agents.write().await;
connected.remove(agent_id);
tracing::info!("Agent disconnected: {:?}", agent_id);
}
pub async fn is_connected(&self, agent_id: &AgentId) -> bool {
let connected = self.connected_agents.read().await;
connected.contains_key(agent_id)
}
pub async fn get_machine_id(&self, agent_id: &AgentId) -> Option<MachineId> {
let connected = self.connected_agents.read().await;
connected.get(agent_id).copied()
}
pub async fn connected_agents(&self) -> Vec<AgentId> {
let connected = self.connected_agents.read().await;
connected.keys().copied().collect()
}
pub fn subscribe(&self) -> DirectMessageReceiver {
DirectMessageReceiver::new(self.message_tx.subscribe())
}
pub async fn handle_incoming(
&self,
machine_id: MachineId,
sender_agent_id: AgentId,
payload: Vec<u8>,
verified: bool,
trust_decision: Option<TrustDecision>,
) {
let msg = DirectMessage::new_verified(
sender_agent_id,
machine_id,
payload,
verified,
trust_decision,
);
if self.message_tx.receiver_count() > 0 {
let _ = self.message_tx.send(msg.clone());
}
if self.internal_tx.try_send(msg).is_err() {
tracing::trace!(
"direct internal_tx full or closed, skipping pull-API copy"
);
}
}
pub async fn recv(&self) -> Option<DirectMessage> {
let mut rx = self.internal_rx.lock().await;
rx.recv().await
}
pub fn encode_message(sender_agent_id: &AgentId, payload: &[u8]) -> NetworkResult<Vec<u8>> {
if payload.len() > MAX_DIRECT_PAYLOAD_SIZE {
return Err(NetworkError::PayloadTooLarge {
size: payload.len(),
max: MAX_DIRECT_PAYLOAD_SIZE,
});
}
let mut buf = Vec::with_capacity(1 + 32 + payload.len());
buf.push(DIRECT_MESSAGE_STREAM_TYPE);
buf.extend_from_slice(&sender_agent_id.0);
buf.extend_from_slice(payload);
Ok(buf)
}
pub fn decode_message(data: &[u8]) -> NetworkResult<(AgentId, Vec<u8>)> {
if data.len() < 33 {
return Err(NetworkError::InvalidMessage(
"Direct message too short".to_string(),
));
}
if data[0] != DIRECT_MESSAGE_STREAM_TYPE {
return Err(NetworkError::InvalidMessage(format!(
"Invalid stream type byte: expected {}, got {}",
DIRECT_MESSAGE_STREAM_TYPE, data[0]
)));
}
let mut agent_id_bytes = [0u8; 32];
agent_id_bytes.copy_from_slice(&data[1..33]);
let sender = AgentId(agent_id_bytes);
let payload = data[33..].to_vec();
Ok((sender, payload))
}
}
impl Default for DirectMessaging {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_decode_roundtrip() {
let agent_id = AgentId([42u8; 32]);
let payload = b"hello world".to_vec();
let encoded = DirectMessaging::encode_message(&agent_id, &payload).unwrap();
assert_eq!(encoded[0], DIRECT_MESSAGE_STREAM_TYPE);
assert_eq!(encoded.len(), 1 + 32 + payload.len());
let (decoded_agent, decoded_payload) = DirectMessaging::decode_message(&encoded).unwrap();
assert_eq!(decoded_agent, agent_id);
assert_eq!(decoded_payload, payload);
}
#[test]
fn test_decode_too_short() {
let short_data = vec![DIRECT_MESSAGE_STREAM_TYPE; 10];
let result = DirectMessaging::decode_message(&short_data);
assert!(result.is_err());
}
#[test]
fn test_decode_wrong_type() {
let mut data = vec![0x00; 50]; data[0] = 0x01;
let result = DirectMessaging::decode_message(&data);
assert!(result.is_err());
}
#[test]
fn test_encode_payload_too_large() {
let agent_id = AgentId([1u8; 32]);
let payload = vec![0u8; MAX_DIRECT_PAYLOAD_SIZE + 1];
let result = DirectMessaging::encode_message(&agent_id, &payload);
assert!(result.is_err());
}
#[tokio::test]
async fn test_register_and_lookup() {
let dm = DirectMessaging::new();
let agent_id = AgentId([1u8; 32]);
let machine_id = MachineId([2u8; 32]);
dm.register_agent(agent_id, machine_id).await;
let lookup = dm.lookup_agent(&machine_id).await;
assert_eq!(lookup, Some(agent_id));
}
#[tokio::test]
async fn test_connection_tracking() {
let dm = DirectMessaging::new();
let agent_id = AgentId([1u8; 32]);
let machine_id = MachineId([2u8; 32]);
assert!(!dm.is_connected(&agent_id).await);
dm.mark_connected(agent_id, machine_id).await;
assert!(dm.is_connected(&agent_id).await);
assert_eq!(dm.get_machine_id(&agent_id).await, Some(machine_id));
let connected = dm.connected_agents().await;
assert_eq!(connected, vec![agent_id]);
dm.mark_disconnected(&agent_id).await;
assert!(!dm.is_connected(&agent_id).await);
}
#[tokio::test]
async fn test_message_subscription() {
let dm = DirectMessaging::new();
let mut rx = dm.subscribe();
let sender = AgentId([1u8; 32]);
let machine_id = MachineId([2u8; 32]);
let payload = b"test message".to_vec();
dm.handle_incoming(machine_id, sender, payload.clone(), true, None)
.await;
let msg = rx.recv().await.unwrap();
assert_eq!(msg.sender, sender);
assert_eq!(msg.machine_id, machine_id);
assert_eq!(msg.payload, payload);
assert!(msg.verified);
assert!(msg.trust_decision.is_none());
}
#[test]
fn test_direct_message_payload_str() {
let msg = DirectMessage::new(AgentId([1u8; 32]), MachineId([2u8; 32]), b"hello".to_vec());
assert_eq!(msg.payload_str(), Some("hello"));
let binary_msg =
DirectMessage::new(AgentId([1u8; 32]), MachineId([2u8; 32]), vec![0xff, 0xfe]);
assert!(binary_msg.payload_str().is_none());
}
}