use std::future::Future;
use std::net::SocketAddr;
use either::Either;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::{self, Receiver, Sender};
use crate::message::{NetMessage, NodeEvent};
pub(crate) mod broadcast_queue;
mod handshake;
pub(crate) mod in_memory;
pub(crate) mod p2p_protoc;
pub(crate) mod priority_select;
pub use in_memory::{FaultInjectorState, NetworkStats, get_fault_injector, set_fault_injector};
pub use p2p_protoc::EventLoopExitReason;
pub(crate) type ConnResult<T> = std::result::Result<T, ConnectionError>;
pub(crate) trait NetworkBridge: Send + Sync {
#[allow(dead_code)]
fn drop_connection(
&mut self,
peer_addr: SocketAddr,
) -> impl Future<Output = ConnResult<()>> + Send;
fn send(
&self,
target_addr: SocketAddr,
msg: NetMessage,
) -> impl Future<Output = ConnResult<()>> + Send;
fn send_stream(
&self,
target_addr: SocketAddr,
stream_id: crate::transport::peer_connection::StreamId,
data: bytes::Bytes,
metadata: Option<bytes::Bytes>,
) -> impl Future<Output = ConnResult<()>> + Send;
fn pipe_stream(
&self,
target_addr: SocketAddr,
outbound_stream_id: crate::transport::peer_connection::StreamId,
inbound_handle: crate::transport::peer_connection::streaming::StreamHandle,
metadata: Option<bytes::Bytes>,
) -> impl Future<Output = ConnResult<()>> + Send;
}
#[derive(Debug, thiserror::Error, Serialize, Deserialize)]
pub(crate) enum ConnectionError {
#[error("location unknown for this node")]
LocationUnknown,
#[error("unable to send message to {0}")]
SendNotCompleted(SocketAddr),
#[error("Unexpected connection req")]
UnexpectedReq,
#[error("error while de/serializing message")]
#[serde(skip)]
Serialization(#[from] Option<Box<bincode::ErrorKind>>),
#[error("{0}")]
TransportError(String),
#[error("failed connect")]
FailedConnectOp,
#[error("unwanted connection")]
UnwantedConnection,
#[error("connection to/from address {0} blocked by local policy")]
AddressBlocked(std::net::SocketAddr),
#[error("IO error: {0}")]
IOError(String),
#[error("timeout error while waiting for a message")]
Timeout,
}
impl From<std::io::Error> for ConnectionError {
fn from(err: std::io::Error) -> Self {
Self::IOError(format!("{err}"))
}
}
impl From<crate::transport::TransportError> for ConnectionError {
fn from(err: crate::transport::TransportError) -> Self {
Self::TransportError(err.to_string())
}
}
impl Clone for ConnectionError {
fn clone(&self) -> Self {
match self {
Self::LocationUnknown => Self::LocationUnknown,
Self::Serialization(_) => Self::Serialization(None),
Self::SendNotCompleted(addr) => Self::SendNotCompleted(*addr),
Self::IOError(err) => Self::IOError(err.clone()),
Self::Timeout => Self::Timeout,
Self::UnexpectedReq => Self::UnexpectedReq,
Self::TransportError(err) => Self::TransportError(err.clone()),
Self::FailedConnectOp => Self::FailedConnectOp,
Self::UnwantedConnection => Self::UnwantedConnection,
Self::AddressBlocked(addr) => Self::AddressBlocked(*addr),
}
}
}
use std::cell::Cell;
const CHANNEL_ID_BLOCK: u64 = 1_000_000;
thread_local! {
static CHANNEL_ID_COUNTER: Cell<u64> = {
let idx = crate::config::GlobalRng::thread_index();
Cell::new(idx * CHANNEL_ID_BLOCK)
};
}
pub fn reset_channel_id_counter() {
let idx = crate::config::GlobalRng::thread_index();
CHANNEL_ID_COUNTER.with(|c| c.set(idx * CHANNEL_ID_BLOCK));
}
const EVENT_LOOP_CHANNEL_CAPACITY: usize = 2048;
pub(crate) fn event_loop_notification_channel()
-> (EventLoopNotificationsReceiver, EventLoopNotificationsSender) {
let _channel_id = CHANNEL_ID_COUNTER.with(|c| {
let v = c.get();
c.set(v + 1);
v
});
let (notification_tx, notification_rx) = mpsc::channel(EVENT_LOOP_CHANNEL_CAPACITY);
let (op_execution_tx, op_execution_rx) = mpsc::channel(EVENT_LOOP_CHANNEL_CAPACITY);
tracing::info!(
channel_id = _channel_id,
"Created event loop notification channel pair"
);
(
EventLoopNotificationsReceiver {
notifications_receiver: notification_rx,
op_execution_receiver: op_execution_rx,
},
EventLoopNotificationsSender {
notifications_sender: notification_tx,
op_execution_sender: op_execution_tx,
},
)
}
pub(crate) type OpExecutionPayload = (Sender<NetMessage>, NetMessage, Option<SocketAddr>);
pub(crate) struct EventLoopNotificationsReceiver {
pub(crate) notifications_receiver: Receiver<Either<NetMessage, NodeEvent>>,
pub(crate) op_execution_receiver: Receiver<OpExecutionPayload>,
}
#[derive(Clone)]
pub(crate) struct EventLoopNotificationsSender {
pub(crate) notifications_sender: Sender<Either<NetMessage, NodeEvent>>,
pub(crate) op_execution_sender: Sender<OpExecutionPayload>,
}
impl EventLoopNotificationsSender {
pub(crate) fn notifications_sender(&self) -> &Sender<Either<NetMessage, NodeEvent>> {
&self.notifications_sender
}
pub(crate) fn notification_channel_pending(&self) -> usize {
self.notifications_sender
.max_capacity()
.saturating_sub(self.notifications_sender.capacity())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::GlobalExecutor;
use either::Either;
use freenet_stdlib::prelude::*;
use tokio::time::{Duration, timeout};
#[tokio::test]
async fn test_notification_channel_with_biased_select() {
let (notification_channel, notification_tx) = event_loop_notification_channel();
let mut rx = notification_channel.notifications_receiver;
let test_event = crate::message::NodeEvent::Disconnect { cause: None };
let sender = notification_tx.clone();
GlobalExecutor::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
tracing::info!("Sending notification");
sender
.notifications_sender()
.send(Either::Right(test_event))
.await
.expect("Failed to send notification");
tracing::info!("Notification sent successfully");
});
let (_dummy_tx, mut dummy_rx) = tokio::sync::mpsc::channel::<String>(10);
let mut received = false;
tracing::info!("Starting event loop simulation");
for i in 0..50 {
tracing::debug!("Loop iteration {}", i);
let result = timeout(Duration::from_millis(100), async {
tokio::select! {
biased;
msg = rx.recv() => {
tracing::info!("Received notification: {:?}", msg);
Some(msg)
}
_ = dummy_rx.recv() => {
tracing::debug!("Received dummy message");
None
}
}
})
.await;
match result {
Ok(Some(Some(_msg))) => {
tracing::info!("Successfully received notification!");
received = true;
break;
}
Ok(Some(None)) => {
tracing::error!("Channel closed unexpectedly");
break;
}
Ok(None) => {
tracing::debug!("Dummy channel activity");
}
Err(_) => {
tracing::debug!("Timeout, continuing...");
}
}
}
assert!(received, "Notification was never received by event loop");
tracing::info!("Test passed!");
}
#[tokio::test]
async fn test_multiple_notifications() {
let (notification_channel, notification_tx) = event_loop_notification_channel();
let mut rx = notification_channel.notifications_receiver;
for _i in 0..3 {
let test_event = crate::message::NodeEvent::Disconnect { cause: None };
notification_tx
.notifications_sender()
.send(Either::Right(test_event))
.await
.expect("Failed to send notification");
}
let mut count = 0;
while count < 3 {
match timeout(Duration::from_secs(1), rx.recv()).await {
Ok(Some(_)) => count += 1,
Ok(None) => panic!("Channel closed unexpectedly"),
Err(_) => panic!("Timeout waiting for notification {}", count + 1),
}
}
assert_eq!(count, 3, "Should receive all 3 notifications");
}
#[tokio::test]
async fn test_send_fails_when_receiver_dropped() {
let (notification_channel, notification_tx) = event_loop_notification_channel();
drop(notification_channel);
let test_event = crate::message::NodeEvent::Disconnect { cause: None };
let result = notification_tx
.notifications_sender()
.send(Either::Right(test_event))
.await;
assert!(result.is_err(), "Send should fail when receiver is dropped");
}
#[tokio::test]
async fn test_channel_capacity() {
let (notification_channel, notification_tx) = event_loop_notification_channel();
let mut rx = notification_channel.notifications_receiver;
for _ in 0..50 {
let test_event = crate::message::NodeEvent::Disconnect { cause: None };
notification_tx
.notifications_sender()
.send(Either::Right(test_event))
.await
.expect("Should not block with capacity of 100");
}
let mut count = 0;
while count < 50 {
match timeout(Duration::from_millis(10), rx.recv()).await {
Ok(Some(_)) => count += 1,
_ => break,
}
}
assert_eq!(count, 50, "Should receive all 50 messages");
}
#[tokio::test]
async fn test_sender_clone() {
let (notification_channel, notification_tx) = event_loop_notification_channel();
let mut rx = notification_channel.notifications_receiver;
let cloned_tx = notification_tx.clone();
let test_event1 = crate::message::NodeEvent::Disconnect { cause: None };
notification_tx
.notifications_sender()
.send(Either::Right(test_event1))
.await
.expect("Should send from original");
let test_event2 = crate::message::NodeEvent::Disconnect {
cause: Some("cloned".into()),
};
cloned_tx
.notifications_sender()
.send(Either::Right(test_event2))
.await
.expect("Should send from clone");
let mut received = 0;
for _ in 0..2 {
if timeout(Duration::from_millis(100), rx.recv()).await.is_ok() {
received += 1;
}
}
assert_eq!(received, 2, "Should receive both messages");
}
}
#[cfg(test)]
mod connection_error_tests {
use super::*;
#[test]
fn test_connection_error_clone() {
let errors = vec![
ConnectionError::LocationUnknown,
ConnectionError::SendNotCompleted("127.0.0.1:8080".parse().unwrap()),
ConnectionError::UnexpectedReq,
ConnectionError::Serialization(None),
ConnectionError::TransportError("test error".to_string()),
ConnectionError::FailedConnectOp,
ConnectionError::UnwantedConnection,
ConnectionError::AddressBlocked("127.0.0.1:8080".parse().unwrap()),
ConnectionError::IOError("io error".to_string()),
ConnectionError::Timeout,
];
for error in errors {
let cloned = error.clone();
assert_eq!(format!("{}", error), format!("{}", cloned));
}
}
#[test]
fn test_connection_error_from_io_error() {
let io_error = std::io::Error::other("test io error");
let conn_error: ConnectionError = io_error.into();
match conn_error {
ConnectionError::IOError(msg) => {
assert!(msg.contains("test io error"));
}
ConnectionError::LocationUnknown
| ConnectionError::SendNotCompleted(_)
| ConnectionError::UnexpectedReq
| ConnectionError::Serialization(_)
| ConnectionError::TransportError(_)
| ConnectionError::FailedConnectOp
| ConnectionError::UnwantedConnection
| ConnectionError::AddressBlocked(_)
| ConnectionError::Timeout => panic!("Expected IOError variant"),
}
}
#[test]
fn test_connection_error_display() {
let error = ConnectionError::LocationUnknown;
let display = format!("{}", error);
assert!(!display.is_empty());
assert!(display.contains("location unknown"));
let error = ConnectionError::SendNotCompleted("127.0.0.1:8080".parse().unwrap());
let display = format!("{}", error);
assert!(display.contains("127.0.0.1:8080"));
let error = ConnectionError::AddressBlocked("192.168.1.1:9000".parse().unwrap());
let display = format!("{}", error);
assert!(display.contains("192.168.1.1:9000"));
let error = ConnectionError::Timeout;
let display = format!("{}", error);
assert!(display.contains("timeout"));
}
#[test]
fn test_serialization_error_clone_loses_inner() {
let inner = Box::new(bincode::ErrorKind::SizeLimit);
let original = ConnectionError::Serialization(Some(inner));
let cloned = original.clone();
match cloned {
ConnectionError::Serialization(None) => {} ConnectionError::LocationUnknown
| ConnectionError::SendNotCompleted(_)
| ConnectionError::UnexpectedReq
| ConnectionError::Serialization(_)
| ConnectionError::TransportError(_)
| ConnectionError::FailedConnectOp
| ConnectionError::UnwantedConnection
| ConnectionError::AddressBlocked(_)
| ConnectionError::IOError(_)
| ConnectionError::Timeout => panic!("Expected Serialization(None) after clone"),
}
}
#[test]
fn test_connection_error_debug() {
let error = ConnectionError::FailedConnectOp;
let debug = format!("{:?}", error);
assert!(!debug.is_empty());
assert!(debug.contains("FailedConnectOp"));
}
}