use std::collections::HashMap;
use std::sync::Arc;
use anyhow::{Result, bail};
use tokio::sync::{RwLock, broadcast};
use uuid::Uuid;
use crate::discovery::Discovery;
use crate::identity::AgentIdentity;
use crate::network::event::TransportType;
use crate::network::{MessageEnvelope, MessageTarget, NetworkEvent, Payload};
use crate::routing::{BroadcastRouter, ContentRouter, DirectRouter, PeerTable, Router};
use crate::transport::{Transport, TransportAddress};
pub struct NetworkManager {
identity: AgentIdentity,
transports: HashMap<TransportType, Box<dyn Transport>>,
direct_router: DirectRouter,
broadcast_router: BroadcastRouter,
content_router: ContentRouter,
custom_router: Option<Box<dyn Router>>,
discoveries: Vec<Box<dyn Discovery>>,
peer_table: Arc<RwLock<PeerTable>>,
event_tx: broadcast::Sender<NetworkEvent>,
}
impl NetworkManager {
pub fn identity(&self) -> &AgentIdentity {
&self.identity
}
pub fn subscribe(&self) -> broadcast::Receiver<NetworkEvent> {
self.event_tx.subscribe()
}
pub async fn peer_table(&self) -> tokio::sync::RwLockReadGuard<'_, PeerTable> {
self.peer_table.read().await
}
pub async fn peer_table_mut(&self) -> tokio::sync::RwLockWriteGuard<'_, PeerTable> {
self.peer_table.write().await
}
pub async fn peers(&self) -> Vec<AgentIdentity> {
self.peer_table.read().await.all_peers().cloned().collect()
}
pub async fn send(&self, target: Uuid, payload: impl Into<Payload>) -> Result<()> {
let envelope = MessageEnvelope::direct(self.identity.id, target, payload);
self.send_envelope(envelope).await
}
pub async fn broadcast(&self, payload: impl Into<Payload>) -> Result<()> {
let envelope = MessageEnvelope::broadcast(self.identity.id, payload);
self.send_envelope(envelope).await
}
pub async fn publish(
&self,
topic: impl Into<String>,
payload: impl Into<Payload>,
) -> Result<()> {
let envelope = MessageEnvelope::topic(self.identity.id, topic, payload);
self.send_envelope(envelope).await
}
pub async fn send_envelope(&self, envelope: MessageEnvelope) -> Result<()> {
let peer_table = self.peer_table.read().await;
let addresses = match &envelope.recipient {
MessageTarget::Direct(_) => {
if let Some(router) = &self.custom_router {
router.route(&envelope, &peer_table).await?
} else {
self.direct_router.route(&envelope, &peer_table).await?
}
}
MessageTarget::Broadcast => self.broadcast_router.route(&envelope, &peer_table).await?,
MessageTarget::Topic(_) => self.content_router.route(&envelope, &peer_table).await?,
};
drop(peer_table);
if addresses.is_empty() {
bail!("No delivery addresses resolved for message");
}
for addr in &addresses {
let transport_type = transport_type_for_address(addr);
if let Some(transport) = self.transports.get(&transport_type) {
transport.send(&envelope).await?;
} else {
tracing::warn!(
"No transport available for address {addr} (type {transport_type:?})"
);
}
}
Ok(())
}
pub fn add_transport(&mut self, transport: Box<dyn Transport>) {
let t = transport.transport_type();
self.transports.insert(t, transport);
}
pub fn set_custom_router(&mut self, router: Box<dyn Router>) {
self.custom_router = Some(router);
}
pub fn add_discovery(&mut self, discovery: Box<dyn Discovery>) {
self.discoveries.push(discovery);
}
pub async fn register_self(&self) -> Result<()> {
for d in &self.discoveries {
d.register(&self.identity).await?;
}
Ok(())
}
pub async fn deregister_self(&self) -> Result<()> {
for d in &self.discoveries {
d.deregister(&self.identity.id).await?;
}
Ok(())
}
pub async fn discover_peers(&self) -> Result<Vec<AgentIdentity>> {
let mut all_peers = Vec::new();
for d in &self.discoveries {
match d.discover().await {
Ok(peers) => all_peers.extend(peers),
Err(e) => {
tracing::warn!("Discovery via {:?} failed: {e}", d.protocol());
}
}
}
let mut seen = std::collections::HashSet::new();
all_peers.retain(|p| seen.insert(p.id));
let mut table = self.peer_table.write().await;
for peer in &all_peers {
if peer.id == self.identity.id {
continue; }
if table.get(&peer.id).is_none() {
let addrs = endpoint_to_addresses(peer);
table.upsert(peer.clone(), addrs);
let _ = self.event_tx.send(NetworkEvent::PeerJoined(peer.clone()));
}
}
Ok(all_peers)
}
pub fn emit(&self, event: NetworkEvent) {
let _ = self.event_tx.send(event);
}
}
pub struct NetworkManagerBuilder {
identity: AgentIdentity,
transports: HashMap<TransportType, Box<dyn Transport>>,
custom_router: Option<Box<dyn Router>>,
discoveries: Vec<Box<dyn Discovery>>,
event_buffer: usize,
}
impl NetworkManagerBuilder {
pub fn new(identity: AgentIdentity) -> Self {
Self {
identity,
transports: HashMap::new(),
custom_router: None,
discoveries: Vec::new(),
event_buffer: 256,
}
}
pub fn add_transport(mut self, transport: Box<dyn Transport>) -> Self {
let t = transport.transport_type();
self.transports.insert(t, transport);
self
}
pub fn with_router(mut self, router: Box<dyn Router>) -> Self {
self.custom_router = Some(router);
self
}
pub fn add_discovery(mut self, discovery: Box<dyn Discovery>) -> Self {
self.discoveries.push(discovery);
self
}
pub fn event_buffer(mut self, size: usize) -> Self {
self.event_buffer = size;
self
}
pub fn build(self) -> NetworkManager {
let (event_tx, _) = broadcast::channel(self.event_buffer);
NetworkManager {
identity: self.identity,
transports: self.transports,
direct_router: DirectRouter::new(),
broadcast_router: BroadcastRouter::new(),
content_router: ContentRouter::new(),
custom_router: self.custom_router,
discoveries: self.discoveries,
peer_table: Arc::new(RwLock::new(PeerTable::new())),
event_tx,
}
}
}
fn transport_type_for_address(addr: &TransportAddress) -> TransportType {
match addr {
TransportAddress::Unix(_) => TransportType::Ipc,
TransportAddress::Tcp(_) => TransportType::Tcp,
TransportAddress::Url(_) => TransportType::Remote,
TransportAddress::Channel(_) => TransportType::PubSub,
}
}
fn endpoint_to_addresses(identity: &AgentIdentity) -> Vec<TransportAddress> {
let Some(endpoint) = &identity.agent_card.endpoint else {
return Vec::new();
};
if let Some(path) = endpoint.strip_prefix("unix://") {
vec![TransportAddress::Unix(path.into())]
} else if let Some(addr) = endpoint.strip_prefix("tcp://") {
if let Ok(sock) = addr.parse() {
vec![TransportAddress::Tcp(sock)]
} else {
vec![TransportAddress::Url(endpoint.clone())]
}
} else {
vec![TransportAddress::Url(endpoint.clone())]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::discovery::ManualDiscovery;
#[tokio::test]
async fn builder_creates_manager() {
let identity = AgentIdentity::new("test-agent");
let manager = NetworkManagerBuilder::new(identity.clone())
.add_discovery(Box::new(ManualDiscovery::new()))
.build();
assert_eq!(manager.identity().name, "test-agent");
assert!(manager.peers().await.is_empty());
}
#[tokio::test]
async fn discover_peers_populates_table() {
let agent_a = AgentIdentity::new("agent-a");
let agent_b = AgentIdentity::new("agent-b");
let discovery = ManualDiscovery::with_peers(vec![agent_b.clone()]);
let manager = NetworkManagerBuilder::new(agent_a)
.add_discovery(Box::new(discovery))
.build();
let mut events = manager.subscribe();
let found = manager.discover_peers().await.unwrap();
assert_eq!(found.len(), 1);
let peers = manager.peers().await;
assert_eq!(peers.len(), 1);
assert_eq!(peers[0].name, "agent-b");
let event = events.try_recv().unwrap();
match event {
NetworkEvent::PeerJoined(p) => assert_eq!(p.id, agent_b.id),
_ => panic!("expected PeerJoined"),
}
}
#[tokio::test]
async fn register_and_deregister_self() {
let identity = AgentIdentity::new("self");
let discovery = ManualDiscovery::new();
let manager = NetworkManagerBuilder::new(identity.clone())
.add_discovery(Box::new(discovery.clone()))
.build();
manager.register_self().await.unwrap();
let peers = discovery.discover().await.unwrap();
assert_eq!(peers.len(), 1);
manager.deregister_self().await.unwrap();
let peers = discovery.discover().await.unwrap();
assert!(peers.is_empty());
}
#[test]
fn transport_type_inference() {
assert_eq!(
transport_type_for_address(&TransportAddress::Unix("/tmp/test.sock".into())),
TransportType::Ipc
);
assert_eq!(
transport_type_for_address(&TransportAddress::Tcp("127.0.0.1:9090".parse().unwrap())),
TransportType::Tcp
);
assert_eq!(
transport_type_for_address(&TransportAddress::Url("https://example.com".into())),
TransportType::Remote
);
assert_eq!(
transport_type_for_address(&TransportAddress::Channel("events".into())),
TransportType::PubSub
);
}
#[test]
fn endpoint_parsing() {
let mut identity = AgentIdentity::new("test");
identity.agent_card.endpoint = Some("unix:///tmp/agent.sock".into());
let addrs = endpoint_to_addresses(&identity);
assert_eq!(
addrs,
vec![TransportAddress::Unix("/tmp/agent.sock".into())]
);
identity.agent_card.endpoint = Some("tcp://127.0.0.1:9090".into());
let addrs = endpoint_to_addresses(&identity);
assert_eq!(
addrs,
vec![TransportAddress::Tcp("127.0.0.1:9090".parse().unwrap())]
);
identity.agent_card.endpoint = Some("https://api.example.com".into());
let addrs = endpoint_to_addresses(&identity);
assert_eq!(
addrs,
vec![TransportAddress::Url("https://api.example.com".into())]
);
identity.agent_card.endpoint = None;
let addrs = endpoint_to_addresses(&identity);
assert!(addrs.is_empty());
}
}