use super::{DiscoveryError, DiscoveryEvent, DiscoveryStrategy, PeerInfo, Result};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use tracing::{debug, info, warn};
pub struct HybridDiscovery {
strategies: HashMap<String, Box<dyn DiscoveryStrategy>>,
combined_peers: Arc<RwLock<HashMap<String, PeerInfo>>>,
events_tx: mpsc::Sender<DiscoveryEvent>,
events_rx: Option<mpsc::Receiver<DiscoveryEvent>>,
started: Arc<RwLock<bool>>,
}
impl HybridDiscovery {
pub fn new() -> Self {
let (events_tx, events_rx) = mpsc::channel(1000);
Self {
strategies: HashMap::new(),
combined_peers: Arc::new(RwLock::new(HashMap::new())),
events_tx,
events_rx: Some(events_rx),
started: Arc::new(RwLock::new(false)),
}
}
pub fn add_strategy(&mut self, name: impl Into<String>, strategy: Box<dyn DiscoveryStrategy>) {
let name = name.into();
info!("Adding discovery strategy: {}", name);
self.strategies.insert(name, strategy);
}
pub fn remove_strategy(&mut self, name: &str) -> Option<Box<dyn DiscoveryStrategy>> {
info!("Removing discovery strategy: {}", name);
self.strategies.remove(name)
}
pub async fn start_all(&mut self) -> Result<()> {
let mut started = self.started.write().await;
if *started {
warn!("Hybrid discovery already started");
return Ok(());
}
info!(
"Starting hybrid discovery with {} strategies",
self.strategies.len()
);
for (name, strategy) in self.strategies.iter_mut() {
info!("Starting discovery strategy: {}", name);
strategy.start().await.map_err(|e| {
DiscoveryError::ConfigError(format!("Failed to start strategy '{}': {}", name, e))
})?;
let mut events = strategy.event_stream().map_err(|e| {
DiscoveryError::ConfigError(format!(
"Failed to get event stream for strategy '{}': {}",
name, e
))
})?;
let combined_peers = self.combined_peers.clone();
let events_tx = self.events_tx.clone();
let strategy_name = name.clone();
tokio::spawn(async move {
while let Some(event) = events.recv().await {
debug!("Event from strategy '{}': {:?}", strategy_name, event);
match &event {
DiscoveryEvent::PeerFound(peer_info) => {
let mut peers = combined_peers.write().await;
let node_id = peer_info.node_id.clone();
let is_new = !peers.contains_key(&node_id);
peers.insert(node_id.clone(), peer_info.clone());
drop(peers);
let forward_event = if is_new {
info!("New peer discovered via '{}': {}", strategy_name, node_id);
DiscoveryEvent::PeerFound(peer_info.clone())
} else {
debug!("Updated peer via '{}': {}", strategy_name, node_id);
DiscoveryEvent::PeerUpdated(peer_info.clone())
};
if let Err(e) = events_tx.send(forward_event).await {
warn!("Failed to forward discovery event: {}", e);
}
}
DiscoveryEvent::PeerLost(node_id) => {
let mut peers = combined_peers.write().await;
if peers.remove(node_id).is_some() {
drop(peers);
info!("Peer lost via '{}': {}", strategy_name, node_id);
if let Err(e) = events_tx
.send(DiscoveryEvent::PeerLost(node_id.clone()))
.await
{
warn!("Failed to forward peer lost event: {}", e);
}
}
}
DiscoveryEvent::PeerUpdated(peer_info) => {
let mut peers = combined_peers.write().await;
let node_id = peer_info.node_id.clone();
if peers.contains_key(&node_id) {
peers.insert(node_id.clone(), peer_info.clone());
drop(peers);
debug!("Peer updated via '{}': {}", strategy_name, node_id);
if let Err(e) = events_tx
.send(DiscoveryEvent::PeerUpdated(peer_info.clone()))
.await
{
warn!("Failed to forward peer updated event: {}", e);
}
}
}
}
}
info!("Event forwarding task for '{}' stopped", strategy_name);
});
}
*started = true;
info!("Hybrid discovery started successfully");
Ok(())
}
pub async fn stop_all(&mut self) -> Result<()> {
let mut started = self.started.write().await;
if !*started {
return Ok(());
}
info!("Stopping hybrid discovery");
for (name, strategy) in self.strategies.iter_mut() {
info!("Stopping discovery strategy: {}", name);
if let Err(e) = strategy.stop().await {
warn!("Error stopping strategy '{}': {}", name, e);
}
}
*started = false;
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
info!("Hybrid discovery stopped");
Ok(())
}
pub async fn all_discovered_peers(&self) -> Vec<PeerInfo> {
self.combined_peers.read().await.values().cloned().collect()
}
pub fn strategy_count(&self) -> usize {
self.strategies.len()
}
pub fn has_strategy(&self, name: &str) -> bool {
self.strategies.contains_key(name)
}
}
impl Default for HybridDiscovery {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl DiscoveryStrategy for HybridDiscovery {
async fn start(&mut self) -> Result<()> {
self.start_all().await
}
async fn stop(&mut self) -> Result<()> {
self.stop_all().await
}
async fn advertise(&self, node_id: &str, port: u16) -> Result<()> {
for (name, strategy) in &self.strategies {
if let Err(e) = strategy.advertise(node_id, port).await {
warn!(strategy = %name, error = %e, "Failed to advertise via strategy");
}
}
Ok(())
}
async fn discovered_peers(&self) -> Vec<PeerInfo> {
self.all_discovered_peers().await
}
fn event_stream(&mut self) -> Result<mpsc::Receiver<DiscoveryEvent>> {
self.events_rx
.take()
.ok_or(DiscoveryError::EventStreamConsumed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::discovery::static_config::{DiscoveryConfig, StaticPeerConfig};
use crate::discovery::StaticDiscovery;
use std::collections::HashMap;
#[tokio::test]
async fn test_hybrid_discovery_basic() {
let mut hybrid = HybridDiscovery::new();
let config = DiscoveryConfig {
peers: vec![StaticPeerConfig {
node_id: "test-peer".to_string(),
addresses: vec!["10.0.0.1:5000".to_string()],
relay_url: None,
priority: 128,
metadata: HashMap::new(),
}],
};
let static_disc = StaticDiscovery::from_config(config).unwrap();
hybrid.add_strategy("static", Box::new(static_disc));
assert_eq!(hybrid.strategy_count(), 1);
assert!(hybrid.has_strategy("static"));
let mut events = hybrid.event_stream().unwrap();
hybrid.start_all().await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let peers = hybrid.all_discovered_peers().await;
assert_eq!(peers.len(), 1);
assert_eq!(peers[0].node_id, "test-peer");
let event = events.try_recv().unwrap();
assert!(matches!(event, DiscoveryEvent::PeerFound(_)));
hybrid.stop_all().await.unwrap();
}
#[tokio::test]
async fn test_hybrid_discovery_multiple_strategies() {
let mut hybrid = HybridDiscovery::new();
let config1 = DiscoveryConfig {
peers: vec![StaticPeerConfig {
node_id: "peer-1".to_string(),
addresses: vec!["10.0.0.1:5000".to_string()],
relay_url: None,
priority: 128,
metadata: HashMap::new(),
}],
};
let config2 = DiscoveryConfig {
peers: vec![StaticPeerConfig {
node_id: "peer-2".to_string(),
addresses: vec!["10.0.0.2:5000".to_string()],
relay_url: None,
priority: 128,
metadata: HashMap::new(),
}],
};
hybrid.add_strategy(
"static-1",
Box::new(StaticDiscovery::from_config(config1).unwrap()),
);
hybrid.add_strategy(
"static-2",
Box::new(StaticDiscovery::from_config(config2).unwrap()),
);
assert_eq!(hybrid.strategy_count(), 2);
hybrid.start_all().await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let peers = hybrid.all_discovered_peers().await;
assert_eq!(peers.len(), 2);
let peer_ids: Vec<&str> = peers.iter().map(|p| p.node_id.as_str()).collect();
assert!(peer_ids.contains(&"peer-1"));
assert!(peer_ids.contains(&"peer-2"));
hybrid.stop_all().await.unwrap();
}
#[test]
fn test_strategy_management() {
let mut hybrid = HybridDiscovery::new();
let config = DiscoveryConfig { peers: vec![] };
let static_disc = StaticDiscovery::from_config(config).unwrap();
hybrid.add_strategy("test", Box::new(static_disc));
assert_eq!(hybrid.strategy_count(), 1);
assert!(hybrid.has_strategy("test"));
hybrid.remove_strategy("test");
assert_eq!(hybrid.strategy_count(), 0);
assert!(!hybrid.has_strategy("test"));
}
#[test]
fn test_default_impl() {
let hybrid = HybridDiscovery::default();
assert_eq!(hybrid.strategy_count(), 0);
}
#[tokio::test]
async fn test_stop_all_when_not_started() {
let mut hybrid = HybridDiscovery::new();
hybrid.stop_all().await.unwrap();
}
#[tokio::test]
async fn test_start_all_twice_idempotent() {
let mut hybrid = HybridDiscovery::new();
let config = DiscoveryConfig {
peers: vec![StaticPeerConfig {
node_id: "p1".to_string(),
addresses: vec!["10.0.0.1:5000".to_string()],
relay_url: None,
priority: 128,
metadata: HashMap::new(),
}],
};
hybrid.add_strategy(
"static",
Box::new(StaticDiscovery::from_config(config).unwrap()),
);
let _events = hybrid.event_stream().unwrap();
hybrid.start_all().await.unwrap();
hybrid.start_all().await.unwrap();
hybrid.stop_all().await.unwrap();
}
#[tokio::test]
async fn test_event_stream_consumed_error() {
let mut hybrid = HybridDiscovery::new();
let _stream = hybrid.event_stream().unwrap();
let result = hybrid.event_stream();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
DiscoveryError::EventStreamConsumed
));
}
#[tokio::test]
async fn test_all_discovered_peers_initially_empty() {
let hybrid = HybridDiscovery::new();
let peers = hybrid.all_discovered_peers().await;
assert!(peers.is_empty());
}
#[tokio::test]
async fn test_remove_strategy_returns_value() {
let mut hybrid = HybridDiscovery::new();
let config = DiscoveryConfig { peers: vec![] };
let static_disc = StaticDiscovery::from_config(config).unwrap();
hybrid.add_strategy("test", Box::new(static_disc));
let removed = hybrid.remove_strategy("test");
assert!(removed.is_some());
let removed_again = hybrid.remove_strategy("test");
assert!(removed_again.is_none());
}
#[tokio::test]
async fn test_has_strategy_false_for_missing() {
let hybrid = HybridDiscovery::new();
assert!(!hybrid.has_strategy("nonexistent"));
}
#[tokio::test]
async fn test_discovery_strategy_trait_implementation() {
let mut hybrid = HybridDiscovery::new();
let config = DiscoveryConfig {
peers: vec![StaticPeerConfig {
node_id: "trait-peer".to_string(),
addresses: vec!["10.0.0.1:5000".to_string()],
relay_url: None,
priority: 128,
metadata: HashMap::new(),
}],
};
hybrid.add_strategy(
"static",
Box::new(StaticDiscovery::from_config(config).unwrap()),
);
let _events = hybrid.event_stream().unwrap();
hybrid.start().await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let peers = hybrid.discovered_peers().await;
assert_eq!(peers.len(), 1);
assert_eq!(peers[0].node_id, "trait-peer");
hybrid.stop().await.unwrap();
}
#[tokio::test]
async fn test_peer_lost_event_forwarding() {
struct LostPeerStrategy;
#[async_trait]
impl DiscoveryStrategy for LostPeerStrategy {
async fn start(&mut self) -> Result<()> {
Ok(())
}
async fn stop(&mut self) -> Result<()> {
Ok(())
}
async fn discovered_peers(&self) -> Vec<PeerInfo> {
vec![]
}
fn event_stream(&mut self) -> Result<mpsc::Receiver<DiscoveryEvent>> {
let (tx, rx) = mpsc::channel(100);
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let peer = PeerInfo {
node_id: "lost-peer".to_string(),
addresses: vec!["10.0.0.1:5000".parse().unwrap()],
relay_url: None,
last_seen: std::time::Instant::now(),
metadata: HashMap::new(),
};
let _ = tx.send(DiscoveryEvent::PeerFound(peer)).await;
tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
let _ = tx
.send(DiscoveryEvent::PeerLost("lost-peer".to_string()))
.await;
});
Ok(rx)
}
}
let mut hybrid = HybridDiscovery::new();
hybrid.add_strategy("lossy", Box::new(LostPeerStrategy));
let mut events = hybrid.event_stream().unwrap();
hybrid.start_all().await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
let peers = hybrid.all_discovered_peers().await;
assert_eq!(peers.len(), 0);
let mut found = false;
let mut lost = false;
while let Ok(event) = events.try_recv() {
match event {
DiscoveryEvent::PeerFound(p) if p.node_id == "lost-peer" => found = true,
DiscoveryEvent::PeerLost(id) if id == "lost-peer" => lost = true,
_ => {}
}
}
assert!(found);
assert!(lost);
hybrid.stop_all().await.unwrap();
}
#[tokio::test]
async fn test_peer_updated_event_forwarding() {
struct UpdateStrategy;
#[async_trait]
impl DiscoveryStrategy for UpdateStrategy {
async fn start(&mut self) -> Result<()> {
Ok(())
}
async fn stop(&mut self) -> Result<()> {
Ok(())
}
async fn discovered_peers(&self) -> Vec<PeerInfo> {
vec![]
}
fn event_stream(&mut self) -> Result<mpsc::Receiver<DiscoveryEvent>> {
let (tx, rx) = mpsc::channel(100);
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let peer = PeerInfo {
node_id: "update-peer".to_string(),
addresses: vec!["10.0.0.1:5000".parse().unwrap()],
relay_url: None,
last_seen: std::time::Instant::now(),
metadata: HashMap::new(),
};
let _ = tx.send(DiscoveryEvent::PeerFound(peer.clone())).await;
tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
let mut updated = peer;
updated.addresses = vec!["10.0.0.2:5000".parse().unwrap()];
let _ = tx.send(DiscoveryEvent::PeerUpdated(updated)).await;
});
Ok(rx)
}
}
let mut hybrid = HybridDiscovery::new();
hybrid.add_strategy("updater", Box::new(UpdateStrategy));
let mut events = hybrid.event_stream().unwrap();
hybrid.start_all().await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
let peers = hybrid.all_discovered_peers().await;
assert_eq!(peers.len(), 1);
assert_eq!(peers[0].node_id, "update-peer");
let mut found = false;
let mut updated = false;
while let Ok(event) = events.try_recv() {
match event {
DiscoveryEvent::PeerFound(p) if p.node_id == "update-peer" => found = true,
DiscoveryEvent::PeerUpdated(p) if p.node_id == "update-peer" => updated = true,
_ => {}
}
}
assert!(found);
assert!(updated);
hybrid.stop_all().await.unwrap();
}
#[tokio::test]
async fn test_peer_lost_unknown_not_forwarded() {
struct LostUnknownStrategy;
#[async_trait]
impl DiscoveryStrategy for LostUnknownStrategy {
async fn start(&mut self) -> Result<()> {
Ok(())
}
async fn stop(&mut self) -> Result<()> {
Ok(())
}
async fn discovered_peers(&self) -> Vec<PeerInfo> {
vec![]
}
fn event_stream(&mut self) -> Result<mpsc::Receiver<DiscoveryEvent>> {
let (tx, rx) = mpsc::channel(100);
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let _ = tx
.send(DiscoveryEvent::PeerLost("never-found".to_string()))
.await;
});
Ok(rx)
}
}
let mut hybrid = HybridDiscovery::new();
hybrid.add_strategy("lost-unknown", Box::new(LostUnknownStrategy));
let mut events = hybrid.event_stream().unwrap();
hybrid.start_all().await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
assert!(hybrid.all_discovered_peers().await.is_empty());
let mut lost_forwarded = false;
while let Ok(event) = events.try_recv() {
if matches!(event, DiscoveryEvent::PeerLost(_)) {
lost_forwarded = true;
}
}
assert!(!lost_forwarded);
hybrid.stop_all().await.unwrap();
}
#[tokio::test]
async fn test_peer_updated_unknown_not_forwarded() {
struct UpdateUnknownStrategy;
#[async_trait]
impl DiscoveryStrategy for UpdateUnknownStrategy {
async fn start(&mut self) -> Result<()> {
Ok(())
}
async fn stop(&mut self) -> Result<()> {
Ok(())
}
async fn discovered_peers(&self) -> Vec<PeerInfo> {
vec![]
}
fn event_stream(&mut self) -> Result<mpsc::Receiver<DiscoveryEvent>> {
let (tx, rx) = mpsc::channel(100);
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let peer = PeerInfo {
node_id: "never-found".to_string(),
addresses: vec!["10.0.0.1:5000".parse().unwrap()],
relay_url: None,
last_seen: std::time::Instant::now(),
metadata: HashMap::new(),
};
let _ = tx.send(DiscoveryEvent::PeerUpdated(peer)).await;
});
Ok(rx)
}
}
let mut hybrid = HybridDiscovery::new();
hybrid.add_strategy("update-unknown", Box::new(UpdateUnknownStrategy));
let mut events = hybrid.event_stream().unwrap();
hybrid.start_all().await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
assert!(hybrid.all_discovered_peers().await.is_empty());
let mut updated_forwarded = false;
while let Ok(event) = events.try_recv() {
if matches!(event, DiscoveryEvent::PeerUpdated(_)) {
updated_forwarded = true;
}
}
assert!(!updated_forwarded);
hybrid.stop_all().await.unwrap();
}
#[tokio::test]
async fn test_duplicate_peer_from_two_strategies() {
let mut hybrid = HybridDiscovery::new();
let config1 = DiscoveryConfig {
peers: vec![StaticPeerConfig {
node_id: "shared-peer".to_string(),
addresses: vec!["10.0.0.1:5000".to_string()],
relay_url: None,
priority: 128,
metadata: HashMap::new(),
}],
};
let config2 = DiscoveryConfig {
peers: vec![StaticPeerConfig {
node_id: "shared-peer".to_string(),
addresses: vec!["10.0.0.2:5000".to_string()],
relay_url: None,
priority: 200,
metadata: HashMap::new(),
}],
};
hybrid.add_strategy(
"static-1",
Box::new(StaticDiscovery::from_config(config1).unwrap()),
);
hybrid.add_strategy(
"static-2",
Box::new(StaticDiscovery::from_config(config2).unwrap()),
);
let mut events = hybrid.event_stream().unwrap();
hybrid.start_all().await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let peers = hybrid.all_discovered_peers().await;
assert_eq!(peers.len(), 1);
assert_eq!(peers[0].node_id, "shared-peer");
let mut found_count = 0;
let mut updated_count = 0;
while let Ok(event) = events.try_recv() {
match event {
DiscoveryEvent::PeerFound(_) => found_count += 1,
DiscoveryEvent::PeerUpdated(_) => updated_count += 1,
_ => {}
}
}
assert!(found_count >= 1);
assert!(updated_count >= 1);
hybrid.stop_all().await.unwrap();
}
}