use std::{
collections::{BTreeMap, BTreeSet, HashSet, VecDeque},
sync::Arc,
time::{Duration, UNIX_EPOCH},
};
use futures_buffered::FuturesUnordered;
use indexmap::IndexSet;
use iroh::NodeId;
use irpc::channel::{mpsc, oneshot};
use n0_future::{BufferedStreamExt, MaybeFuture, StreamExt, stream};
use rand::{Rng, SeedableRng, rngs::StdRng, seq::index::sample};
use serde::{Deserialize, Serialize};
use tokio::task::JoinSet;
#[cfg(test)]
mod tests;
pub mod rpc {
use std::{
fmt,
num::NonZeroU64,
ops::Deref,
sync::{Arc, Weak},
};
use iroh::{Endpoint, NodeAddr, NodeId, PublicKey};
use iroh_base::SignatureError;
use irpc::{
channel::{mpsc, oneshot},
rpc_requests,
};
use serde::{Deserialize, Serialize};
use serde_big_array::BigArray;
pub const ALPN: &[u8] = b"iroh/dht/0";
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Blake3Provider {
timestamp: u64, node_id: [u8; 32],
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Blake3Immutable {
pub timestamp: u64, pub data: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ED25519SignedMessage {
pub timestamp: u64,
#[serde(with = "BigArray")]
pub signature: [u8; 64],
pub data: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Value {
Blake3Provider(Blake3Provider),
ED25519SignedMessage(ED25519SignedMessage),
Blake3Immutable(Blake3Immutable),
}
impl Value {
pub fn kind(&self) -> Kind {
match self {
Value::Blake3Provider(_) => Kind::Blake3Provider,
Value::ED25519SignedMessage(_) => Kind::ED25519SignedMessage,
Value::Blake3Immutable(_) => Kind::Blake3Immutable,
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Kind {
Blake3Provider,
ED25519SignedMessage,
Blake3Immutable,
}
#[derive(Clone, Copy, Ord, PartialOrd, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Id([u8; 32]);
impl From<[u8; 32]> for Id {
fn from(bytes: [u8; 32]) -> Self {
Id(bytes)
}
}
impl Deref for Id {
type Target = [u8; 32];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl fmt::Debug for Id {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Id({})", hex::encode(self.0))
}
}
impl fmt::Display for Id {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", hex::encode(self.0))
}
}
impl From<PublicKey> for Id {
fn from(pk: PublicKey) -> Self {
Id(*pk.as_bytes())
}
}
impl From<blake3::Hash> for Id {
fn from(pk: blake3::Hash) -> Self {
Id(*pk.as_bytes())
}
}
impl Id {
pub fn blake3_hash(data: &[u8]) -> Self {
let hash = blake3::hash(data);
Id(hash.into())
}
pub fn node_id(id: iroh::NodeId) -> Self {
Id::from(*id.as_bytes())
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Set {
pub key: Id,
pub value: Value,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum SetResponse {
Ok,
ErrDistance,
ErrExpired,
ErrFull,
ErrInvalid,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GetAll {
pub key: Id,
pub kind: Kind,
pub seed: Option<NonZeroU64>,
pub n: Option<NonZeroU64>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FindNode {
pub id: Id,
pub requester: Option<NodeId>,
}
#[rpc_requests(message = RpcMessage)]
#[derive(Debug, Serialize, Deserialize)]
pub enum RpcProto {
#[rpc(tx = oneshot::Sender<SetResponse>)]
Set(Set),
#[rpc(tx = mpsc::Sender<Value>)]
GetAll(GetAll),
#[rpc(tx = oneshot::Sender<Vec<NodeAddr>>)]
FindNode(FindNode),
}
#[derive(Debug, Clone)]
pub struct RpcClient(pub(crate) Arc<irpc::Client<RpcProto>>);
#[derive(Debug, Clone)]
pub struct WeakRpcClient(pub(crate) Weak<irpc::Client<RpcProto>>);
impl WeakRpcClient {
pub fn upgrade(&self) -> Option<RpcClient> {
self.0.upgrade().map(RpcClient)
}
}
impl RpcClient {
pub fn remote(endpoint: Endpoint, id: Id) -> std::result::Result<Self, SignatureError> {
let id = iroh::NodeId::from_bytes(&id)?;
let client = irpc_iroh::client(endpoint, id, ALPN);
Ok(Self::new(client))
}
pub fn new(client: irpc::Client<RpcProto>) -> Self {
Self(Arc::new(client))
}
pub async fn set(&self, key: Id, value: Value) -> irpc::Result<SetResponse> {
self.0.rpc(Set { key, value }).await
}
pub async fn get_all(
&self,
key: Id,
kind: Kind,
seed: Option<NonZeroU64>,
n: Option<NonZeroU64>,
) -> irpc::Result<irpc::channel::mpsc::Receiver<Value>> {
self.0
.server_streaming(GetAll { key, kind, seed, n }, 32)
.await
}
pub async fn find_node(
&self,
id: Id,
requester: Option<NodeId>,
) -> irpc::Result<Vec<NodeAddr>> {
self.0.rpc(FindNode { id, requester }).await
}
pub fn downgrade(&self) -> WeakRpcClient {
WeakRpcClient(Arc::downgrade(&self.0))
}
}
}
pub mod api {
use std::{
collections::BTreeMap,
num::NonZeroU64,
sync::{Arc, Weak},
time::Duration,
};
use iroh::NodeId;
use irpc::{
channel::{mpsc, none::NoSender, oneshot},
rpc_requests,
};
use serde::{Deserialize, Serialize};
use crate::{
now,
routing::NodeInfo,
rpc::{Blake3Immutable, Id, Kind, Value},
};
#[rpc_requests(message = ApiMessage)]
#[derive(Debug, Serialize, Deserialize)]
pub enum ApiProto {
#[rpc(tx = NoSender)]
#[wrap(NodesSeen)]
NodesSeen { ids: Vec<NodeId> },
#[rpc(tx = NoSender)]
#[wrap(NodesDead)]
NodesDead { ids: Vec<NodeId> },
#[rpc(tx = oneshot::Sender<Vec<NodeId>>)]
#[wrap(Lookup)]
Lookup {
initial: Option<Vec<NodeId>>,
id: Id,
},
#[rpc(tx = mpsc::Sender<NodeId>)]
#[wrap(NetworkPut)]
NetworkPut { id: Id, value: Value },
#[rpc(tx = mpsc::Sender<(NodeId, Value)>)]
#[wrap(NetworkGet)]
NetworkGet {
id: Id,
kind: Kind,
seed: Option<NonZeroU64>,
n: Option<NonZeroU64>,
},
#[rpc(tx = oneshot::Sender<Vec<Vec<NodeInfo>>>)]
#[wrap(GetRoutingTable)]
GetRoutingTable,
#[rpc(tx = oneshot::Sender<BTreeMap<Id, BTreeMap<Kind, usize>>>)]
#[wrap(GetStorageStats)]
GetStorageStats,
#[rpc(tx = oneshot::Sender<()>)]
#[wrap(SelfLookup)]
SelfLookup,
#[rpc(tx = oneshot::Sender<()>)]
#[wrap(RandomLookup)]
RandomLookup,
#[rpc(tx = oneshot::Sender<()>)]
#[wrap(CandidateLookup)]
CandidateLookup,
}
#[derive(Debug, Clone)]
pub struct ApiClient(pub(crate) Arc<irpc::Client<ApiProto>>);
impl ApiClient {
pub async fn nodes_seen(&self, ids: &[NodeId]) -> irpc::Result<()> {
self.0.notify(NodesSeen { ids: ids.to_vec() }).await
}
pub async fn nodes_dead(&self, ids: &[NodeId]) -> irpc::Result<()> {
self.0.notify(NodesDead { ids: ids.to_vec() }).await
}
pub async fn get_storage_stats(&self) -> irpc::Result<BTreeMap<Id, BTreeMap<Kind, usize>>> {
self.0.rpc(GetStorageStats).await
}
pub async fn get_routing_table(&self) -> irpc::Result<Vec<Vec<NodeInfo>>> {
self.0.rpc(GetRoutingTable).await
}
pub async fn lookup(
&self,
id: Id,
initial: Option<Vec<NodeId>>,
) -> irpc::Result<Vec<NodeId>> {
self.0.rpc(Lookup { id, initial }).await
}
pub async fn get_immutable(&self, hash: blake3::Hash) -> irpc::Result<Option<Vec<u8>>> {
let id = Id::from(*hash.as_bytes());
let mut rx = self
.0
.server_streaming(
NetworkGet {
id,
kind: Kind::Blake3Immutable,
seed: None,
n: Some(NonZeroU64::new(1).unwrap()),
},
32,
)
.await?;
loop {
match rx.recv().await {
Ok(Some((_, value))) => {
let Value::Blake3Immutable(Blake3Immutable { data, .. }) = value else {
continue; };
if blake3::hash(&data) == hash {
return Ok(Some(data));
} else {
continue; }
}
Ok(None) => {
break Ok(None);
}
Err(e) => {
break Err(e.into());
}
}
}
}
pub async fn put_immutable(
&self,
value: &[u8],
) -> irpc::Result<(blake3::Hash, Vec<NodeId>)> {
let hash = blake3::hash(value);
let id = Id::from(*hash.as_bytes());
let mut rx = self
.0
.server_streaming(
NetworkPut {
id,
value: Value::Blake3Immutable(Blake3Immutable {
timestamp: now(),
data: value.to_vec(),
}),
},
32,
)
.await?;
let mut res = Vec::new();
loop {
match rx.recv().await {
Ok(Some(id)) => res.push(id),
Ok(None) => break,
Err(_) => {}
}
}
Ok((hash, res))
}
pub async fn self_lookup(&self) {
self.0.rpc(SelfLookup).await.ok();
}
pub async fn random_lookup(&self) {
self.0.rpc(RandomLookup).await.ok();
}
pub async fn candidate_lookup(&self) {
self.0.rpc(CandidateLookup).await.ok();
}
pub fn downgrade(&self) -> WeakApiClient {
WeakApiClient(Arc::downgrade(&self.0))
}
}
#[derive(Debug, Clone)]
pub struct WeakApiClient(pub(crate) Weak<irpc::Client<ApiProto>>);
impl WeakApiClient {
pub fn upgrade(&self) -> irpc::Result<ApiClient> {
self.0
.upgrade()
.map(ApiClient)
.ok_or(irpc::Error::Send(irpc::channel::SendError::ReceiverClosed))
}
pub async fn nodes_dead(&self, ids: &[NodeId]) -> irpc::Result<()> {
self.upgrade()?.nodes_dead(ids).await
}
pub async fn nodes_seen(&self, ids: &[NodeId]) -> irpc::Result<()> {
self.upgrade()?.nodes_seen(ids).await
}
pub(crate) async fn self_lookup_periodic(self, interval: Duration) {
loop {
tokio::time::sleep(interval).await;
let Ok(api) = self.upgrade() else {
return;
};
api.self_lookup().await;
}
}
pub(crate) async fn random_lookup_periodic(self, interval: Duration) {
loop {
tokio::time::sleep(interval).await;
let Ok(api) = self.upgrade() else {
return;
};
api.random_lookup().await;
}
}
pub(crate) async fn candidate_lookup_periodic(self, interval: Duration) {
loop {
tokio::time::sleep(interval).await;
let Ok(api) = self.upgrade() else {
return;
};
api.candidate_lookup().await;
}
}
}
}
pub use api::ApiClient;
use tracing::{error, info, warn};
mod routing {
use std::{
fmt,
ops::{Index, IndexMut},
};
use arrayvec::ArrayVec;
use iroh::NodeId;
use serde::{Deserialize, Serialize};
use super::rpc::Id;
pub const K: usize = 20; pub const ALPHA: usize = 3; pub const BUCKET_COUNT: usize = 256;
fn xor(a: &[u8; 32], b: &[u8; 32]) -> [u8; 32] {
let mut result = [0u8; 32];
for i in 0..32 {
result[i] = a[i] ^ b[i];
}
result
}
fn leading_zeros(data: &[u8; 32]) -> usize {
for (byte_idx, &byte) in data.iter().enumerate() {
if byte != 0 {
return byte_idx * 8 + byte.leading_zeros() as usize;
}
}
256 }
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct Distance([u8; 32]);
impl Distance {
pub fn between(a: &[u8; 32], b: &[u8; 32]) -> Self {
Self(xor(a, b))
}
pub fn inverse(&self, b: &[u8; 32]) -> [u8; 32] {
xor(&self.0, b)
}
pub const MAX: Self = Self([u8::MAX; 32]);
}
impl fmt::Debug for Distance {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Distance({self})")
}
}
impl fmt::Display for Distance {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", hex::encode(self.0))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct NodeInfo {
pub id: NodeId,
pub last_seen: u64,
}
impl NodeInfo {
pub fn new(id: NodeId, last_seen: u64) -> Self {
Self { id, last_seen }
}
}
#[derive(Debug, Clone, Default)]
pub struct KBucket {
nodes: ArrayVec<NodeInfo, K>,
}
impl KBucket {
fn new() -> Self {
Self {
nodes: ArrayVec::new(),
}
}
pub fn add_node(&mut self, node: NodeInfo) -> bool {
for existing in &mut self.nodes {
if existing.id == node.id {
existing.last_seen = node.last_seen;
return true; }
}
if self.nodes.len() < K {
self.nodes.push(node);
return true;
}
false }
fn remove_node(&mut self, id: &NodeId) {
self.nodes.retain(|n| n.id != *id);
}
pub fn nodes(&self) -> &[NodeInfo] {
&self.nodes
}
}
#[derive(Debug)]
pub struct RoutingTable {
pub buckets: Box<Buckets>,
pub local_id: NodeId,
}
#[derive(Debug, Clone)]
pub struct Buckets(pub [KBucket; BUCKET_COUNT]);
impl Buckets {
pub fn iter(&self) -> std::slice::Iter<'_, KBucket> {
self.0.iter()
}
}
impl Index<usize> for Buckets {
type Output = KBucket;
fn index(&self, index: usize) -> &Self::Output {
&self.0[index]
}
}
impl IndexMut<usize> for Buckets {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.0[index]
}
}
impl Default for Buckets {
fn default() -> Self {
Self(std::array::from_fn(|_| KBucket::new()))
}
}
impl RoutingTable {
pub fn new(local_id: NodeId, buckets: Option<Box<Buckets>>) -> Self {
let buckets = buckets
.map(|mut buckets| {
for bucket in buckets.0.iter_mut() {
bucket.nodes.retain(|n| n.id != local_id);
}
buckets
})
.unwrap_or_default();
Self { buckets, local_id }
}
fn bucket_index(&self, target: &[u8; 32]) -> usize {
let distance = xor(self.local_id.as_bytes(), target);
let zeros = leading_zeros(&distance);
if zeros >= BUCKET_COUNT {
0 } else {
BUCKET_COUNT - 1 - zeros
}
}
pub(crate) fn contains(&self, id: &NodeId) -> bool {
let bucket_idx = self.bucket_index(id.as_bytes());
self.buckets[bucket_idx]
.nodes()
.iter()
.any(|node| node.id == *id)
}
pub fn add_node(&mut self, node: NodeInfo) -> bool {
if node.id == self.local_id {
return false;
}
let bucket_idx = self.bucket_index(node.id.as_bytes());
self.buckets[bucket_idx].add_node(node)
}
pub(crate) fn remove_node(&mut self, id: &NodeId) {
let bucket_idx = self.bucket_index(id.as_bytes());
self.buckets[bucket_idx].remove_node(id);
}
pub fn nodes(&self) -> impl Iterator<Item = &NodeInfo> {
self.buckets.iter().flat_map(|bucket| bucket.nodes())
}
pub fn find_closest_nodes(&self, target: &Id, k: usize) -> Vec<NodeId> {
let mut candidates = Vec::with_capacity(self.nodes().count());
candidates.extend(
self.nodes()
.map(|node| Distance::between(target, node.id.as_bytes())),
);
if k < candidates.len() {
candidates.select_nth_unstable(k - 1);
candidates.truncate(k);
}
candidates.sort_unstable();
candidates
.into_iter()
.map(|dist| {
NodeId::from_bytes(&dist.inverse(target))
.expect("inverse called with different target than between")
})
.collect()
}
}
}
#[doc(hidden)]
pub mod bench_exports {
pub use crate::{
routing::{Buckets, KBucket, NodeInfo, RoutingTable},
rpc::Id,
};
}
use crate::{
api::{ApiMessage, Lookup, NetworkGet, NetworkPut, WeakApiClient},
pool::ClientPool,
routing::{ALPHA, BUCKET_COUNT, Buckets, Distance, K, NodeInfo, RoutingTable},
rpc::{Id, Kind, RpcClient, RpcMessage, SetResponse, Value},
u256::U256,
};
struct Node {
routing_table: RoutingTable,
storage: MemStorage,
}
impl Node {
fn id(&self) -> &NodeId {
&self.routing_table.local_id
}
}
struct MemStorage {
data: BTreeMap<Id, BTreeMap<Kind, IndexSet<Value>>>,
}
impl MemStorage {
fn new() -> Self {
Self {
data: BTreeMap::new(),
}
}
fn set(&mut self, key: Id, value: Value) {
let kind = value.kind();
self.data
.entry(key)
.or_default()
.entry(kind)
.or_default()
.insert(value);
}
fn get_all(&self, key: &Id, kind: &Kind) -> Option<&IndexSet<Value>> {
self.data.get(key).and_then(|kinds| kinds.get(kind))
}
}
mod u256 {
#![allow(clippy::needless_range_loop)]
use std::ops::{BitAnd, BitOr, BitXor, Deref, Not, Shl, Shr};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct U256([u8; 32]);
impl Deref for U256 {
type Target = [u8; 32];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl U256 {
pub const MIN: U256 = U256([0u8; 32]);
pub const MAX: U256 = U256([0xffu8; 32]);
pub fn from_le_bytes(bytes: [u8; 32]) -> Self {
U256(bytes)
}
#[allow(clippy::wrong_self_convention)]
pub fn to_le_bytes(&self) -> [u8; 32] {
self.0
}
#[allow(dead_code)]
pub fn leading_zeros(&self) -> u32 {
let mut count = 0;
for &byte in self.0.iter().rev() {
if byte == 0 {
count += 8;
} else {
count += byte.leading_zeros();
break;
}
}
count
}
}
impl BitXor for U256 {
type Output = Self;
fn bitxor(self, rhs: Self) -> Self::Output {
let mut result = [0u8; 32];
for i in 0..32 {
result[i] = self.0[i] ^ rhs.0[i];
}
U256(result)
}
}
impl BitAnd for U256 {
type Output = Self;
fn bitand(self, rhs: Self) -> Self::Output {
let mut result = [0u8; 32];
for i in 0..32 {
result[i] = self.0[i] & rhs.0[i];
}
U256(result)
}
}
impl BitOr for U256 {
type Output = Self;
fn bitor(self, rhs: Self) -> Self::Output {
let mut result = [0u8; 32];
for i in 0..32 {
result[i] = self.0[i] | rhs.0[i];
}
U256(result)
}
}
impl Not for U256 {
type Output = Self;
fn not(self) -> Self::Output {
let mut result = [0u8; 32];
for i in 0..32 {
result[i] = !self.0[i];
}
U256(result)
}
}
impl Shl<u32> for U256 {
type Output = Self;
fn shl(self, rhs: u32) -> Self::Output {
if rhs >= 256 {
return U256::MIN;
}
let low = u128::from_le_bytes(self.0[0..16].try_into().unwrap());
let high = u128::from_le_bytes(self.0[16..32].try_into().unwrap());
let (new_low, new_high) = if rhs >= 128 {
let shift_amount = rhs - 128;
(0, low << shift_amount)
} else {
let overflow_bits = 128 - rhs;
let new_low = low << rhs;
let new_high = (high << rhs) | (low >> overflow_bits);
(new_low, new_high)
};
let mut result = [0u8; 32];
result[0..16].copy_from_slice(&new_low.to_le_bytes());
result[16..32].copy_from_slice(&new_high.to_le_bytes());
U256(result)
}
}
impl Shr<u32> for U256 {
type Output = Self;
fn shr(self, rhs: u32) -> Self::Output {
if rhs >= 256 {
return U256::MIN;
}
let low = u128::from_le_bytes(self.0[0..16].try_into().unwrap());
let high = u128::from_le_bytes(self.0[16..32].try_into().unwrap());
let (new_low, new_high) = if rhs >= 128 {
let shift_amount = rhs - 128;
(high >> shift_amount, 0)
} else {
let overflow_bits = 128 - rhs;
let new_low = (low >> rhs) | (high << overflow_bits);
let new_high = high >> rhs;
(new_low, new_high)
};
let mut result = [0u8; 32];
result[0..16].copy_from_slice(&new_low.to_le_bytes());
result[16..32].copy_from_slice(&new_high.to_le_bytes());
U256(result)
}
}
}
pub mod pool {
use std::sync::{Arc, RwLock};
use iroh::{
Endpoint, NodeAddr, NodeId,
endpoint::{RecvStream, SendStream},
};
use iroh_blobs::util::connection_pool::{ConnectionPool, ConnectionRef};
use snafu::Snafu;
use tracing::error;
use crate::rpc::{RpcClient, WeakRpcClient};
pub trait ClientPool: Send + Sync + Clone + Sized + 'static {
fn id(&self) -> NodeId;
fn node_addr(&self, node_id: NodeId) -> NodeAddr {
node_id.into()
}
fn add_node_addr(&self, _addr: NodeAddr) {}
fn client(&self, id: NodeId) -> impl Future<Output = Result<RpcClient, String>> + Send;
}
#[derive(Debug, Snafu)]
pub struct PoolError {
pub message: String,
}
#[derive(Debug, Clone)]
pub struct IrohPool {
endpoint: Endpoint,
inner: ConnectionPool,
self_client: Arc<RwLock<Option<WeakRpcClient>>>,
}
impl IrohPool {
pub fn new(endpoint: Endpoint, inner: ConnectionPool) -> Self {
Self {
endpoint,
inner,
self_client: Arc::new(RwLock::new(None)),
}
}
pub fn set_self_client(&self, client: Option<WeakRpcClient>) {
let mut self_client = self.self_client.write().unwrap();
*self_client = client;
}
}
#[derive(Debug, Clone)]
struct IrohConnection(Arc<ConnectionRef>);
impl irpc::rpc::RemoteConnection for IrohConnection {
fn clone_boxed(&self) -> Box<dyn irpc::rpc::RemoteConnection> {
Box::new(self.clone())
}
fn open_bi(
&self,
) -> n0_future::future::Boxed<
std::result::Result<(SendStream, RecvStream), irpc::RequestError>,
> {
let conn = self.0.clone();
Box::pin(async move {
let (send, recv) = conn.open_bi().await?;
Ok((send, recv))
})
}
}
impl ClientPool for IrohPool {
fn id(&self) -> NodeId {
self.endpoint.node_id()
}
fn node_addr(&self, node_id: NodeId) -> NodeAddr {
match self.endpoint.remote_info(node_id) {
Some(info) => info.into(),
None => node_id.into(),
}
}
fn add_node_addr(&self, addr: NodeAddr) {
if addr.node_id == self.id() {
return;
}
if addr.relay_url.is_none() && addr.direct_addresses.is_empty() {
return;
}
self.endpoint.add_node_addr_with_source(addr, "").ok();
}
async fn client(&self, node_id: NodeId) -> Result<RpcClient, String> {
if node_id == self.id() {
if let Some(client) = self.self_client.read().unwrap().clone() {
return client
.upgrade()
.ok_or_else(|| "Self client is no longer available".to_string());
} else {
error!("Self client not set");
return Err("Self client not set".to_string());
}
}
let connection = self
.inner
.get_or_connect(node_id)
.await
.map_err(|e| format!("Failed to connect: {e}"));
let connection = connection?;
let client = RpcClient::new(irpc::Client::boxed(IrohConnection(Arc::new(connection))));
Ok(client)
}
}
}
#[derive(Debug, Clone)]
struct State<P> {
api: WeakApiClient,
pool: P,
config: Config,
}
struct Candidates {
ids: VecDeque<NodeId>,
max_size: usize,
}
impl Candidates {
fn new(max_size: usize) -> Self {
Self {
ids: VecDeque::new(),
max_size,
}
}
fn add(&mut self, id: NodeId) {
self.ids.retain(|x| x != &id);
self.ids.push_front(id);
while self.ids.len() > self.max_size {
self.ids.pop_back();
}
}
fn clear_and_take(&mut self) -> Vec<NodeId> {
let res = self.ids.iter().cloned().collect();
self.ids.clear();
res
}
}
struct Actor<P> {
node: Node,
rpc_rx: tokio::sync::mpsc::Receiver<RpcMessage>,
api_rx: tokio::sync::mpsc::Receiver<ApiMessage>,
tasks: JoinSet<()>,
state: State<P>,
candidates: Option<Candidates>,
rng: rand::rngs::StdRng,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Config {
k: usize,
alpha: usize,
parallelism: usize,
transient: bool,
rng_seed: Option<[u8; 32]>,
lookup_strategies: LookupStrategies,
}
pub mod config {
use super::*;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct LookupStrategies {
pub random: Option<RandomLookupStrategy>,
pub self_id: Option<SelfLookupStrategy>,
pub candidate: Option<CandidateLookupStrategy>,
}
impl LookupStrategies {
pub fn none() -> Self {
Self {
random: None,
self_id: None,
candidate: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
pub struct CandidateLookupStrategy {
pub max_lookups: usize,
pub interval: Duration,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
pub struct SelfLookupStrategy {
pub interval: Duration,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
pub struct RandomLookupStrategy {
pub interval: Duration,
pub blended: bool,
}
}
use config::*;
impl Config {
pub fn transient() -> Self {
Self {
transient: true,
..Default::default()
}
}
pub fn persistent() -> Self {
Self {
transient: false,
..Default::default()
}
}
pub fn candidate_lookup_strategy(mut self, value: CandidateLookupStrategy) -> Self {
self.lookup_strategies.candidate = Some(value);
self
}
pub fn random_lookup_strategy(mut self, value: RandomLookupStrategy) -> Self {
self.lookup_strategies.random = Some(value);
self
}
pub fn self_lookup_strategy(mut self, value: SelfLookupStrategy) -> Self {
self.lookup_strategies.self_id = Some(value);
self
}
}
impl Default for Config {
fn default() -> Self {
Self {
k: K,
alpha: ALPHA,
parallelism: 4,
transient: true,
lookup_strategies: LookupStrategies {
random: None,
self_id: None,
candidate: None,
},
rng_seed: None,
}
}
}
impl<P> Actor<P>
where
P: ClientPool,
{
fn new(
node: Node,
rx: tokio::sync::mpsc::Receiver<RpcMessage>,
pool: P,
config: Config,
) -> (Self, ApiClient) {
let (api_tx, internal_rx) = tokio::sync::mpsc::channel(32);
let api = ApiClient(Arc::new(api_tx.into()));
let mut tasks = JoinSet::new();
let state = State {
api: api.downgrade(),
pool,
config: config.clone(),
};
tasks.spawn(state.clone().notify_self());
(
Self {
node,
rpc_rx: rx,
api_rx: internal_rx,
tasks,
state,
candidates: config
.lookup_strategies
.candidate
.map(|s| Candidates::new(s.max_lookups * config.k)),
rng: config
.rng_seed
.map(StdRng::from_seed)
.unwrap_or(StdRng::from_entropy()),
},
api,
)
}
async fn run(mut self) {
loop {
tokio::select! {
msg = self.rpc_rx.recv() => {
if let Some(msg) = msg {
self.handle_rpc(msg).await;
} else {
break;
}
}
msg = self.api_rx.recv() => {
if let Some(msg) = msg {
self.handle_api(msg).await;
} else {
break;
}
}
Some(res) = self.tasks.join_next(), if !self.tasks.is_empty() => {
if let Err(e) = res {
error!("Task failed: {:?}", e);
}
}
}
}
}
async fn handle_api(&mut self, message: ApiMessage) {
match message {
ApiMessage::NodesSeen(msg) => {
let now = now();
for id in msg.ids.iter().copied() {
self.node
.routing_table
.add_node(NodeInfo { id, last_seen: now });
}
}
ApiMessage::NodesDead(msg) => {
for id in msg.ids.iter() {
self.node.routing_table.remove_node(id);
}
}
ApiMessage::Lookup(msg) => {
let initial = msg
.initial
.clone()
.unwrap_or_else(|| self.node.routing_table.find_closest_nodes(&msg.id, K));
self.tasks
.spawn(self.state.clone().lookup(initial, msg.inner, msg.tx));
}
ApiMessage::NetworkGet(msg) => {
let initial = self.node.routing_table.find_closest_nodes(&msg.id, K);
self.tasks
.spawn(self.state.clone().network_get(initial, msg.inner, msg.tx));
}
ApiMessage::NetworkPut(msg) => {
let initial = self.node.routing_table.find_closest_nodes(&msg.id, K);
self.tasks
.spawn(self.state.clone().network_put(initial, msg.inner, msg.tx));
}
ApiMessage::GetRoutingTable(msg) => {
let table = self
.node
.routing_table
.buckets
.iter()
.map(|bucket| bucket.nodes().to_vec())
.collect();
msg.tx.send(table).await.ok();
}
ApiMessage::GetStorageStats(msg) => {
let mut stats = BTreeMap::new();
for (key, kinds) in &self.node.storage.data {
let kind_stats = kinds
.iter()
.map(|(kind, values)| (*kind, values.len()))
.collect();
stats.insert(*key, kind_stats);
}
msg.tx.send(stats).await.ok();
}
ApiMessage::SelfLookup(msg) => {
let id = self.state.pool.id().into();
let api = self.state.api.clone();
self.tasks.spawn(async move {
let Ok(api) = api.upgrade() else {
return;
};
api.lookup(id, None).await.ok();
msg.tx.send(()).await.ok();
});
}
ApiMessage::RandomLookup(msg) => {
let blended = false;
let id = if blended {
let bucket = self.rng.gen_range::<u32, _>(0..BUCKET_COUNT as u32 + 2);
let this = U256::from_le_bytes(*self.node.id().as_bytes());
let random = U256::from_le_bytes(self.rng.r#gen());
let res = blend(this, random, bucket);
Id::from(res.to_le_bytes())
} else {
Id::from(self.rng.r#gen::<[u8; 32]>())
};
let api = self.state.api.clone();
self.tasks.spawn(async move {
let Ok(api) = api.upgrade() else {
return;
};
api.lookup(id, None).await.ok();
msg.tx.send(()).await.ok();
});
}
ApiMessage::CandidateLookup(msg) => {
if self.state.config.lookup_strategies.candidate.is_none() {
warn!(
"Received CandidateLookup request, but no candidate lookup strategy is configured"
);
return;
};
let Some(candidates) = self.candidates.as_mut() else {
warn!("Received CandidateLookup request, but no candidates are being tracked");
return;
};
let chosen = candidates.clear_and_take();
let api = self.state.api.clone();
let groups = chosen
.chunks(self.state.config.k)
.map(|chunk| {
let id = Id::from(self.rng.r#gen::<[u8; 32]>());
(id, chunk.to_vec())
})
.collect::<Vec<_>>();
self.tasks.spawn(async move {
let Ok(api) = api.upgrade() else {
return;
};
for (id, ids) in groups {
api.lookup(id, Some(ids)).await.ok();
}
msg.tx.send(()).await.ok();
});
}
}
}
async fn handle_rpc(&mut self, message: RpcMessage) {
match message {
RpcMessage::Set(msg) => {
let ids = self
.node
.routing_table
.find_closest_nodes(&msg.key, self.state.config.k);
let self_dist = Distance::between(self.node.id().as_bytes(), &msg.key);
if ids.len() >= self.state.config.k
&& ids.iter().all(|id| {
Distance::between(self.node.id().as_bytes(), id.as_bytes()) < self_dist
})
{
msg.tx.send(SetResponse::ErrDistance).await.ok();
return;
}
self.node.storage.set(msg.key, msg.value.clone());
msg.tx.send(SetResponse::Ok).await.ok();
}
RpcMessage::GetAll(msg) => {
let Some(values) = self.node.storage.get_all(&msg.key, &msg.kind) else {
return;
};
if let Some(seed) = msg.seed {
let mut rng = rand::rngs::StdRng::seed_from_u64(seed.get());
let n = msg.n.map(|x| x.get()).unwrap_or(values.len() as u64) as usize;
let indices = sample(&mut rng, values.len(), n);
for i in indices {
if let Some(value) = values.get_index(i)
&& msg.tx.send(value.clone()).await.is_err()
{
break;
}
}
} else {
for value in values {
if msg.tx.send(value.clone()).await.is_err() {
break;
}
}
}
}
RpcMessage::FindNode(msg) => {
let ids = self
.node
.routing_table
.find_closest_nodes(&msg.id, self.state.config.k)
.into_iter()
.take(self.state.config.k) .map(|id| self.state.pool.node_addr(id))
.collect();
if let Some(requester) = msg.requester {
self.add_candidate(requester);
}
msg.tx.send(ids).await.ok();
}
}
}
fn add_candidate(&mut self, id: NodeId) {
if self.state.config.transient {
warn!("Received FindNode request for transient node");
return;
}
if self.node.routing_table.contains(&id) {
self.node.routing_table.add_node(NodeInfo {
id,
last_seen: now(),
});
}
let Some(candidates) = &mut self.candidates else {
return;
};
candidates.add(id);
}
}
impl<P: ClientPool> State<P> {
async fn lookup(self, initial: Vec<NodeId>, msg: Lookup, tx: oneshot::Sender<Vec<NodeId>>) {
let ids = self.clone().iterative_find_node(msg.id, initial).await;
tx.send(ids).await.ok();
}
async fn network_put(self, initial: Vec<NodeId>, msg: NetworkPut, tx: mpsc::Sender<NodeId>) {
let ids = self.clone().iterative_find_node(msg.id, initial).await;
stream::iter(ids)
.for_each_concurrent(self.config.parallelism, |id| {
let pool = self.pool.clone();
let value = msg.value.clone();
let tx = tx.clone();
async move {
let Ok(client) = pool.client(id).await else {
return;
};
if client.set(msg.id, value).await.is_ok() {
tx.send(id).await.ok();
}
drop(client);
}
})
.await;
}
async fn network_get(
self,
initial: Vec<NodeId>,
msg: NetworkGet,
tx: mpsc::Sender<(NodeId, Value)>,
) {
let ids = self.clone().iterative_find_node(msg.id, initial).await;
stream::iter(ids)
.for_each_concurrent(self.config.parallelism, |id| {
let pool = self.pool.clone();
let tx = tx.clone();
let msg = NetworkGet {
id: msg.id,
kind: msg.kind,
seed: msg.seed,
n: msg.n,
};
async move {
let Ok(client) = pool.client(id).await else {
return;
};
let Ok(mut rx) = client.get_all(msg.id, msg.kind, msg.seed, msg.n).await else {
return;
};
while let Ok(Some(value)) = rx.recv().await {
if tx.send((id, value)).await.is_err() {
break;
}
}
drop(client);
}
})
.await;
}
async fn query_one(&self, id: NodeId, target: Id) -> Result<Vec<NodeId>, &'static str> {
let requester = if self.config.transient {
None
} else {
Some(self.pool.id())
};
let client = self
.pool
.client(id)
.await
.map_err(|_| "Error getting client")?;
let infos = client
.find_node(target, requester)
.await
.map_err(|_| "Failed to query node");
if let Err(e) = &infos {
info!(%id, "Failed to query node: {e}");
return Err("Failed to query node");
}
let infos = infos?;
drop(client);
let ids = infos.iter().map(|info| info.node_id).collect();
for info in infos {
self.pool.add_node_addr(info);
}
Ok(ids)
}
async fn iterative_find_node(self, target: Id, initial: Vec<NodeId>) -> Vec<NodeId> {
let mut candidates = initial
.into_iter()
.filter(|addr| *addr != self.pool.id())
.map(|id| (Distance::between(&target, id.as_bytes()), id))
.collect::<BTreeSet<_>>();
let mut queried = HashSet::new();
let mut tasks = FuturesUnordered::new();
let mut result = BTreeSet::new();
queried.insert(self.pool.id());
result.insert((
Distance::between(self.pool.id().as_bytes(), &target),
self.pool.id(),
));
loop {
for _ in 0..self.config.alpha {
let Some(pair @ (_, id)) = candidates.pop_first() else {
break;
};
queried.insert(id);
let fut = self.query_one(id, target);
tasks.push(async move { (pair, fut.await) });
}
while let Some((pair @ (_, id), cands)) = tasks.next().await {
let Ok(cands) = cands else {
self.api.nodes_dead(&[id]).await.ok();
continue;
};
for cand in cands {
let dist = Distance::between(&target, cand.as_bytes());
if !queried.contains(&cand) {
candidates.insert((dist, cand));
}
}
self.api.nodes_seen(&[id]).await.ok();
result.insert(pair);
}
while result.len() > self.config.k {
result.pop_last();
}
let kth_best_distance = result
.iter()
.nth(self.config.k - 1)
.map(|(dist, _)| *dist)
.unwrap_or(Distance::MAX);
let has_closer_candidates = candidates
.first()
.map(|(dist, _)| *dist < kth_best_distance)
.unwrap_or_default();
if !has_closer_candidates {
break;
}
}
result.into_iter().map(|(_, id)| id).collect()
}
async fn notify_self(self) {
let mut self_lookup = MaybeFuture::None;
let mut random_lookup = MaybeFuture::None;
let mut candidate_lookup = MaybeFuture::None;
if let Some(strategy) = &self.config.lookup_strategies.self_id {
let api = self.api.clone();
self_lookup = MaybeFuture::Some(api.self_lookup_periodic(strategy.interval));
}
if let Some(strategy) = &self.config.lookup_strategies.random {
let api = self.api.clone();
random_lookup = MaybeFuture::Some(api.random_lookup_periodic(strategy.interval));
}
if let Some(strategy) = &self.config.lookup_strategies.candidate {
let api = self.api.clone();
candidate_lookup = MaybeFuture::Some(api.candidate_lookup_periodic(strategy.interval));
}
tokio::pin!(self_lookup, random_lookup, candidate_lookup);
loop {
tokio::select! {
_ = &mut self_lookup => {
}
_ = &mut random_lookup => {
}
_ = &mut candidate_lookup => {
}
}
}
}
}
fn blend(a: U256, b: U256, n: u32) -> U256 {
if n >= 256 {
return b;
}
let a_mask = U256::MAX << n;
let b_mask = (!a_mask) >> 1;
let xor_mask = !(a_mask | b_mask);
a & a_mask | a ^ xor_mask | b & b_mask
}
fn now() -> u64 {
UNIX_EPOCH.elapsed().unwrap().as_secs()
}
pub fn create_node<P: ClientPool>(
id: NodeId,
pool: P,
bootstrap: Vec<NodeId>,
config: Config,
) -> (RpcClient, ApiClient) {
create_node_impl(id, pool, bootstrap, None, config)
}
fn create_node_impl<P: ClientPool>(
id: NodeId,
pool: P,
bootstrap: Vec<NodeId>,
buckets: Option<Box<Buckets>>,
config: Config,
) -> (RpcClient, ApiClient) {
let mut node = Node {
routing_table: RoutingTable::new(id, buckets),
storage: MemStorage::new(),
};
for bootstrap_id in bootstrap {
if bootstrap_id != id {
node.routing_table.add_node(NodeInfo {
id: bootstrap_id,
last_seen: now(),
});
}
}
let (tx, rx) = tokio::sync::mpsc::channel(32);
let (actor, api) = Actor::<P>::new(node, rx, pool, config);
tokio::spawn(actor.run());
(RpcClient::new(irpc::Client::local(tx)), api)
}