use crate::error::{Error, Result};
use crate::XorName;
use bytes::Bytes;
use futures::stream::{FuturesUnordered, StreamExt};
use hex_fmt::HexFmt;
use qp2p::{Endpoint, QuicP2p};
use sn_messaging::MessageType;
use std::{
fmt::{self, Debug, Formatter},
net::SocketAddr,
sync::RwLock,
};
use tokio::{sync::mpsc, task};
pub(crate) struct Comm {
_quic_p2p: QuicP2p,
endpoint: Endpoint,
event_tx: RwLock<Option<mpsc::Sender<ConnectionEvent>>>,
}
impl Comm {
pub async fn new(
transport_config: qp2p::Config,
event_tx: mpsc::Sender<ConnectionEvent>,
) -> Result<Self> {
let quic_p2p = QuicP2p::with_config(Some(transport_config), &[], true)
.map_err(|err| Error::InvalidConfig { err })?;
let (endpoint, _incoming_connections, incoming_messages, disconnections) = quic_p2p
.new_endpoint()
.await
.map_err(|err| Error::CannotConnectEndpoint { err })?;
let _ = task::spawn(handle_incoming_messages(
incoming_messages,
event_tx.clone(),
));
let _ = task::spawn(handle_disconnection_events(
disconnections,
event_tx.clone(),
));
Ok(Self {
_quic_p2p: quic_p2p,
endpoint,
event_tx: RwLock::new(Some(event_tx)),
})
}
pub async fn bootstrap(
transport_config: qp2p::Config,
event_tx: mpsc::Sender<ConnectionEvent>,
) -> Result<(Self, SocketAddr)> {
let quic_p2p = QuicP2p::with_config(Some(transport_config), &[], true)
.map_err(|err| Error::InvalidConfig { err })?;
let (endpoint, _incoming_connections, incoming_messages, disconnections, bootstrap_addr) =
quic_p2p
.bootstrap()
.await
.map_err(|err| Error::CannotConnectEndpoint { err })?;
let _ = task::spawn(handle_incoming_messages(
incoming_messages,
event_tx.clone(),
));
let _ = task::spawn(handle_disconnection_events(
disconnections,
event_tx.clone(),
));
Ok((
Self {
_quic_p2p: quic_p2p,
endpoint,
event_tx: RwLock::new(Some(event_tx)),
},
bootstrap_addr,
))
}
pub fn terminate(&self) {
self.endpoint.close();
let _ = self
.event_tx
.write()
.unwrap_or_else(|err| err.into_inner())
.take();
}
pub fn our_connection_info(&self) -> SocketAddr {
self.endpoint.socket_addr()
}
pub async fn send_on_existing_connection(
&self,
recipient: (XorName, SocketAddr),
mut msg: MessageType,
) -> Result<(), Error> {
msg.update_dest_info(None, Some(recipient.0));
let bytes = msg.serialize()?;
self.endpoint
.send_message(bytes, &recipient.1)
.await
.map_err(|err| {
error!("Sending to {:?} failed with {}", recipient, err);
Error::FailedSend(recipient.1, recipient.0)
})?;
Ok(())
}
pub async fn is_reachable(&self, peer: &SocketAddr) -> Result<(), Error> {
let qp2p_config = qp2p::Config {
local_ip: Some(self.endpoint.local_addr().ip()),
local_port: Some(0),
forward_port: false,
..Default::default()
};
let qp2p = QuicP2p::with_config(Some(qp2p_config), &[], false)
.map_err(|err| Error::InvalidConfig { err })?;
let (connectivity_endpoint, _, _, _) = qp2p
.new_endpoint()
.await
.map_err(|err| Error::CannotConnectEndpoint { err })?;
connectivity_endpoint
.is_reachable(peer)
.await
.map_err(|err| {
info!("Peer {} is NOT externally reachable: {}", peer, err);
Error::AddressNotReachable { err }
})
.map(|()| {
info!("Peer {} is externally reachable.", peer);
})
}
pub async fn send(
&self,
recipients: &[(XorName, SocketAddr)],
delivery_group_size: usize,
mut msg: MessageType,
) -> Result<SendStatus> {
trace!(
"Sending message to {} of {:?}",
delivery_group_size,
recipients
);
if recipients.len() < delivery_group_size {
warn!(
"Less than delivery_group_size valid recipients - delivery_group_size: {}, recipients: {:?}",
delivery_group_size,
recipients,
);
}
let delivery_group_size = delivery_group_size.min(recipients.len());
if recipients.is_empty() {
return Err(Error::EmptyRecipientList);
}
msg.update_dest_info(None, Some(recipients[0].0));
let msg_bytes = msg.serialize().map_err(Error::Messaging)?;
let send = |recipient: (XorName, SocketAddr), msg_bytes: Bytes| async move {
trace!(
"Sending message ({} bytes) to {} of {:?}",
msg_bytes.len(),
delivery_group_size,
recipient.1
);
let result = self
.send_to(&recipient.1, msg_bytes)
.await
.map_err(|err| match err {
qp2p::Error::Connection(qp2p::ConnectionError::LocallyClosed) => {
Error::ConnectionClosed
}
_ => {
trace!("during sending, received error {:?}", err);
Error::AddressNotReachable { err }
}
});
(result, recipient.1)
};
let mut tasks: FuturesUnordered<_> = recipients[0..delivery_group_size]
.iter()
.map(|(name, recipient)| send((*name, *recipient), msg_bytes.clone()))
.collect();
let mut next = delivery_group_size;
let mut successes = 0;
let mut failed_recipients = vec![];
while let Some((result, addr)) = tasks.next().await {
match result {
Ok(()) => successes += 1,
Err(Error::ConnectionClosed) => {
return Err(Error::ConnectionClosed);
}
Err(_) => {
failed_recipients.push(addr);
if next < recipients.len() {
tasks.push(send(recipients[next], msg_bytes.clone()));
next += 1;
}
}
}
}
trace!(
"Sending message {:?} finished to {}/{} recipients (failed: {:?})",
msg,
successes,
delivery_group_size,
failed_recipients
);
if successes == delivery_group_size {
if failed_recipients.is_empty() {
Ok(SendStatus::AllRecipients)
} else {
Ok(SendStatus::MinDeliveryGroupSizeReached(failed_recipients))
}
} else {
Ok(SendStatus::MinDeliveryGroupSizeFailed(failed_recipients))
}
}
async fn send_to(&self, recipient: &SocketAddr, msg: Bytes) -> Result<(), qp2p::Error> {
trace!("Low level send for msg over qp2p");
if self
.endpoint
.send_message(msg.clone(), recipient)
.await
.is_ok()
{
return Ok(());
}
self.endpoint.connect_to(recipient).await?;
self.endpoint.send_message(msg, recipient).await
}
}
impl Drop for Comm {
fn drop(&mut self) {
self.endpoint.close()
}
}
pub(crate) enum ConnectionEvent {
Received((SocketAddr, Bytes)),
Disconnected(SocketAddr),
}
impl Debug for ConnectionEvent {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
Self::Received((src, msg)) => write!(f, "Received(src: {}, msg: {})", src, HexFmt(msg)),
Self::Disconnected(addr) => write!(f, "Disconnected({})", addr),
}
}
}
async fn handle_disconnection_events(
mut disconnections: qp2p::DisconnectionEvents,
event_tx: mpsc::Sender<ConnectionEvent>,
) {
while let Some(peer_addr) = disconnections.next().await {
let _ = event_tx
.send(ConnectionEvent::Disconnected(peer_addr))
.await;
}
}
async fn handle_incoming_messages(
mut incoming_msgs: qp2p::IncomingMessages,
event_tx: mpsc::Sender<ConnectionEvent>,
) {
while let Some((src, msg)) = incoming_msgs.next().await {
let _ = event_tx.send(ConnectionEvent::Received((src, msg))).await;
}
}
#[derive(Debug, Clone)]
pub enum SendStatus {
AllRecipients,
MinDeliveryGroupSizeReached(Vec<SocketAddr>),
MinDeliveryGroupSizeFailed(Vec<SocketAddr>),
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
use assert_matches::assert_matches;
use futures::future;
use qp2p::Config;
use sn_data_types::PublicKey;
use sn_messaging::{section_info::SectionInfoMsg, DestInfo, WireMsg};
use std::{net::Ipv4Addr, slice, time::Duration};
use tokio::{net::UdpSocket, sync::mpsc, time};
const TIMEOUT: Duration = Duration::from_secs(1);
#[tokio::test]
async fn successful_send() -> Result<()> {
let (tx, _rx) = mpsc::channel(1);
let comm = Comm::new(transport_config(), tx).await?;
let mut peer0 = Peer::new().await?;
let mut peer1 = Peer::new().await?;
let mut original_message = new_section_info_message();
let status = comm
.send(
&[(peer0._name, peer0.addr), (peer1._name, peer1.addr)],
2,
original_message.clone(),
)
.await?;
assert_matches!(status, SendStatus::AllRecipients);
if let Some(bytes) = peer0.rx.recv().await {
original_message.update_dest_info(None, Some(peer0._name));
assert_eq!(WireMsg::deserialize(bytes)?, original_message.clone());
}
if let Some(bytes) = peer1.rx.recv().await {
assert_eq!(WireMsg::deserialize(bytes)?, original_message);
}
Ok(())
}
#[tokio::test]
async fn successful_send_to_subset() -> Result<()> {
let (tx, _rx) = mpsc::channel(1);
let comm = Comm::new(transport_config(), tx).await?;
let mut peer0 = Peer::new().await?;
let mut peer1 = Peer::new().await?;
let mut original_message = new_section_info_message();
let status = comm
.send(
&[(peer0._name, peer0.addr), (peer1._name, peer1.addr)],
1,
original_message.clone(),
)
.await?;
assert_matches!(status, SendStatus::AllRecipients);
if let Some(bytes) = peer0.rx.recv().await {
original_message.update_dest_info(None, Some(peer0._name));
assert_eq!(WireMsg::deserialize(bytes)?, original_message);
}
assert!(time::timeout(TIMEOUT, peer1.rx.recv())
.await
.unwrap_or_default()
.is_none());
Ok(())
}
#[tokio::test]
async fn failed_send() -> Result<()> {
let (tx, _rx) = mpsc::channel(1);
let comm = Comm::new(
Config {
idle_timeout_msec: Some(1),
..transport_config()
},
tx,
)
.await?;
let invalid_addr = get_invalid_addr().await?;
let status = comm
.send(
&[(XorName::random(), invalid_addr)],
1,
new_section_info_message(),
)
.await?;
assert_matches!(
&status,
&SendStatus::MinDeliveryGroupSizeFailed(_) => vec![invalid_addr]
);
Ok(())
}
#[tokio::test]
async fn successful_send_after_failed_attempts() -> Result<()> {
let (tx, _rx) = mpsc::channel(1);
let comm = Comm::new(
Config {
idle_timeout_msec: Some(1),
..transport_config()
},
tx,
)
.await?;
let mut peer = Peer::new().await?;
let invalid_addr = get_invalid_addr().await?;
let name = XorName::random();
let mut message = new_section_info_message();
let _ = comm
.send(
&[(name, invalid_addr), (peer._name, peer.addr)],
1,
message.clone(),
)
.await?;
if let Some(bytes) = peer.rx.recv().await {
message.update_dest_info(None, Some(name));
assert_eq!(WireMsg::deserialize(bytes)?, message);
}
Ok(())
}
#[tokio::test]
async fn partially_successful_send() -> Result<()> {
let (tx, _rx) = mpsc::channel(1);
let comm = Comm::new(
Config {
idle_timeout_msec: Some(1),
..transport_config()
},
tx,
)
.await?;
let mut peer = Peer::new().await?;
let invalid_addr = get_invalid_addr().await?;
let name = XorName::random();
let mut message = new_section_info_message();
let status = comm
.send(
&[(name, invalid_addr), (peer._name, peer.addr)],
2,
message.clone(),
)
.await?;
assert_matches!(
status,
SendStatus::MinDeliveryGroupSizeFailed(_) => vec![invalid_addr]
);
if let Some(bytes) = peer.rx.recv().await {
message.update_dest_info(None, Some(name));
assert_eq!(WireMsg::deserialize(bytes)?, message);
}
Ok(())
}
#[tokio::test]
async fn send_after_reconnect() -> Result<()> {
let (tx, _rx) = mpsc::channel(1);
let send_comm = Comm::new(transport_config(), tx).await?;
let recv_transport = QuicP2p::with_config(Some(transport_config()), &[], false)?;
let (recv_endpoint, _, mut incoming_msgs, _) = recv_transport.new_endpoint().await?;
let recv_addr = recv_endpoint.socket_addr();
let name = XorName::random();
let key0 = bls::SecretKey::random().public_key();
let msg0 = MessageType::SectionInfo {
msg: SectionInfoMsg::GetSectionQuery(PublicKey::Bls(key0)),
dest_info: DestInfo {
dest: name,
dest_section_pk: key0,
},
};
let _ = send_comm
.send(slice::from_ref(&(name, recv_addr)), 1, msg0.clone())
.await?;
let mut msg0_received = false;
{
if let Some((src, msg)) = time::timeout(TIMEOUT, incoming_msgs.next()).await? {
assert_eq!(WireMsg::deserialize(msg)?, msg0);
msg0_received = true;
recv_endpoint.disconnect_from(&src).await?;
}
assert!(msg0_received);
}
let key1 = bls::SecretKey::random().public_key();
let msg1 = MessageType::SectionInfo {
msg: SectionInfoMsg::GetSectionQuery(PublicKey::Bls(key1)),
dest_info: DestInfo {
dest: name,
dest_section_pk: key1,
},
};
let _ = send_comm
.send(slice::from_ref(&(name, recv_addr)), 1, msg1.clone())
.await?;
let mut msg1_received = false;
if let Some((_src, msg)) = time::timeout(TIMEOUT, incoming_msgs.next()).await? {
assert_eq!(WireMsg::deserialize(msg)?, msg1);
msg1_received = true;
}
assert!(msg1_received);
Ok(())
}
#[tokio::test]
async fn incoming_connection_lost() -> Result<()> {
let (tx, mut rx0) = mpsc::channel(1);
let comm0 = Comm::new(transport_config(), tx).await?;
let addr0 = comm0.our_connection_info();
let (tx, _rx) = mpsc::channel(1);
let comm1 = Comm::new(transport_config(), tx).await?;
let addr1 = comm1.our_connection_info();
let _ = comm1
.send(
slice::from_ref(&(XorName::random(), addr0)),
1,
new_section_info_message(),
)
.await?;
assert_matches!(rx0.recv().await, Some(ConnectionEvent::Received(_)));
drop(comm1);
assert_matches!(
time::timeout(TIMEOUT, rx0.recv()).await?,
Some(ConnectionEvent::Disconnected(addr)) => assert_eq!(addr, addr1)
);
Ok(())
}
fn transport_config() -> Config {
Config {
local_ip: Some(Ipv4Addr::LOCALHOST.into()),
..Default::default()
}
}
fn new_section_info_message() -> MessageType {
let random_bls_pk = bls::SecretKey::random().public_key();
MessageType::SectionInfo {
msg: SectionInfoMsg::GetSectionQuery(PublicKey::Bls(random_bls_pk)),
dest_info: DestInfo {
dest: XorName::random(),
dest_section_pk: bls::SecretKey::random().public_key(),
},
}
}
struct Peer {
addr: SocketAddr,
_incoming_connections: qp2p::IncomingConnections,
_disconnections: qp2p::DisconnectionEvents,
_name: XorName,
rx: mpsc::Receiver<Bytes>,
}
impl Peer {
async fn new() -> Result<Self> {
let transport = QuicP2p::with_config(Some(transport_config()), &[], false)?;
let (endpoint, incoming_connections, mut incoming_messages, disconnections) =
transport.new_endpoint().await?;
let addr = endpoint.socket_addr();
let (tx, rx) = mpsc::channel(1);
let _ = tokio::spawn(async move {
while let Some((_src, msg)) = incoming_messages.next().await {
let _ = tx.send(msg).await;
}
});
Ok(Self {
addr,
rx,
_incoming_connections: incoming_connections,
_disconnections: disconnections,
_name: XorName::random(),
})
}
}
async fn get_invalid_addr() -> Result<SocketAddr> {
let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await?;
let addr = socket.local_addr()?;
let _ = tokio::spawn(async move {
future::pending::<()>().await;
let _ = socket;
});
Ok(addr)
}
}