use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use mainline::{Dht, Id};
use tokio::sync::RwLock;
use crate::error::{EngineError, ProtocolErrorKind, Result};
use crate::torrent::metainfo::Sha1Hash;
pub const DEFAULT_BOOTSTRAP_NODES: &[&str] = &[
"router.bittorrent.com:6881",
"router.utorrent.com:6881",
"dht.transmissionbt.com:6881",
"dht.aelitis.com:6881",
];
pub struct DhtClient {
dht: Arc<Dht>,
listen_port: u16,
running: Arc<AtomicBool>,
peer_cache: Arc<RwLock<std::collections::HashMap<Sha1Hash, Vec<SocketAddr>>>>,
}
impl DhtClient {
pub fn new(listen_port: u16) -> Result<Self> {
let dht = Dht::client()
.map_err(|e| EngineError::protocol(ProtocolErrorKind::DhtError, e.to_string()))?;
Ok(Self {
dht: Arc::new(dht),
listen_port,
running: Arc::new(AtomicBool::new(true)),
peer_cache: Arc::new(RwLock::new(std::collections::HashMap::new())),
})
}
pub fn with_bootstrap(listen_port: u16, bootstrap_nodes: &[String]) -> Result<Self> {
let dht = if bootstrap_nodes.is_empty() {
Dht::client()
.map_err(|e| EngineError::protocol(ProtocolErrorKind::DhtError, e.to_string()))?
} else {
Dht::builder()
.bootstrap(bootstrap_nodes)
.build()
.map_err(|e| EngineError::protocol(ProtocolErrorKind::DhtError, e.to_string()))?
};
Ok(Self {
dht: Arc::new(dht),
listen_port,
running: Arc::new(AtomicBool::new(true)),
peer_cache: Arc::new(RwLock::new(std::collections::HashMap::new())),
})
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
pub async fn find_peers(&self, info_hash: &Sha1Hash) -> Vec<SocketAddr> {
if !self.is_running() {
return vec![];
}
let id = match Id::from_bytes(info_hash) {
Ok(id) => id,
Err(e) => {
tracing::error!("Failed to convert info_hash to DHT Id: {}", e);
return vec![];
}
};
let dht = self.dht.clone();
let peers: Vec<SocketAddr> = tokio::task::spawn_blocking(move || {
dht.get_peers(id)
.flatten()
.map(SocketAddr::V4)
.collect::<Vec<_>>()
})
.await
.unwrap_or_default();
if !peers.is_empty() {
let mut cache = self.peer_cache.write().await;
cache.insert(*info_hash, peers.clone());
}
peers
}
pub async fn find_peers_timeout(
&self,
info_hash: &Sha1Hash,
timeout: Duration,
) -> Vec<SocketAddr> {
match tokio::time::timeout(timeout, async { self.find_peers(info_hash).await }).await {
Ok(peers) => peers,
Err(_) => {
let cache = self.peer_cache.read().await;
cache.get(info_hash).cloned().unwrap_or_default()
}
}
}
pub fn announce(&self, info_hash: &Sha1Hash) -> Result<()> {
if !self.is_running() {
return Err(EngineError::protocol(
ProtocolErrorKind::DhtError,
"DHT client is not running",
));
}
let id = Id::from_bytes(info_hash).map_err(|e| {
EngineError::protocol(
ProtocolErrorKind::DhtError,
format!("Failed to convert info_hash to DHT Id: {}", e),
)
})?;
self.dht
.announce_peer(id, Some(self.listen_port))
.map_err(|e| EngineError::protocol(ProtocolErrorKind::DhtError, e.to_string()))?;
Ok(())
}
pub async fn cached_peers(&self, info_hash: &Sha1Hash) -> Vec<SocketAddr> {
let cache = self.peer_cache.read().await;
cache.get(info_hash).cloned().unwrap_or_default()
}
pub async fn clear_cache(&self, info_hash: &Sha1Hash) {
let mut cache = self.peer_cache.write().await;
cache.remove(info_hash);
}
pub fn shutdown(&self) {
self.running.store(false, Ordering::SeqCst);
}
pub fn listen_port(&self) -> u16 {
self.listen_port
}
}
impl Drop for DhtClient {
fn drop(&mut self) {
self.shutdown();
}
}
pub struct DhtManager {
client: Arc<DhtClient>,
tracked: Arc<RwLock<std::collections::HashSet<Sha1Hash>>>,
lookup_interval: Duration,
announce_interval: Duration,
}
impl DhtManager {
pub fn new(client: Arc<DhtClient>) -> Self {
Self {
client,
tracked: Arc::new(RwLock::new(std::collections::HashSet::new())),
lookup_interval: Duration::from_secs(300), announce_interval: Duration::from_secs(1800), }
}
pub fn set_lookup_interval(&mut self, interval: Duration) {
self.lookup_interval = interval;
}
pub fn set_announce_interval(&mut self, interval: Duration) {
self.announce_interval = interval;
}
pub async fn track(&self, info_hash: Sha1Hash) {
let mut tracked = self.tracked.write().await;
tracked.insert(info_hash);
}
pub async fn untrack(&self, info_hash: &Sha1Hash) {
let mut tracked = self.tracked.write().await;
tracked.remove(info_hash);
self.client.clear_cache(info_hash).await;
}
pub async fn tracked_hashes(&self) -> Vec<Sha1Hash> {
let tracked = self.tracked.read().await;
tracked.iter().cloned().collect()
}
pub async fn discover_peers(&self) -> std::collections::HashMap<Sha1Hash, Vec<SocketAddr>> {
let tracked = self.tracked.read().await;
let mut results = std::collections::HashMap::new();
for info_hash in tracked.iter() {
let peers = self
.client
.find_peers_timeout(info_hash, Duration::from_secs(30))
.await;
if !peers.is_empty() {
results.insert(*info_hash, peers);
}
}
results
}
pub async fn announce_all(&self) -> Vec<Result<()>> {
let tracked = self.tracked.read().await;
let mut results = Vec::new();
for info_hash in tracked.iter() {
results.push(self.client.announce(info_hash));
}
results
}
pub fn client(&self) -> &Arc<DhtClient> {
&self.client
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_bootstrap_nodes() {
assert!(!DEFAULT_BOOTSTRAP_NODES.is_empty());
for node in DEFAULT_BOOTSTRAP_NODES {
assert!(node.contains(':'), "Bootstrap node should have port");
}
}
#[tokio::test]
#[ignore]
async fn test_dht_client_creation() {
let client = DhtClient::new(6881);
assert!(client.is_ok(), "Should create DHT client");
let client = client.unwrap();
assert!(client.is_running());
assert_eq!(client.listen_port(), 6881);
}
#[tokio::test]
#[ignore]
async fn test_dht_find_peers() {
let info_hash: Sha1Hash = [
0x2c, 0x6b, 0x6a, 0x1e, 0x9c, 0x2f, 0x9f, 0x53, 0x4c, 0x8a, 0x9c, 0x7a, 0x1b, 0x2a,
0x3c, 0x4d, 0x5e, 0x6f, 0x70, 0x81,
];
let client = DhtClient::new(6881).unwrap();
let peers = client
.find_peers_timeout(&info_hash, Duration::from_secs(10))
.await;
println!("Found {} peers", peers.len());
}
#[tokio::test]
async fn test_dht_manager_tracking() {
let info_hash: Sha1Hash = [0u8; 20];
let mut tracked = std::collections::HashSet::new();
tracked.insert(info_hash);
assert!(tracked.contains(&info_hash));
tracked.remove(&info_hash);
assert!(!tracked.contains(&info_hash));
}
}