use std::collections::HashMap;
use std::sync::mpsc::{channel, Receiver, Sender};
use std::thread;
use crate::network::dispatch::DispatchMessageSender;
use crate::protos::network::{NetworkMessage, NetworkMessageType};
use crate::transport::matrix::{
ConnectionMatrixReceiver, ConnectionMatrixRecvError, ConnectionMatrixSender,
};
use super::connector::{PeerLookup, PeerLookupProvider};
use super::error::PeerInterconnectError;
#[derive(Clone, Debug, PartialEq)]
pub(crate) enum SendRequest {
Shutdown,
Message { recipient: String, payload: Vec<u8> },
}
#[derive(Clone)]
pub struct NetworkMessageSender {
sender: Sender<SendRequest>,
}
impl NetworkMessageSender {
pub(crate) fn new(sender: Sender<SendRequest>) -> Self {
NetworkMessageSender { sender }
}
pub fn send(&self, recipient: String, payload: Vec<u8>) -> Result<(), (String, Vec<u8>)> {
self.sender
.send(SendRequest::Message { recipient, payload })
.map_err(|err| match err.0 {
SendRequest::Message { recipient, payload } => (recipient, payload),
SendRequest::Shutdown => unreachable!(), })
}
}
pub struct PeerInterconnect {
dispatched_sender: Sender<SendRequest>,
recv_join_handle: thread::JoinHandle<()>,
send_join_handle: thread::JoinHandle<()>,
shutdown_signaler: ShutdownSignaler,
}
impl PeerInterconnect {
pub fn new_network_sender(&self) -> NetworkMessageSender {
NetworkMessageSender::new(self.dispatched_sender.clone())
}
pub fn shutdown_handle(&self) -> ShutdownHandle {
ShutdownHandle::from(self.shutdown_signaler.clone())
}
pub fn shutdown_signaler(&self) -> ShutdownSignaler {
self.shutdown_signaler.clone()
}
pub fn await_shutdown(self) {
debug!("Shutting down peer interconnect receiver...");
if let Err(err) = self.send_join_handle.join() {
error!(
"Peer interconnect send thread did not shutdown correctly: {:?}",
err
);
};
debug!("Shutting down peer interconnect receiver (complete)");
debug!("Shutting down peer interconnect sender...");
if let Err(err) = self.recv_join_handle.join() {
error!(
"Peer interconnect recv thread did not shutdown correctly: {:?}",
err
);
}
debug!("Shutting down peer interconnect sender (complete)");
}
pub fn shutdown_and_wait(self) {
self.shutdown_signaler().shutdown();
self.await_shutdown();
}
}
#[derive(Default)]
pub struct PeerInterconnectBuilder<T: 'static, U: 'static, P>
where
T: ConnectionMatrixReceiver,
U: ConnectionMatrixSender,
P: PeerLookupProvider + 'static,
{
peer_lookup_provider: Option<P>,
message_receiver: Option<T>,
message_sender: Option<U>,
network_dispatcher_sender: Option<DispatchMessageSender<NetworkMessageType>>,
}
impl<T, U, P> PeerInterconnectBuilder<T, U, P>
where
T: ConnectionMatrixReceiver,
U: ConnectionMatrixSender,
P: PeerLookupProvider + 'static,
{
pub fn new() -> Self {
PeerInterconnectBuilder {
peer_lookup_provider: None,
message_receiver: None,
message_sender: None,
network_dispatcher_sender: None,
}
}
pub fn with_peer_connector(mut self, peer_lookup_provider: P) -> Self {
self.peer_lookup_provider = Some(peer_lookup_provider);
self
}
pub fn with_message_receiver(mut self, message_receiver: T) -> Self {
self.message_receiver = Some(message_receiver);
self
}
pub fn with_message_sender(mut self, message_sender: U) -> Self {
self.message_sender = Some(message_sender);
self
}
pub fn with_network_dispatcher_sender(
mut self,
network_dispatcher_sender: DispatchMessageSender<NetworkMessageType>,
) -> Self {
self.network_dispatcher_sender = Some(network_dispatcher_sender);
self
}
pub fn build(&mut self) -> Result<PeerInterconnect, PeerInterconnectError> {
let (dispatched_sender, dispatched_receiver) = channel();
let peer_lookup_provider = self.peer_lookup_provider.take().ok_or_else(|| {
PeerInterconnectError::StartUpError("Peer lookup provider missing".to_string())
})?;
let network_dispatcher_sender = self.network_dispatcher_sender.take().ok_or_else(|| {
PeerInterconnectError::StartUpError("Network dispatcher sender missing".to_string())
})?;
let message_receiver = self.message_receiver.take().ok_or_else(|| {
PeerInterconnectError::StartUpError("Message receiver missing".to_string())
})?;
let recv_peer_lookup = peer_lookup_provider.peer_lookup();
debug!("Starting peer interconnect receiver");
let recv_join_handle = thread::Builder::new()
.name("PeerInterconnect Receiver".into())
.spawn(move || {
if let Err(err) = run_recv_loop(
&*recv_peer_lookup,
message_receiver,
network_dispatcher_sender,
) {
error!("Shutting down peer interconnect recevier: {}", err);
}
})
.map_err(|err| {
PeerInterconnectError::StartUpError(format!(
"Unable to start PeerInterconnect receiver thread {}",
err
))
})?;
let send_peer_lookup = peer_lookup_provider.peer_lookup();
let message_sender = self
.message_sender
.take()
.ok_or_else(|| PeerInterconnectError::StartUpError("Already started".to_string()))?;
debug!("Starting peer interconnect sender");
let send_join_handle = thread::Builder::new()
.name("PeerInterconnect Sender".into())
.spawn(move || {
if let Err(err) =
run_send_loop(&*send_peer_lookup, dispatched_receiver, message_sender)
{
error!("Shutting down peer interconnect sender: {}", err);
}
})
.map_err(|err| {
PeerInterconnectError::StartUpError(format!(
"Unable to start PeerInterconnect sender thread {}",
err
))
})?;
Ok(PeerInterconnect {
dispatched_sender: dispatched_sender.clone(),
recv_join_handle,
send_join_handle,
shutdown_signaler: ShutdownSignaler {
sender: dispatched_sender,
},
})
}
}
fn run_recv_loop<R>(
peer_connector: &dyn PeerLookup,
message_receiver: R,
dispatch_msg_sender: DispatchMessageSender<NetworkMessageType>,
) -> Result<(), String>
where
R: ConnectionMatrixReceiver + 'static,
{
let mut connection_id_to_peer_id: HashMap<String, String> = HashMap::new();
loop {
let envelope = match message_receiver.recv() {
Ok(envelope) => envelope,
Err(ConnectionMatrixRecvError::Shutdown) => {
info!("ConnectionMatrix has shutdown");
break Ok(());
}
Err(ConnectionMatrixRecvError::Disconnected) => {
break Err("Unable to receive message: disconnected".into());
}
Err(ConnectionMatrixRecvError::InternalError { context, .. }) => {
break Err(format!("Unable to receive message: {}", context));
}
};
let connection_id = envelope.id();
let peer_id = if let Some(peer_id) = connection_id_to_peer_id.get(connection_id) {
Some(peer_id.to_owned())
} else if let Some(peer_id) = peer_connector
.peer_id(connection_id)
.map_err(|err| format!("Unable to get peer ID for {}: {}", connection_id, err))?
{
connection_id_to_peer_id.insert(connection_id.to_string(), peer_id.clone());
Some(peer_id)
} else {
None
};
if let Some(peer_id) = peer_id {
let mut network_msg: NetworkMessage =
match protobuf::parse_from_bytes(&envelope.payload()) {
Ok(msg) => msg,
Err(err) => {
error!("Unable to dispatch message: {}", err);
continue;
}
};
trace!(
"Received message from {}: {:?}",
peer_id,
network_msg.get_message_type()
);
match dispatch_msg_sender.send(
network_msg.get_message_type(),
network_msg.take_payload(),
peer_id.into(),
) {
Ok(()) => (),
Err((message_type, _, _)) => {
error!("Unable to dispatch message of type {:?}", message_type)
}
}
} else {
error!(
"Received message from removed or unknown peer with connection_id {}",
connection_id
);
}
}
}
fn run_send_loop<S>(
peer_connector: &dyn PeerLookup,
receiver: Receiver<SendRequest>,
message_sender: S,
) -> Result<(), String>
where
S: ConnectionMatrixSender + 'static,
{
let mut peer_id_to_connection_id: HashMap<String, String> = HashMap::new();
loop {
let (recipient, payload) = match receiver.recv() {
Ok(SendRequest::Message { recipient, payload }) => (recipient, payload),
Ok(SendRequest::Shutdown) => {
info!("Received Shutdown");
break Ok(());
}
Err(err) => {
break Err(format!("Unable to receive message from handlers: {}", err));
}
};
let connection_id = if let Some(connection_id) = peer_id_to_connection_id.get(&recipient) {
Some(connection_id.to_owned())
} else if let Some(connection_id) = peer_connector
.connection_id(&recipient)
.map_err(|err| format!("Unable to get connection ID for {}: {}", recipient, err))?
{
peer_id_to_connection_id.insert(recipient.clone(), connection_id.clone());
Some(connection_id)
} else {
None
};
if let Some(connection_id) = connection_id {
if let Err(err) = message_sender.send(connection_id.to_string(), payload.to_vec()) {
if let Some(new_connection_id) =
peer_connector.connection_id(&recipient).map_err(|err| {
format!("Unable to get connection ID for {}: {}", recipient, err)
})?
{
if new_connection_id != connection_id {
peer_id_to_connection_id
.insert(recipient.clone(), new_connection_id.clone());
if let Err(err) = message_sender.send(new_connection_id, payload) {
error!("Unable to send message to {}: {}", recipient, err);
}
}
} else {
error!("Unable to send message to {}: {}", recipient, err);
peer_id_to_connection_id.remove(&recipient);
}
}
} else {
error!("Cannot send message, unknown peer: {}", recipient);
}
}
}
#[derive(Clone)]
pub struct ShutdownHandle {
sender: Sender<SendRequest>,
}
impl ShutdownHandle {
pub fn shutdown(&self) {
if self.sender.send(SendRequest::Shutdown).is_err() {
warn!("Peer Interconnect is no longer running");
}
}
}
impl From<ShutdownSignaler> for ShutdownHandle {
fn from(signaler: ShutdownSignaler) -> Self {
ShutdownHandle {
sender: signaler.sender,
}
}
}
#[derive(Clone)]
pub struct ShutdownSignaler {
sender: Sender<SendRequest>,
}
impl ShutdownSignaler {
pub fn shutdown(&self) {
if self.sender.send(SendRequest::Shutdown).is_err() {
warn!("Peer Interconnect is no longer running");
}
}
}
#[cfg(test)]
pub mod tests {
use super::*;
use protobuf::Message;
use std::sync::mpsc::{self, Sender};
use std::time::Duration;
use crate::mesh::{Envelope, Mesh};
use crate::network::connection_manager::{
AuthorizationResult, Authorizer, AuthorizerError, ConnectionManager,
};
use crate::network::dispatch::{
dispatch_channel, DispatchError, DispatchLoopBuilder, Dispatcher, Handler, MessageContext,
MessageSender, PeerId,
};
use crate::peer::{PeerManager, PeerManagerNotification};
use crate::protos::network::NetworkEcho;
use crate::transport::{inproc::InprocTransport, Connection, Transport};
#[test]
fn test_peer_interconnect() {
let mut transport = Box::new(InprocTransport::default());
let mut listener = transport
.listen("inproc://test")
.expect("Cannot listen for connections");
let mesh1 = Mesh::new(512, 128);
let mesh2 = Mesh::new(512, 128);
let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let conn = listener.accept().expect("Cannot accept connection");
mesh2
.add(conn, "test_id".to_string())
.expect("Cannot add connection to mesh");
let message_bytes = echo_to_network_message_bytes(b"test_retrieve".to_vec());
let envelope = Envelope::new("test_id".to_string(), message_bytes);
mesh2.send(envelope).expect("Unable to send message");
let envelope = mesh2.recv().expect("Cannot receive message");
let network_msg: NetworkMessage = protobuf::parse_from_bytes(&envelope.payload())
.expect("Cannot parse NetworkMessage");
assert_eq!(
network_msg.get_message_type(),
NetworkMessageType::NETWORK_ECHO
);
let echo: NetworkEcho = protobuf::parse_from_bytes(network_msg.get_payload()).unwrap();
assert_eq!(echo.get_payload().to_vec(), b"test_retrieve".to_vec());
let message_bytes =
echo_to_network_message_bytes("shutdown_string".as_bytes().to_vec());
let envelope = Envelope::new("test_id".to_string(), message_bytes);
mesh2.send(envelope).expect("Cannot send message");
rx.recv().unwrap();
mesh2.shutdown_signaler().shutdown();
});
let cm = ConnectionManager::builder()
.with_authorizer(Box::new(NoopAuthorizer::new("test_peer")))
.with_matrix_life_cycle(mesh1.get_life_cycle())
.with_matrix_sender(mesh1.get_sender())
.with_transport(transport)
.start()
.expect("Unable to start Connection Manager");
let connector = cm.connector();
let peer_manager = PeerManager::builder()
.with_connector(connector)
.with_retry_interval(1)
.with_identity("my_id".to_string())
.with_strict_ref_counts(true)
.start()
.expect("Cannot start peer_manager");
let peer_connector = peer_manager.connector();
let (send, recv) = channel();
let (dispatcher_sender, dispatcher_receiver) = dispatch_channel();
let interconnect = PeerInterconnectBuilder::new()
.with_peer_connector(peer_connector.clone())
.with_message_receiver(mesh1.get_receiver())
.with_message_sender(mesh1.get_sender())
.with_network_dispatcher_sender(dispatcher_sender.clone())
.build()
.expect("Unable to build PeerInterconnect");
let mut dispatcher = Dispatcher::new(Box::new(interconnect.new_network_sender()));
let handler = NetworkTestHandler::new(send);
dispatcher.set_handler(Box::new(handler));
let network_dispatch_loop = DispatchLoopBuilder::new()
.with_dispatcher(dispatcher)
.with_thread_name("NetworkDispatchLoop".to_string())
.with_dispatch_channel((dispatcher_sender, dispatcher_receiver))
.build()
.expect("Unable to create network dispatch loop");
let dispatch_shutdown = network_dispatch_loop.shutdown_signaler();
let (notification_tx, notification_rx): (
Sender<PeerManagerNotification>,
mpsc::Receiver<PeerManagerNotification>,
) = channel();
peer_connector
.subscribe_sender(notification_tx)
.expect("Unable to get subscriber");
let peer_ref = peer_connector
.add_peer_ref("test_peer".to_string(), vec!["test".to_string()])
.expect("Unable to add peer");
assert_eq!(peer_ref.peer_id(), "test_peer");
let timeout = Duration::from_secs(60);
let notification = notification_rx
.recv_timeout(timeout)
.expect("Unable to get new notifications");
assert_eq!(
notification,
PeerManagerNotification::Connected {
peer: "test_peer".to_string()
}
);
let test_timeout = std::time::Duration::from_secs(60);
recv.recv_timeout(test_timeout)
.expect("Failed to receive message");
tx.send(()).unwrap();
peer_manager.shutdown_signaler().shutdown();
cm.shutdown_signaler().shutdown();
peer_manager.await_shutdown();
cm.await_shutdown();
dispatch_shutdown.shutdown();
mesh1.shutdown_signaler().shutdown();
interconnect.shutdown_signaler().shutdown();
interconnect.await_shutdown();
}
#[test]
fn test_peer_interconnect_shutdown() {
let transport = Box::new(InprocTransport::default());
let mesh = Mesh::new(512, 128);
let cm = ConnectionManager::builder()
.with_authorizer(Box::new(NoopAuthorizer::new("test_peer")))
.with_matrix_life_cycle(mesh.get_life_cycle())
.with_matrix_sender(mesh.get_sender())
.with_transport(transport)
.start()
.expect("Unable to start Connection Manager");
let connector = cm.connector();
let peer_manager = PeerManager::builder()
.with_connector(connector)
.with_retry_interval(1)
.with_identity("my_id".to_string())
.with_strict_ref_counts(true)
.start()
.expect("Cannot start peer_manager");
let peer_connector = peer_manager.connector();
let (dispatcher_sender, _dispatched_receiver) = dispatch_channel();
let interconnect = PeerInterconnectBuilder::new()
.with_peer_connector(peer_connector)
.with_message_receiver(mesh.get_receiver())
.with_message_sender(mesh.get_sender())
.with_network_dispatcher_sender(dispatcher_sender)
.build()
.expect("Unable to build PeerInterconnect");
peer_manager.shutdown_signaler().shutdown();
cm.shutdown_signaler().shutdown();
peer_manager.await_shutdown();
cm.await_shutdown();
mesh.shutdown_signaler().shutdown();
interconnect.shutdown_signaler().shutdown();
interconnect.await_shutdown();
}
struct Shutdown {}
struct NetworkTestHandler {
shutdown_sender: Sender<Shutdown>,
}
impl NetworkTestHandler {
fn new(shutdown_sender: Sender<Shutdown>) -> Self {
NetworkTestHandler { shutdown_sender }
}
}
impl Handler for NetworkTestHandler {
type Source = PeerId;
type MessageType = NetworkMessageType;
type Message = NetworkEcho;
fn match_type(&self) -> Self::MessageType {
NetworkMessageType::NETWORK_ECHO
}
fn handle(
&self,
message: NetworkEcho,
message_context: &MessageContext<Self::Source, NetworkMessageType>,
network_sender: &dyn MessageSender<Self::Source>,
) -> Result<(), DispatchError> {
let echo_string = String::from_utf8(message.get_payload().to_vec()).unwrap();
if &echo_string == "shutdown_string" {
self.shutdown_sender
.send(Shutdown {})
.expect("Cannot send shutdown");
} else {
assert_eq!(message_context.source_peer_id(), "test_peer");
let echo_bytes = message.write_to_bytes().unwrap();
let mut network_msg = NetworkMessage::new();
network_msg.set_message_type(NetworkMessageType::NETWORK_ECHO);
network_msg.set_payload(echo_bytes);
let network_msg_bytes = network_msg.write_to_bytes().unwrap();
network_sender
.send(message_context.source_id().clone(), network_msg_bytes)
.expect("Cannot send message");
}
Ok(())
}
}
fn echo_to_network_message_bytes(echo_bytes: Vec<u8>) -> Vec<u8> {
let mut echo_message = NetworkEcho::new();
echo_message.set_payload(echo_bytes);
let echo_message_bytes = echo_message.write_to_bytes().unwrap();
let mut network_message = NetworkMessage::new();
network_message.set_message_type(NetworkMessageType::NETWORK_ECHO);
network_message.set_payload(echo_message_bytes);
network_message.write_to_bytes().unwrap()
}
struct NoopAuthorizer {
authorized_id: String,
}
impl NoopAuthorizer {
fn new(id: &str) -> Self {
Self {
authorized_id: id.to_string(),
}
}
}
impl Authorizer for NoopAuthorizer {
fn authorize_connection(
&self,
connection_id: String,
connection: Box<dyn Connection>,
callback: Box<
dyn Fn(AuthorizationResult) -> Result<(), Box<dyn std::error::Error>> + Send,
>,
) -> Result<(), AuthorizerError> {
(*callback)(AuthorizationResult::Authorized {
connection_id,
connection,
identity: self.authorized_id.clone(),
})
.map_err(|err| AuthorizerError(format!("Unable to return result: {}", err)))
}
}
}