use async_trait::async_trait;
use iroh::{Endpoint, EndpointId};
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
pub use crate::network::peer_config::{PeerConfig, PeerInfo};
#[derive(Clone, Debug)]
pub enum DiscoveryEvent {
PeerFound(PeerInfo),
PeerLost(EndpointId),
}
#[async_trait]
pub trait DiscoveryStrategy: Send + Sync {
async fn start(&mut self) -> anyhow::Result<()>;
async fn discovered_peers(&self) -> Vec<PeerInfo>;
fn event_stream(&self) -> mpsc::Receiver<DiscoveryEvent>;
}
pub struct StaticDiscovery {
peers: Vec<PeerInfo>,
}
impl StaticDiscovery {
pub fn from_file(path: impl AsRef<Path>) -> anyhow::Result<Self> {
let config = PeerConfig::from_file(path)?;
Ok(Self {
peers: config.peers,
})
}
pub fn from_peers(peers: Vec<PeerInfo>) -> Self {
Self { peers }
}
}
#[async_trait]
impl DiscoveryStrategy for StaticDiscovery {
async fn start(&mut self) -> anyhow::Result<()> {
tracing::info!(
"Static: Loaded {} peers from configuration",
self.peers.len()
);
Ok(())
}
async fn discovered_peers(&self) -> Vec<PeerInfo> {
self.peers.clone()
}
fn event_stream(&self) -> mpsc::Receiver<DiscoveryEvent> {
let (_, rx) = mpsc::channel(1);
rx
}
}
pub struct MdnsDiscovery {
endpoint: Endpoint,
node_name: String,
discovered: Arc<RwLock<HashMap<EndpointId, PeerInfo>>>,
event_tx: Arc<RwLock<Option<mpsc::Sender<DiscoveryEvent>>>>,
mdns_service: Arc<RwLock<Option<mdns_sd::ServiceDaemon>>>,
}
impl MdnsDiscovery {
const SERVICE_TYPE: &'static str = "_peat-node._tcp.local.";
pub fn new(endpoint: Endpoint, node_name: String) -> anyhow::Result<Self> {
Ok(Self {
endpoint,
node_name,
discovered: Arc::new(RwLock::new(HashMap::new())),
event_tx: Arc::new(RwLock::new(None)),
mdns_service: Arc::new(RwLock::new(None)),
})
}
fn get_local_ip() -> anyhow::Result<String> {
use std::net::UdpSocket;
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket.connect("8.8.8.8:80")?;
let addr = socket.local_addr()?;
Ok(addr.ip().to_string())
}
pub async fn stop(&mut self) {
let _ = self.mdns_service.write().await.take();
tracing::info!("mDNS: Stopped discovery and unregistered service");
}
}
#[async_trait]
impl DiscoveryStrategy for MdnsDiscovery {
async fn start(&mut self) -> anyhow::Result<()> {
use mdns_sd::{ServiceDaemon, ServiceInfo};
use std::collections::HashMap as StdHashMap;
tracing::info!("mDNS: Starting zero-config discovery for local network");
let mdns = ServiceDaemon::new()
.map_err(|e| anyhow::anyhow!("Failed to create mDNS daemon: {}", e))?;
let endpoint_id = self.endpoint.id();
let node_id_hex = hex::encode(endpoint_id.as_bytes());
let port = 0;
let mut properties = StdHashMap::new();
properties.insert("node_id".to_string(), node_id_hex.clone());
properties.insert("version".to_string(), "1".to_string());
let local_ip = Self::get_local_ip().unwrap_or_else(|_| "127.0.0.1".to_string());
tracing::debug!("mDNS: Using local IP address: {}", local_ip);
let host_name = format!("{}.local.", self.node_name);
let service_info = ServiceInfo::new(
Self::SERVICE_TYPE,
&self.node_name,
&host_name,
&local_ip,
port,
properties,
)
.map_err(|e| anyhow::anyhow!("Failed to create mDNS service info: {}", e))?;
mdns.register(service_info)
.map_err(|e| anyhow::anyhow!("Failed to register mDNS service: {}", e))?;
tracing::info!(
"mDNS: Advertised node '{}' with ID {} on port {}",
self.node_name,
&node_id_hex[..16],
port
);
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let receiver = mdns
.browse(Self::SERVICE_TYPE)
.map_err(|e| anyhow::anyhow!("Failed to browse mDNS services: {}", e))?;
tracing::debug!(
"mDNS: Started browsing for service type: {}",
Self::SERVICE_TYPE
);
*self.mdns_service.write().await = Some(mdns);
let (tx, _) = mpsc::channel(100);
*self.event_tx.write().await = Some(tx.clone());
let discovered = Arc::clone(&self.discovered);
let event_tx = tx;
tokio::spawn(async move {
use mdns_sd::ServiceEvent;
tracing::debug!("mDNS: Background task started, listening for service events");
while let Ok(event) = receiver.recv_async().await {
tracing::debug!("mDNS: Received event: {:?}", event);
match event {
ServiceEvent::ServiceResolved(info) => {
tracing::info!("mDNS: Service resolved: {}", info.get_fullname());
if let Some(node_id_hex) = info.get_property_val_str("node_id") {
tracing::debug!("mDNS: Found node_id in TXT record: {}", node_id_hex);
match hex::decode(node_id_hex) {
Ok(node_id_bytes) => {
tracing::debug!(
"mDNS: Decoded node_id, length: {}",
node_id_bytes.len()
);
if node_id_bytes.len() == 32 {
let mut array = [0u8; 32];
array.copy_from_slice(&node_id_bytes);
match EndpointId::from_bytes(&array) {
Ok(endpoint_id) => {
tracing::debug!(
"mDNS: Successfully created EndpointId"
);
let addresses: Vec<String> = info
.get_addresses()
.iter()
.map(|addr| {
format!("{}:{}", addr, info.get_port())
})
.collect();
let peer_info = PeerInfo {
name: info.get_fullname().to_string(),
node_id: node_id_hex.to_string(),
addresses: addresses.clone(),
relay_url: None,
};
let mut peers = discovered.write().await;
peers.insert(endpoint_id, peer_info.clone());
let total_peers = peers.len();
drop(peers);
let _ = event_tx
.send(DiscoveryEvent::PeerFound(peer_info))
.await;
tracing::info!(
"mDNS: Discovered peer '{}' at {:?} (total peers: {})",
info.get_fullname(),
addresses,
total_peers
);
}
Err(e) => {
tracing::warn!(
"mDNS: Failed to create EndpointId: {}",
e
);
}
}
} else {
tracing::warn!(
"mDNS: node_id wrong length: {} bytes, expected 32",
node_id_bytes.len()
);
}
}
Err(e) => {
tracing::warn!("mDNS: Failed to decode node_id hex: {}", e);
}
}
} else {
tracing::debug!("mDNS: No node_id property found in TXT records");
}
}
ServiceEvent::ServiceRemoved(_, fullname) => {
let mut peers = discovered.write().await;
if let Some((endpoint_id, _)) = peers
.iter()
.find(|(_, p)| p.name == fullname)
.map(|(k, v)| (*k, v.clone()))
{
peers.remove(&endpoint_id);
drop(peers);
let _ = event_tx.send(DiscoveryEvent::PeerLost(endpoint_id)).await;
tracing::info!("mDNS: Peer '{}' left the network", fullname);
}
}
other_event => {
tracing::debug!("mDNS: Received event (ignored): {:?}", other_event);
}
}
}
tracing::warn!("mDNS: Background task ended - receiver closed");
});
Ok(())
}
async fn discovered_peers(&self) -> Vec<PeerInfo> {
self.discovered.read().await.values().cloned().collect()
}
fn event_stream(&self) -> mpsc::Receiver<DiscoveryEvent> {
let (_tx, rx) = mpsc::channel(100);
rx
}
}
pub struct RelayDiscovery {
_endpoint: Endpoint,
discovered: Arc<RwLock<HashMap<EndpointId, PeerInfo>>>,
}
impl RelayDiscovery {
pub fn new(endpoint: Endpoint) -> Self {
Self {
_endpoint: endpoint,
discovered: Arc::new(RwLock::new(HashMap::new())),
}
}
}
#[async_trait]
impl DiscoveryStrategy for RelayDiscovery {
async fn start(&mut self) -> anyhow::Result<()> {
tracing::info!("Relay: Starting relay-based discovery");
Ok(())
}
async fn discovered_peers(&self) -> Vec<PeerInfo> {
self.discovered.read().await.values().cloned().collect()
}
fn event_stream(&self) -> mpsc::Receiver<DiscoveryEvent> {
let (_, rx) = mpsc::channel(100);
rx
}
}
pub struct DiscoveryManager {
strategies: Vec<Box<dyn DiscoveryStrategy>>,
all_peers: Arc<RwLock<HashMap<EndpointId, PeerInfo>>>,
}
impl DiscoveryManager {
pub fn new() -> Self {
Self {
strategies: Vec::new(),
all_peers: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn add_strategy(&mut self, strategy: Box<dyn DiscoveryStrategy>) {
self.strategies.push(strategy);
}
pub async fn start(&mut self) -> anyhow::Result<()> {
for strategy in &mut self.strategies {
strategy.start().await?;
}
self.update_peers().await;
Ok(())
}
pub async fn update_peers(&self) {
let mut all = self.all_peers.write().await;
for strategy in &self.strategies {
for peer in strategy.discovered_peers().await {
if let Ok(endpoint_id) = peer.endpoint_id() {
all.insert(endpoint_id, peer);
}
}
}
}
pub async fn get_peers(&self) -> Vec<PeerInfo> {
let mut all_peers = HashMap::new();
for strategy in &self.strategies {
for peer in strategy.discovered_peers().await {
if let Ok(endpoint_id) = peer.endpoint_id() {
all_peers.insert(endpoint_id, peer);
}
}
}
all_peers.into_values().collect()
}
pub async fn discovered_peers(&self) -> anyhow::Result<Vec<PeerInfo>> {
Ok(self.get_peers().await)
}
pub async fn peer_count(&self) -> usize {
self.get_peers().await.len()
}
}
impl Default for DiscoveryManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_static_discovery() {
let peer = PeerInfo {
name: "Test Node".to_string(),
node_id: "a".repeat(64), addresses: vec!["192.168.1.100:5000".to_string()],
relay_url: None,
};
let mut discovery = StaticDiscovery::from_peers(vec![peer.clone()]);
discovery.start().await.unwrap();
let peers = discovery.discovered_peers().await;
assert_eq!(peers.len(), 1);
assert_eq!(peers[0].name, "Test Node");
}
#[tokio::test]
async fn test_discovery_manager() {
let peer1 = PeerInfo {
name: "Node 1".to_string(),
node_id: "a".repeat(64),
addresses: vec!["192.168.1.1:5000".to_string()],
relay_url: None,
};
let peer2 = PeerInfo {
name: "Node 2".to_string(),
node_id: "b".repeat(64),
addresses: vec!["192.168.1.2:5000".to_string()],
relay_url: None,
};
let mut manager = DiscoveryManager::new();
manager.add_strategy(Box::new(StaticDiscovery::from_peers(vec![peer1, peer2])));
manager.start().await.unwrap();
let peers = manager.get_peers().await;
assert_eq!(peers.len(), 2);
}
#[tokio::test]
async fn test_mdns_service_registration() {
let endpoint = iroh::Endpoint::builder(iroh::endpoint::presets::N0)
.bind()
.await
.expect("Failed to create endpoint");
let mut mdns = MdnsDiscovery::new(endpoint, "test-node".to_string())
.expect("Failed to create mDNS discovery");
mdns.start().await.expect("Failed to start mDNS discovery");
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}