use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::net::UdpSocket;
use tokio::sync::oneshot;
use crate::error::{Error, ErrorKind};
use super::krpc::{KrpcMessage, TransactionId};
const RPC_TIMEOUT: Duration = Duration::from_secs(15);
pub type QueryHandler = Arc<dyn Fn(&KrpcMessage, SocketAddr) -> Option<Vec<u8>> + Send + Sync>;
pub struct DhtRpc {
socket: UdpSocket,
pending: Mutex<HashMap<TransactionId, oneshot::Sender<KrpcMessage>>>,
query_handler: Mutex<Option<QueryHandler>>,
timeout: Duration,
}
impl DhtRpc {
pub async fn new(bind_addr: SocketAddr) -> Result<Arc<Self>, Error> {
Self::with_timeout(bind_addr, RPC_TIMEOUT).await
}
pub async fn with_timeout(
bind_addr: SocketAddr, timeout: Duration,
) -> Result<Arc<Self>, Error> {
let socket = UdpSocket::bind(bind_addr).await?;
let rpc = Arc::new(DhtRpc {
socket,
pending: Mutex::new(HashMap::new()),
query_handler: Mutex::new(None),
timeout,
});
rpc.clone().start_recv_loop();
Ok(rpc)
}
pub fn set_query_handler(&self, handler: QueryHandler) {
*self.query_handler.lock().unwrap() = Some(handler);
}
pub fn local_addr(&self) -> Result<SocketAddr, Error> {
self.socket.local_addr().map_err(Error::protocol)
}
pub async fn query(
&self, addr: SocketAddr, tid: TransactionId, data: &[u8],
) -> Result<KrpcMessage, Error> {
let (tx, rx) = oneshot::channel();
self.pending.lock().unwrap().insert(tid, tx);
tracing::debug!("DHT query to {}", addr);
if let Err(e) = self.socket.send_to(data, addr).await {
self.pending.lock().unwrap().remove(&tid);
return Err(Error::with_source(ErrorKind::Protocol, e));
}
tokio::time::timeout(self.timeout, rx)
.await
.map_err(|_| {
self.pending.lock().unwrap().remove(&tid);
Error::new(ErrorKind::Protocol)
})?
.map_err(|_| {
self.pending.lock().unwrap().remove(&tid);
Error::new(ErrorKind::Protocol)
})
}
pub async fn ping(
&self, addr: SocketAddr, tid: TransactionId, node_id: &[u8; 20],
) -> Result<KrpcMessage, Error> {
let data = super::krpc::build_ping(tid, node_id);
self.query(addr, tid, &data).await
}
fn start_recv_loop(self: Arc<Self>) {
tokio::spawn(async move {
let mut buf = [0u8; 8192];
loop {
let (len, src_addr) = match self.socket.recv_from(&mut buf).await {
Ok(r) => r,
Err(e) => {
tracing::warn!("DHT recv error: {e}");
continue;
}
};
let msg = match KrpcMessage::from_bytes(&buf[..len]) {
Ok(m) => m,
Err(_) => continue,
};
match &msg {
KrpcMessage::Response { transaction_id, .. }
| KrpcMessage::Error { transaction_id, .. } => {
if let Some(tx) = self.pending.lock().unwrap().remove(transaction_id) {
let _ = tx.send(msg);
}
}
KrpcMessage::Query { .. } => {
let handler = self.query_handler.lock().unwrap().clone();
if let Some(handler) = handler {
if let Some(response_bytes) = handler(&msg, src_addr) {
let _ = self.socket.send_to(&response_bytes, src_addr).await;
}
}
}
}
}
});
}
}