use std::collections::{BTreeMap, HashSet};
use std::fmt::Debug;
use std::io::Write;
use std::marker::PhantomData;
use futures_util::{Sink, SinkExt, Stream, StreamExt};
use rand::{RngCore, rng};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::address_book::{AddressBookStore, NodeInfo};
use crate::traits::{DiscoveryProtocol, DiscoveryResult, LocalTopics};
const ALICE_SALT_BYTE: u8 = 0;
const BOB_SALT_BYTE: u8 = 1;
#[derive(Serialize, Deserialize)]
pub enum PsiHashMessage<ID, N>
where
N: NodeInfo<ID>,
for<'a> N::Transports: Serialize + Deserialize<'a>,
ID: Ord,
{
AliceSaltHalf { alice_salt_half: [u8; 32] },
BobSaltHalfAndHashedData {
bob_salt_half: [u8; 32],
topics_for_alice: HashSet<[u8; 32]>,
},
AliceHashedData { topics_for_bob: HashSet<[u8; 32]> },
Nodes {
transport_infos: BTreeMap<ID, N::Transports>,
},
}
#[derive(Clone, Debug, Default)]
pub struct Config {
pub share_nodes_with_common_topics: bool,
}
pub struct PsiHashDiscoveryProtocol<S, P, ID, N> {
store: S,
subscription: P,
my_node_id: ID,
remote_node_id: ID,
config: Config,
_marker: PhantomData<N>,
}
impl<S, P, ID, N> PsiHashDiscoveryProtocol<S, P, ID, N> {
pub fn new(store: S, subscription: P, my_node_id: ID, remote_node_id: ID) -> Self {
Self::with_config(
store,
subscription,
my_node_id,
remote_node_id,
Config::default(),
)
}
pub fn with_config(
store: S,
subscription: P,
my_node_id: ID,
remote_node_id: ID,
config: Config,
) -> Self {
Self {
store,
subscription,
my_node_id,
remote_node_id,
config,
_marker: PhantomData,
}
}
async fn gather_transport_infos(
&self,
topics: Vec<[u8; 32]>,
) -> Result<BTreeMap<ID, N::Transports>, PsiHashError<S, P, ID, N>>
where
S: AddressBookStore<ID, N>,
P: LocalTopics,
ID: PartialEq + Ord,
N: NodeInfo<ID>,
{
let node_infos = if self.config.share_nodes_with_common_topics {
let mut result = self
.store
.node_infos_by_topics(&topics)
.await
.map_err(PsiHashError::Store)?;
let contains_our_info = result.iter().any(|info| info.id() == self.my_node_id);
if !contains_our_info
&& let Some(my_node_info) = self
.store
.node_info(&self.my_node_id)
.await
.map_err(PsiHashError::Store)?
{
result.extend([my_node_info]);
}
result
} else {
self.store
.all_node_infos()
.await
.map_err(PsiHashError::Store)?
};
let mut map = BTreeMap::new();
for node_info in node_infos {
if let Some(transport_info) = node_info.transports() {
map.insert(node_info.id(), transport_info);
}
}
Ok(map)
}
}
impl<S, P, ID, N> DiscoveryProtocol<ID, N> for PsiHashDiscoveryProtocol<S, P, ID, N>
where
S: AddressBookStore<ID, N>,
P: LocalTopics,
ID: Clone + Ord,
N: NodeInfo<ID>,
for<'a> N::Transports: Serialize + Deserialize<'a>,
{
type Error = PsiHashError<S, P, ID, N>;
type Message = PsiHashMessage<ID, N>;
async fn alice(
&self,
tx: &mut (impl Sink<Self::Message, Error = impl Debug> + Unpin),
rx: &mut (impl Stream<Item = Result<Self::Message, impl Debug>> + Unpin),
) -> Result<DiscoveryResult<ID, N>, Self::Error> {
let alice_salt_half = generate_salt_half();
let message_1 = PsiHashMessage::AliceSaltHalf { alice_salt_half };
tx.send(message_1).await.map_err(|_| PsiHashError::Sink)?;
let message_2 = match rx.next().await {
Some(val) => val.map_err(|_| PsiHashError::Stream)?,
None => {
return Err(PsiHashError::Stream);
}
};
let PsiHashMessage::BobSaltHalfAndHashedData {
bob_salt_half,
topics_for_alice,
} = message_2
else {
return Err(PsiHashError::UnexpectedMessage);
};
let my_topics: Vec<[u8; 32]> = self
.subscription
.topics()
.await
.map_err(PsiHashError::Subscription)?
.into_iter()
.collect();
let alice_final_salt = combine_salt(&alice_salt_half, &bob_salt_half, &ALICE_SALT_BYTE);
let bob_final_salt = combine_salt(&alice_salt_half, &bob_salt_half, &BOB_SALT_BYTE);
let topics_intersection =
compute_intersection(&my_topics, &topics_for_alice, &bob_final_salt)?;
let topics_for_bob: HashSet<[u8; 32]> =
HashSet::from_iter(hash_vector(&my_topics, &alice_final_salt)?.into_iter());
tx.send(PsiHashMessage::AliceHashedData { topics_for_bob })
.await
.map_err(|_| PsiHashError::Sink)?;
let message_4 = match rx.next().await {
Some(val) => val.map_err(|_| PsiHashError::Stream)?,
None => {
return Err(PsiHashError::Stream);
}
};
let PsiHashMessage::Nodes { transport_infos } = message_4 else {
return Err(PsiHashError::UnexpectedMessage);
};
tx.send(PsiHashMessage::Nodes {
transport_infos: self
.gather_transport_infos(topics_intersection.clone().into_iter().collect::<Vec<_>>())
.await?,
})
.await
.map_err(|_| PsiHashError::Sink)?;
Ok(DiscoveryResult {
remote_node_id: self.remote_node_id.clone(),
transport_infos,
topics: topics_intersection,
})
}
async fn bob(
&self,
tx: &mut (impl Sink<Self::Message, Error = impl Debug> + Unpin),
rx: &mut (impl Stream<Item = Result<Self::Message, impl Debug>> + Unpin),
) -> Result<DiscoveryResult<ID, N>, Self::Error> {
let Some(Ok(message_1)) = rx.next().await else {
return Err(PsiHashError::Stream);
};
let PsiHashMessage::AliceSaltHalf { alice_salt_half } = message_1 else {
return Err(PsiHashError::UnexpectedMessage);
};
let bob_salt_half = generate_salt_half();
let alice_final_salt = combine_salt(&alice_salt_half, &bob_salt_half, &ALICE_SALT_BYTE);
let bob_final_salt = combine_salt(&alice_salt_half, &bob_salt_half, &BOB_SALT_BYTE);
let my_topics: Vec<[u8; 32]> = self
.subscription
.topics()
.await
.map_err(PsiHashError::Subscription)?
.into_iter()
.collect();
let topics_for_alice: HashSet<[u8; 32]> =
HashSet::from_iter(hash_vector(&my_topics, &bob_final_salt)?.into_iter());
tx.send(PsiHashMessage::BobSaltHalfAndHashedData {
bob_salt_half,
topics_for_alice,
})
.await
.map_err(|_| PsiHashError::Sink)?;
let Some(Ok(message_3)) = rx.next().await else {
return Err(PsiHashError::Stream);
};
let PsiHashMessage::AliceHashedData { topics_for_bob } = message_3 else {
return Err(PsiHashError::UnexpectedMessage);
};
let topics_intersection =
compute_intersection(&my_topics, &topics_for_bob, &alice_final_salt)?;
tx.send(PsiHashMessage::Nodes {
transport_infos: self
.gather_transport_infos(topics_intersection.clone().into_iter().collect::<Vec<_>>())
.await?,
})
.await
.map_err(|_| PsiHashError::Sink)?;
let Some(Ok(message_5)) = rx.next().await else {
return Err(PsiHashError::Stream);
};
let PsiHashMessage::Nodes { transport_infos } = message_5 else {
return Err(PsiHashError::UnexpectedMessage);
};
Ok(DiscoveryResult {
remote_node_id: self.remote_node_id.clone(),
transport_infos,
topics: topics_intersection,
})
}
}
fn compute_intersection(
local_topics: &[[u8; 32]],
remote_hashes: &HashSet<[u8; 32]>,
salt: &[u8; 65],
) -> Result<HashSet<[u8; 32]>, std::io::Error> {
let local_topics_hashed = hash_vector(local_topics, salt)?;
let mut intersection: HashSet<[u8; 32]> = HashSet::new();
for (i, local_hash) in local_topics_hashed.iter().enumerate() {
if remote_hashes.contains(local_hash) {
intersection.insert(local_topics[i]);
}
}
Ok(intersection)
}
fn hash_vector(topics: &[[u8; 32]], salt: &[u8; 65]) -> Result<Vec<[u8; 32]>, std::io::Error> {
topics.iter().map(|topic| hash(topic, salt)).collect()
}
fn hash(data: &[u8; 32], salt: &[u8; 65]) -> Result<[u8; 32], std::io::Error> {
let mut hash = blake3::Hasher::new();
hash.write_all(data)?;
hash.write_all(salt)?;
Ok(*hash.finalize().as_bytes())
}
fn generate_salt_half() -> [u8; 32] {
let mut generator = rng();
let mut random_bytes: [u8; 32] = [0; 32];
generator.fill_bytes(&mut random_bytes);
random_bytes
}
fn combine_salt(alice_salt_half: &[u8; 32], bob_salt_half: &[u8; 32], pair_byte: &u8) -> [u8; 65] {
let mut output: [u8; 65] = [0; 65];
output[0..32].copy_from_slice(alice_salt_half);
output[32..64].copy_from_slice(bob_salt_half);
output[64] = *pair_byte;
output
}
#[derive(Debug, Error)]
pub enum PsiHashError<S, P, ID, N>
where
S: AddressBookStore<ID, N>,
P: LocalTopics,
{
#[error("{0}")]
Store(S::Error),
#[error("{0}")]
Subscription(P::Error),
#[error("received unexpected message")]
UnexpectedMessage,
#[error("stream closed unexpectedly")]
Stream,
#[error("sink closed unexpectedly")]
Sink,
#[error(transparent)]
Hash(#[from] std::io::Error),
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use futures_channel::mpsc;
use futures_util::{SinkExt, StreamExt};
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
use crate::address_book::AddressBookStore;
use crate::test_utils::{TestInfo, TestStore, TestSubscription};
use crate::traits::DiscoveryProtocol;
use super::{Config, PsiHashDiscoveryProtocol, PsiHashError, PsiHashMessage};
#[tokio::test]
async fn topic_discovery() {
let rng = ChaCha20Rng::from_seed([1; 32]);
let mut alice_subscription = TestSubscription::default();
alice_subscription.topics.insert([1; 32]);
alice_subscription.topics.insert([2; 32]);
alice_subscription.topics.insert([98; 32]);
alice_subscription.topics.insert([99; 32]);
let alice_store = TestStore::new(rng.clone());
let mut bob_subscription = TestSubscription::default();
bob_subscription.topics.insert([2; 32]);
bob_subscription.topics.insert([3; 32]);
bob_subscription.topics.insert([99; 32]);
bob_subscription.topics.insert([100; 32]);
let bob_store = TestStore::new(rng.clone());
let alice_protocol = PsiHashDiscoveryProtocol::new(alice_store, alice_subscription, 0, 1);
let bob_protocol = PsiHashDiscoveryProtocol::new(bob_store, bob_subscription, 1, 0);
let (mut alice_tx, alice_rx) = mpsc::channel(16);
let (mut bob_tx, bob_rx) = mpsc::channel(16);
let bob_handle = tokio::task::spawn(async move {
let mut alice_rx = alice_rx.map(|message| Ok::<_, ()>(message));
let Ok(result) = bob_protocol.bob(&mut bob_tx, &mut alice_rx).await else {
panic!("running bob protocol failed");
};
result
});
let mut bob_rx = bob_rx.map(|message| Ok::<_, ()>(message));
let Ok(alice_result) = alice_protocol.alice(&mut alice_tx, &mut bob_rx).await else {
panic!("running alice protocol failed");
};
let bob_result = bob_handle.await.expect("local task failure");
let expected = HashSet::from_iter([[2; 32], [99; 32]]);
assert_eq!(alice_result.topics, expected);
assert_eq!(bob_result.topics, expected);
}
#[tokio::test]
async fn topic_out_of_order_alice() {
let rng = ChaCha20Rng::from_seed([1; 32]);
let mut alice_subscription = TestSubscription::default();
alice_subscription.topics.insert([1; 32]);
alice_subscription.topics.insert([99; 32]);
let alice_store = TestStore::new(rng.clone());
let alice_protocol = PsiHashDiscoveryProtocol::new(alice_store, alice_subscription, 0, 1);
let (mut alice_tx, _alice_rx) = mpsc::channel(16);
let (mut bob_tx, bob_rx) = mpsc::channel(16);
let bob_handle = tokio::task::spawn(async move {
let _result = bob_tx
.send(PsiHashMessage::AliceSaltHalf {
alice_salt_half: [0; 32],
})
.await;
});
let mut bob_rx = bob_rx.map(|message| Ok::<_, ()>(message));
let alice_result = alice_protocol.alice(&mut alice_tx, &mut bob_rx).await;
let _bob_result = bob_handle.await;
assert!(matches!(alice_result, Err(PsiHashError::UnexpectedMessage)));
}
#[tokio::test]
async fn topic_out_of_order_bob() {
let rng = ChaCha20Rng::from_seed([1; 32]);
let mut bob_subscription = TestSubscription::default();
bob_subscription.topics.insert([1; 32]);
bob_subscription.topics.insert([99; 32]);
let bob_store = TestStore::new(rng.clone());
let bob_protocol = PsiHashDiscoveryProtocol::new(bob_store, bob_subscription, 0, 1);
let (mut bob_tx, _) = mpsc::channel(16);
let (mut alice_tx, alice_rx) = mpsc::channel(16);
tokio::task::spawn(async move {
let _result = alice_tx
.send(PsiHashMessage::AliceHashedData {
topics_for_bob: HashSet::new(),
})
.await;
});
let mut alice_rx = alice_rx.map(|message| Ok::<_, ()>(message));
let bob_result = bob_protocol.bob(&mut bob_tx, &mut alice_rx).await;
assert!(matches!(bob_result, Err(PsiHashError::UnexpectedMessage)));
}
#[tokio::test]
async fn transport_info() {
let mut rng = ChaCha20Rng::from_seed([1; 32]);
let mut alice_subscription = TestSubscription::default();
alice_subscription.topics.insert([1; 32]);
let alice_store = TestStore::new(rng.clone());
alice_store
.insert_node_info(TestInfo::new(0).with_random_address(&mut rng))
.await
.unwrap();
alice_store
.set_topics(0, HashSet::from_iter([[1; 32]]))
.await
.unwrap();
let mut bob_subscription = TestSubscription::default();
bob_subscription.topics.insert([1; 32]);
bob_subscription.topics.insert([2; 32]);
let bob_store = TestStore::new(rng.clone());
bob_store
.insert_node_info(TestInfo::new(1).with_random_address(&mut rng))
.await
.unwrap();
bob_store
.set_topics(1, HashSet::from_iter([[1; 32], [2; 32]]))
.await
.unwrap();
bob_store
.insert_node_info(TestInfo::new(2).with_random_address(&mut rng))
.await
.unwrap();
bob_store
.set_topics(2, HashSet::from_iter([[1; 32]]))
.await
.unwrap();
bob_store
.insert_node_info(TestInfo::new(3).with_random_address(&mut rng))
.await
.unwrap();
bob_store
.set_topics(3, HashSet::from_iter([[2; 32]]))
.await
.unwrap();
let config = Config {
share_nodes_with_common_topics: true,
};
let alice_protocol = PsiHashDiscoveryProtocol::with_config(
alice_store,
alice_subscription,
0,
1,
config.clone(),
);
let bob_protocol =
PsiHashDiscoveryProtocol::with_config(bob_store, bob_subscription, 1, 0, config);
let (mut alice_tx, alice_rx) = mpsc::channel(16);
let (mut bob_tx, bob_rx) = mpsc::channel(16);
let bob_handle = tokio::task::spawn(async move {
let mut alice_rx = alice_rx.map(|message| Ok::<_, ()>(message));
let Ok(result) = bob_protocol.bob(&mut bob_tx, &mut alice_rx).await else {
panic!("running bob protocol failed");
};
result
});
let mut bob_rx = bob_rx.map(|message| Ok::<_, ()>(message));
let Ok(alice_result) = alice_protocol.alice(&mut alice_tx, &mut bob_rx).await else {
panic!("running alice protocol failed");
};
let bob_result = bob_handle.await.expect("local task failure");
assert!(alice_result.transport_infos.contains_key(&1)); assert!(alice_result.transport_infos.contains_key(&2)); assert_eq!(alice_result.transport_infos.len(), 2);
assert!(!alice_result.transport_infos.contains_key(&3));
assert!(bob_result.transport_infos.contains_key(&0)); assert_eq!(bob_result.transport_infos.len(), 1);
}
}