use std::{net::SocketAddr, thread, time::Duration};
use bytes::Bytes;
use flume::{Receiver, Sender};
use tracing::info;
use crate::{
common::{
hash_immutable, AnnouncePeerRequestArguments, GetPeersRequestArguments,
GetValueRequestArguments, Id, MutableItem, PutImmutableRequestArguments,
PutMutableRequestArguments, PutRequestSpecific, RequestTypeSpecific,
},
rpc::{PutResult, ReceivedFrom, ReceivedMessage, ResponseSender, Rpc},
server::{DhtServer, Server},
Result,
};
#[derive(Debug, Clone)]
pub struct Dht {
pub(crate) sender: Sender<ActorMessage>,
pub(crate) address: Option<SocketAddr>,
}
pub struct Builder {
settings: DhtSettings,
}
impl Builder {
pub fn build(self) -> Result<Dht> {
Dht::new(self.settings)
}
pub fn server(mut self) -> Self {
self.settings.server = Some(Box::<DhtServer>::default());
self
}
pub fn custom_server(mut self, custom_server: Box<dyn Server>) -> Self {
self.settings.server = Some(custom_server);
self
}
pub fn bootstrap(mut self, bootstrap: &[String]) -> Self {
self.settings.bootstrap = Some(bootstrap.to_vec());
self
}
pub fn port(mut self, port: u16) -> Self {
self.settings.port = Some(port);
self
}
pub fn request_timeout(mut self, request_timeout: Duration) -> Self {
self.settings.request_timeout = Some(request_timeout);
self
}
}
#[derive(Debug, Default)]
pub struct DhtSettings {
pub bootstrap: Option<Vec<String>>,
pub server: Option<Box<dyn Server>>,
pub port: Option<u16>,
pub request_timeout: Option<Duration>,
}
impl Dht {
pub fn builder() -> Builder {
Builder {
settings: DhtSettings::default(),
}
}
pub fn client() -> Result<Self> {
Dht::builder().build()
}
pub fn server() -> Result<Self> {
Dht::builder().server().build()
}
pub fn new(settings: DhtSettings) -> Result<Self> {
let (sender, receiver) = flume::bounded(32);
let rpc = Rpc::new(&settings)?;
let address = rpc.local_addr();
info!(?address, "Mainline DHT listening");
let mut server = settings.server;
thread::spawn(move || run(rpc, &mut server, receiver));
Ok(Dht {
sender,
address: Some(address),
})
}
pub fn local_addr(&self) -> Option<SocketAddr> {
self.address
}
pub fn shutdown(&mut self) -> Result<()> {
let (sender, receiver) = flume::bounded::<()>(1);
self.sender.send(ActorMessage::Shutdown(sender))?;
receiver.recv()?;
self.address = None;
Ok(())
}
pub fn get_peers(&self, info_hash: Id) -> Result<flume::IntoIter<Vec<SocketAddr>>> {
let (sender, receiver) = flume::unbounded::<Vec<SocketAddr>>();
let request = RequestTypeSpecific::GetPeers(GetPeersRequestArguments { info_hash });
self.sender.send(ActorMessage::Get(
info_hash,
request,
ResponseSender::Peers(sender),
))?;
Ok(receiver.into_iter())
}
pub fn announce_peer(&self, info_hash: Id, port: Option<u16>) -> Result<Id> {
let (sender, receiver) = flume::bounded::<PutResult>(1);
let (port, implied_port) = match port {
Some(port) => (port, None),
None => (0, Some(true)),
};
let request = PutRequestSpecific::AnnouncePeer(AnnouncePeerRequestArguments {
info_hash,
port,
implied_port,
});
self.sender
.send(ActorMessage::Put(info_hash, request, sender))?;
receiver.recv()?
}
pub fn get_immutable(&self, target: Id) -> Result<Bytes> {
let (sender, receiver) = flume::unbounded::<Bytes>();
let request = RequestTypeSpecific::GetValue(GetValueRequestArguments {
target,
seq: None,
salt: None,
});
self.sender.send(ActorMessage::Get(
target,
request,
ResponseSender::Immutable(sender),
))?;
Ok(receiver.recv()?)
}
pub fn put_immutable(&self, value: Bytes) -> Result<Id> {
let target = Id::from_bytes(hash_immutable(&value)).unwrap();
let (sender, receiver) = flume::bounded::<PutResult>(1);
let request = PutRequestSpecific::PutImmutable(PutImmutableRequestArguments {
target,
v: value.clone().into(),
});
self.sender
.send(ActorMessage::Put(target, request, sender))?;
receiver.recv()?
}
pub fn get_mutable(
&self,
public_key: &[u8; 32],
salt: Option<Bytes>,
seq: Option<i64>,
) -> Result<flume::IntoIter<MutableItem>> {
let target = MutableItem::target_from_key(public_key, &salt);
let (sender, receiver) = flume::unbounded::<MutableItem>();
let request = RequestTypeSpecific::GetValue(GetValueRequestArguments { target, seq, salt });
let _ = self.sender.send(ActorMessage::Get(
target,
request,
ResponseSender::Mutable(sender),
));
Ok(receiver.into_iter())
}
pub fn put_mutable(&self, item: MutableItem) -> Result<Id> {
let (sender, receiver) = flume::bounded::<PutResult>(1);
let request = PutRequestSpecific::PutMutable(PutMutableRequestArguments {
target: *item.target(),
v: item.value().clone().into(),
k: item.key().to_vec(),
seq: *item.seq(),
sig: item.signature().to_vec(),
salt: item.salt().clone().map(|s| s.to_vec()),
cas: *item.cas(),
});
let _ = self
.sender
.send(ActorMessage::Put(*item.target(), request, sender));
receiver.recv()?
}
}
fn run(mut rpc: Rpc, server: &mut Option<Box<dyn Server>>, receiver: Receiver<ActorMessage>) {
loop {
if let Ok(actor_message) = receiver.try_recv() {
match actor_message {
ActorMessage::Shutdown(sender) => {
drop(receiver);
let _ = sender.send(());
break;
}
ActorMessage::Put(target, request, sender) => {
rpc.put(target, request, Some(sender));
}
ActorMessage::Get(target, request, sender) => {
rpc.get(target, request, Some(sender), None)
}
}
}
let report = rpc.tick();
if let Some(ReceivedFrom {
from,
message: ReceivedMessage::Request((transaction_id, request_specific)),
}) = report.received_from
{
if let Some(server) = server.as_mut() {
server.handle_request(&mut rpc, from, transaction_id, &request_specific);
}
};
}
}
pub enum ActorMessage {
Put(Id, PutRequestSpecific, Sender<PutResult>),
Get(Id, RequestTypeSpecific, ResponseSender),
Shutdown(Sender<()>),
}
#[derive(Debug)]
pub struct Testnet {
pub bootstrap: Vec<String>,
pub nodes: Vec<Dht>,
}
impl Testnet {
pub fn new(count: usize) -> Self {
let mut nodes: Vec<Dht> = vec![];
let mut bootstrap = vec![];
for i in 0..count {
if i == 0 {
let node = Dht::builder().server().bootstrap(&[]).build().unwrap();
let addr = node.local_addr().unwrap();
bootstrap.push(format!("127.0.0.1:{}", addr.port()));
nodes.push(node)
} else {
let node = Dht::builder()
.server()
.bootstrap(&bootstrap)
.build()
.unwrap();
nodes.push(node)
}
}
Self { bootstrap, nodes }
}
}
#[cfg(test)]
mod test {
use std::str::FromStr;
use ed25519_dalek::SigningKey;
use super::*;
use crate::Error;
#[test]
fn shutdown() {
let mut dht = Dht::client().unwrap();
dht.local_addr();
let a = dht.clone();
dht.shutdown().unwrap();
let result = a.get_immutable(Id::random());
assert!(matches!(result, Err(Error::DhtIsShutdown(_))))
}
#[test]
fn bind_twice() {
let a = Dht::client().unwrap();
let result = Dht::builder()
.port(a.local_addr().unwrap().port())
.server()
.build();
assert!(result.is_err());
}
#[test]
fn announce_get_peer() {
let testnet = Testnet::new(10);
let a = Dht::builder()
.bootstrap(&testnet.bootstrap)
.build()
.unwrap();
let b = Dht::builder()
.bootstrap(&testnet.bootstrap)
.build()
.unwrap();
let info_hash = Id::random();
a.announce_peer(info_hash, Some(45555))
.expect("failed to announce");
let peers = b.get_peers(info_hash).unwrap().next().expect("No peers");
assert_eq!(peers.first().unwrap().port(), 45555);
}
#[test]
fn put_get_immutable() {
let testnet = Testnet::new(10);
let a = Dht::builder()
.bootstrap(&testnet.bootstrap)
.build()
.unwrap();
let b = Dht::builder()
.bootstrap(&testnet.bootstrap)
.build()
.unwrap();
let value: Bytes = "Hello World!".into();
let expected_target = Id::from_str("e5f96f6f38320f0f33959cb4d3d656452117aadb").unwrap();
let target = a.put_immutable(value.clone()).unwrap();
assert_eq!(target, expected_target);
let response = b.get_immutable(target).unwrap();
assert_eq!(response, value);
}
#[test]
fn put_get_mutable() {
let testnet = Testnet::new(10);
let a = Dht::builder()
.bootstrap(&testnet.bootstrap)
.build()
.unwrap();
let b = Dht::builder()
.bootstrap(&testnet.bootstrap)
.build()
.unwrap();
let signer = SigningKey::from_bytes(&[
56, 171, 62, 85, 105, 58, 155, 209, 189, 8, 59, 109, 137, 84, 84, 201, 221, 115, 7,
228, 127, 70, 4, 204, 182, 64, 77, 98, 92, 215, 27, 103,
]);
let seq = 1000;
let value: Bytes = "Hello World!".into();
let item = MutableItem::new(signer.clone(), value, seq, None);
a.put_mutable(item.clone()).unwrap();
let response = b
.get_mutable(signer.verifying_key().as_bytes(), None, None)
.unwrap()
.next()
.expect("No mutable values");
assert_eq!(&response, &item);
}
#[test]
fn put_get_mutable_no_more_recent_value() {
let testnet = Testnet::new(10);
let a = Dht::builder()
.bootstrap(&testnet.bootstrap)
.build()
.unwrap();
let b = Dht::builder()
.bootstrap(&testnet.bootstrap)
.build()
.unwrap();
let signer = SigningKey::from_bytes(&[
56, 171, 62, 85, 105, 58, 155, 209, 189, 8, 59, 109, 137, 84, 84, 201, 221, 115, 7,
228, 127, 70, 4, 204, 182, 64, 77, 98, 92, 215, 27, 103,
]);
let seq = 1000;
let value: Bytes = "Hello World!".into();
let item = MutableItem::new(signer.clone(), value, seq, None);
a.put_mutable(item.clone()).unwrap();
let response = b
.get_mutable(signer.verifying_key().as_bytes(), None, Some(seq))
.unwrap()
.next();
assert!(&response.is_none());
}
#[test]
fn repeated_put_query() {
let testnet = Testnet::new(10);
let a = Dht::builder()
.bootstrap(&testnet.bootstrap)
.build()
.unwrap();
let id = a.put_immutable(vec![1, 2, 3].into()).unwrap();
assert_eq!(a.put_immutable(vec![1, 2, 3].into()).unwrap(), id);
}
}