use async_trait::async_trait;
use iroh::EndpointId;
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
pub use crate::network::peer_config::{PeerConfig, PeerInfo};
#[derive(Clone, Debug)]
pub enum DiscoveryEvent {
PeerFound(PeerInfo),
PeerLost(EndpointId),
}
#[async_trait]
pub trait DiscoveryStrategy: Send + Sync {
async fn start(&mut self) -> anyhow::Result<()>;
async fn discovered_peers(&self) -> Vec<PeerInfo>;
fn event_stream(&self) -> mpsc::Receiver<DiscoveryEvent>;
}
pub struct StaticDiscovery {
peers: Vec<PeerInfo>,
}
impl StaticDiscovery {
pub fn from_file(path: impl AsRef<Path>) -> anyhow::Result<Self> {
let config = PeerConfig::from_file(path)?;
Ok(Self {
peers: config.peers,
})
}
pub fn from_peers(peers: Vec<PeerInfo>) -> Self {
Self { peers }
}
}
#[async_trait]
impl DiscoveryStrategy for StaticDiscovery {
async fn start(&mut self) -> anyhow::Result<()> {
tracing::info!(
"Static: Loaded {} peers from configuration",
self.peers.len()
);
Ok(())
}
async fn discovered_peers(&self) -> Vec<PeerInfo> {
self.peers.clone()
}
fn event_stream(&self) -> mpsc::Receiver<DiscoveryEvent> {
let (_, rx) = mpsc::channel(1);
rx
}
}
pub struct DiscoveryManager {
strategies: Vec<Box<dyn DiscoveryStrategy>>,
all_peers: Arc<RwLock<HashMap<EndpointId, PeerInfo>>>,
}
impl DiscoveryManager {
pub fn new() -> Self {
Self {
strategies: Vec::new(),
all_peers: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn add_strategy(&mut self, strategy: Box<dyn DiscoveryStrategy>) {
self.strategies.push(strategy);
}
pub async fn start(&mut self) -> anyhow::Result<()> {
for strategy in &mut self.strategies {
strategy.start().await?;
}
self.update_peers().await;
Ok(())
}
pub async fn update_peers(&self) {
let mut all = self.all_peers.write().await;
for strategy in &self.strategies {
for peer in strategy.discovered_peers().await {
if let Ok(endpoint_id) = peer.endpoint_id() {
all.insert(endpoint_id, peer);
}
}
}
}
pub async fn get_peers(&self) -> Vec<PeerInfo> {
let mut all_peers = HashMap::new();
for strategy in &self.strategies {
for peer in strategy.discovered_peers().await {
if let Ok(endpoint_id) = peer.endpoint_id() {
all_peers.insert(endpoint_id, peer);
}
}
}
all_peers.into_values().collect()
}
pub async fn discovered_peers(&self) -> anyhow::Result<Vec<PeerInfo>> {
Ok(self.get_peers().await)
}
pub async fn peer_count(&self) -> usize {
self.get_peers().await.len()
}
}
impl Default for DiscoveryManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_static_discovery() {
let peer = PeerInfo {
name: "Test Node".to_string(),
node_id: "a".repeat(64), addresses: vec!["192.168.1.100:5000".to_string()],
relay_url: None,
};
let mut discovery = StaticDiscovery::from_peers(vec![peer.clone()]);
discovery.start().await.unwrap();
let peers = discovery.discovered_peers().await;
assert_eq!(peers.len(), 1);
assert_eq!(peers[0].name, "Test Node");
}
#[tokio::test]
async fn test_discovery_manager() {
let peer1 = PeerInfo {
name: "Node 1".to_string(),
node_id: "a".repeat(64),
addresses: vec!["192.168.1.1:5000".to_string()],
relay_url: None,
};
let peer2 = PeerInfo {
name: "Node 2".to_string(),
node_id: "b".repeat(64),
addresses: vec!["192.168.1.2:5000".to_string()],
relay_url: None,
};
let mut manager = DiscoveryManager::new();
manager.add_strategy(Box::new(StaticDiscovery::from_peers(vec![peer1, peer2])));
manager.start().await.unwrap();
let peers = manager.get_peers().await;
assert_eq!(peers.len(), 2);
}
}