use crate::circuit::handlers::create_message;
use crate::circuit::routing::{RoutingTableReader, ServiceId};
use crate::network::dispatch::{DispatchError, Handler, MessageContext, MessageSender, PeerId};
use crate::peer::PeerTokenPair;
use crate::protos::circuit::{
CircuitDirectMessage, CircuitError, CircuitError_Error, CircuitMessageType,
};
use protobuf::Message;
pub struct CircuitDirectMessageHandler {
node_id: String,
routing_table: Box<dyn RoutingTableReader>,
}
impl Handler for CircuitDirectMessageHandler {
type Source = PeerId;
type MessageType = CircuitMessageType;
type Message = CircuitDirectMessage;
fn match_type(&self) -> Self::MessageType {
CircuitMessageType::CIRCUIT_DIRECT_MESSAGE
}
fn handle(
&self,
msg: Self::Message,
context: &MessageContext<Self::Source, Self::MessageType>,
sender: &dyn MessageSender<Self::Source>,
) -> Result<(), DispatchError> {
debug!(
"Handle Circuit Direct Message {}on {} ({} => {}) [{} byte{}]",
if msg.get_correlation_id().is_empty() {
"".to_string()
} else {
format!("{} ", msg.get_correlation_id())
},
msg.get_circuit(),
msg.get_sender(),
msg.get_recipient(),
msg.get_payload().len(),
if msg.get_payload().len() == 1 {
""
} else {
"s"
}
);
let circuit_name = msg.get_circuit();
let msg_sender = msg.get_sender();
let recipient = msg.get_recipient();
let recipient_id = ServiceId::new(circuit_name.to_string(), recipient.to_string());
let (msg_bytes, msg_recipient) = {
if let Some(circuit) = self
.routing_table
.get_circuit(circuit_name)
.map_err(|err| DispatchError::HandleError(err.to_string()))?
{
if !circuit
.roster()
.iter()
.any(|service| service.service_id() == msg_sender)
{
let mut error_message = CircuitError::new();
error_message.set_correlation_id(msg.get_correlation_id().to_string());
error_message.set_service_id(msg_sender.into());
error_message.set_circuit_name(circuit_name.into());
error_message.set_error(CircuitError_Error::ERROR_SENDER_NOT_IN_CIRCUIT_ROSTER);
error_message.set_error_message(format!(
"Sender is not allowed in the Circuit: {}",
msg_sender
));
let msg_bytes = error_message.write_to_bytes()?;
let network_msg_bytes =
create_message(msg_bytes, CircuitMessageType::CIRCUIT_ERROR_MESSAGE)?;
(network_msg_bytes, context.source_peer_id().clone())
} else if circuit
.roster()
.iter()
.any(|service| service.service_id() == recipient)
{
if let Some(service) = self
.routing_table
.get_service(&recipient_id)
.map_err(|err| DispatchError::HandleError(err.to_string()))?
{
let node_id = service.node_id().to_string();
let msg_bytes = context.message_bytes().to_vec();
let network_msg_bytes =
create_message(msg_bytes, CircuitMessageType::CIRCUIT_DIRECT_MESSAGE)?;
if node_id != self.node_id {
let node_peer_id: PeerId = {
let peer_id = self
.routing_table
.get_node(&node_id)
.map_err(|err| DispatchError::HandleError(err.to_string()))?
.ok_or_else(|| {
DispatchError::HandleError(format!(
"Node {} not in routing table",
node_id
))
})?
.get_peer_auth_token(circuit.authorization_type())
.map_err(|err| DispatchError::HandleError(err.to_string()))?;
let local_peer_id = self
.routing_table
.get_node(&self.node_id)
.map_err(|err| DispatchError::HandleError(err.to_string()))?
.ok_or_else(|| {
DispatchError::HandleError(format!(
"Local Node {} not in routing table",
node_id
))
})?
.get_peer_auth_token(circuit.authorization_type())
.map_err(|err| DispatchError::HandleError(err.to_string()))?;
PeerTokenPair::new(peer_id, local_peer_id)
}
.into();
(network_msg_bytes, node_peer_id)
} else {
let peer_id: PeerId = match service.local_peer_id() {
Some(peer_id) => peer_id.clone().into(),
None => {
warn!("No peer id for service:{} ", service.service_id());
return Ok(());
}
};
(network_msg_bytes, peer_id)
}
} else {
let mut error_message = CircuitError::new();
error_message.set_correlation_id(msg.get_correlation_id().to_string());
error_message.set_service_id(msg_sender.into());
error_message.set_circuit_name(circuit_name.into());
error_message
.set_error(CircuitError_Error::ERROR_RECIPIENT_NOT_IN_DIRECTORY);
error_message.set_error_message(format!(
"Recipient is not in the service directory: {}",
recipient
));
let msg_bytes = error_message.write_to_bytes()?;
let network_msg_bytes =
create_message(msg_bytes, CircuitMessageType::CIRCUIT_ERROR_MESSAGE)?;
(network_msg_bytes, context.source_peer_id().clone())
}
} else {
let mut error_message = CircuitError::new();
error_message.set_correlation_id(msg.get_correlation_id().to_string());
error_message.set_service_id(msg_sender.into());
error_message.set_circuit_name(circuit_name.into());
error_message
.set_error(CircuitError_Error::ERROR_RECIPIENT_NOT_IN_CIRCUIT_ROSTER);
error_message.set_error_message(format!(
"Recipient is not allowed in the Circuit: {}",
recipient
));
let msg_bytes = error_message.write_to_bytes()?;
let network_msg_bytes =
create_message(msg_bytes, CircuitMessageType::CIRCUIT_ERROR_MESSAGE)?;
(network_msg_bytes, context.source_peer_id().clone())
}
} else {
let mut error_message = CircuitError::new();
error_message.set_correlation_id(msg.get_correlation_id().into());
error_message.set_service_id(msg_sender.into());
error_message.set_circuit_name(circuit_name.into());
error_message.set_error(CircuitError_Error::ERROR_CIRCUIT_DOES_NOT_EXIST);
error_message
.set_error_message(format!("Circuit does not exist: {}", circuit_name));
let msg_bytes = error_message.write_to_bytes()?;
let network_msg_bytes =
create_message(msg_bytes, CircuitMessageType::CIRCUIT_ERROR_MESSAGE)?;
(network_msg_bytes, context.source_peer_id().clone())
}
};
sender
.send(msg_recipient, msg_bytes)
.map_err(|(recipient, payload)| {
DispatchError::NetworkSendError((recipient.into(), payload))
})?;
Ok(())
}
}
impl CircuitDirectMessageHandler {
pub fn new(node_id: String, routing_table: Box<dyn RoutingTableReader>) -> Self {
CircuitDirectMessageHandler {
node_id,
routing_table,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use crate::circuit::routing::AuthorizationType;
use crate::circuit::routing::{
memory::RoutingTable, Circuit, CircuitNode, RoutingTableWriter, Service,
};
use crate::network::dispatch::Dispatcher;
use crate::peer::PeerAuthorizationToken;
use crate::protos::circuit::CircuitMessage;
use crate::protos::network::NetworkMessage;
#[test]
fn test_circuit_direct_message_handler_service() {
let mock_sender = MockSender::new();
let mut dispatcher = Dispatcher::new(Box::new(mock_sender.clone()));
let table = RoutingTable::default();
let reader: Box<dyn RoutingTableReader> = Box::new(table.clone());
let mut writer: Box<dyn RoutingTableWriter> = Box::new(table.clone());
let node_123 = CircuitNode::new("123".to_string(), vec!["123.0.0.1:0".to_string()], None);
let node_345 = CircuitNode::new("345".to_string(), vec!["123.0.0.1:1".to_string()], None);
let mut service_abc = Service::new(
"abc".to_string(),
"test".to_string(),
"123".to_string(),
vec![],
);
let mut service_def = Service::new(
"def".to_string(),
"test".to_string(),
"345".to_string(),
vec![],
);
service_abc.set_local_peer_id(PeerTokenPair::new(
PeerAuthorizationToken::from_peer_id("abc_network"),
PeerAuthorizationToken::from_peer_id("123"),
));
service_def.set_local_peer_id(PeerTokenPair::new(
PeerAuthorizationToken::from_peer_id("def_network"),
PeerAuthorizationToken::from_peer_id("345"),
));
let circuit = Circuit::new(
"alpha".into(),
vec![service_abc.clone(), service_def.clone()],
vec!["123".into(), "345".into()],
AuthorizationType::Trust,
);
writer
.add_circuit(
circuit.circuit_id().into(),
circuit,
vec![node_123, node_345],
)
.expect("Unable to add circuits");
let handler = CircuitDirectMessageHandler::new("123".to_string(), reader);
dispatcher.set_handler(Box::new(handler));
let mut direct_message = CircuitDirectMessage::new();
direct_message.set_circuit("alpha".into());
direct_message.set_sender("def".into());
direct_message.set_recipient("abc".into());
direct_message.set_payload(b"test".to_vec());
direct_message.set_correlation_id("1234".into());
let direct_bytes = direct_message.write_to_bytes().unwrap();
dispatcher
.dispatch(
PeerTokenPair::new(
PeerAuthorizationToken::from_peer_id("def"),
PeerAuthorizationToken::from_peer_id("345"),
)
.into(),
&CircuitMessageType::CIRCUIT_DIRECT_MESSAGE,
direct_bytes.clone(),
)
.unwrap();
let (id, message) = mock_sender.next_outbound().expect("No message was sent");
assert_network_message(
message,
id.into(),
PeerTokenPair::new(
PeerAuthorizationToken::from_peer_id("abc_network"),
PeerAuthorizationToken::from_peer_id("123"),
),
CircuitMessageType::CIRCUIT_DIRECT_MESSAGE,
|msg: CircuitDirectMessage| {
assert_eq!(msg.get_sender(), "def");
assert_eq!(msg.get_circuit(), "alpha");
assert_eq!(msg.get_recipient(), "abc");
assert_eq!(msg.get_payload().to_vec(), b"test".to_vec());
assert_eq!(msg.get_correlation_id(), "1234");
},
)
}
#[test]
fn test_circuit_direct_message_handler_node() {
let mock_sender = MockSender::new();
let mut dispatcher = Dispatcher::new(Box::new(mock_sender.clone()));
let table = RoutingTable::default();
let reader: Box<dyn RoutingTableReader> = Box::new(table.clone());
let mut writer: Box<dyn RoutingTableWriter> = Box::new(table.clone());
let node_123 = CircuitNode::new("123".to_string(), vec!["123.0.0.1:0".to_string()], None);
let node_345 = CircuitNode::new("345".to_string(), vec!["123.0.0.1:1".to_string()], None);
let mut service_abc = Service::new(
"abc".to_string(),
"test".to_string(),
"123".to_string(),
vec![],
);
let mut service_def = Service::new(
"def".to_string(),
"test".to_string(),
"345".to_string(),
vec![],
);
service_abc.set_local_peer_id(PeerTokenPair::new(
PeerAuthorizationToken::from_peer_id("abc_network"),
PeerAuthorizationToken::from_peer_id("123"),
));
service_def.set_local_peer_id(PeerTokenPair::new(
PeerAuthorizationToken::from_peer_id("def_network"),
PeerAuthorizationToken::from_peer_id("345"),
));
let circuit = Circuit::new(
"alpha".into(),
vec![service_abc.clone(), service_def.clone()],
vec!["123".into(), "345".into()],
AuthorizationType::Trust,
);
writer
.add_circuit(
circuit.circuit_id().into(),
circuit,
vec![node_123, node_345],
)
.expect("Unable to add circuits");
let handler = CircuitDirectMessageHandler::new("345".to_string(), reader);
dispatcher.set_handler(Box::new(handler));
let mut direct_message = CircuitDirectMessage::new();
direct_message.set_circuit("alpha".into());
direct_message.set_sender("def".into());
direct_message.set_recipient("abc".into());
direct_message.set_payload(b"test".to_vec());
direct_message.set_correlation_id("1234".into());
let direct_bytes = direct_message.write_to_bytes().unwrap();
dispatcher
.dispatch(
PeerTokenPair::new(
PeerAuthorizationToken::from_peer_id("def"),
PeerAuthorizationToken::from_peer_id("345"),
)
.into(),
&CircuitMessageType::CIRCUIT_DIRECT_MESSAGE,
direct_bytes.clone(),
)
.unwrap();
let (id, message) = mock_sender.next_outbound().expect("No message was sent");
assert_network_message(
message,
id.into(),
PeerTokenPair::new(
PeerAuthorizationToken::from_peer_id("123"),
PeerAuthorizationToken::from_peer_id("345"),
),
CircuitMessageType::CIRCUIT_DIRECT_MESSAGE,
|msg: CircuitDirectMessage| {
assert_eq!(msg.get_sender(), "def");
assert_eq!(msg.get_circuit(), "alpha");
assert_eq!(msg.get_recipient(), "abc");
assert_eq!(msg.get_payload().to_vec(), b"test".to_vec());
assert_eq!(msg.get_correlation_id(), "1234");
},
)
}
#[test]
fn test_circuit_direct_message_handler_sender_not_in_circuit_roster() {
let mock_sender = MockSender::new();
let mut dispatcher = Dispatcher::new(Box::new(mock_sender.clone()));
let table = RoutingTable::default();
let reader: Box<dyn RoutingTableReader> = Box::new(table.clone());
let mut writer: Box<dyn RoutingTableWriter> = Box::new(table.clone());
let node_123 = CircuitNode::new("123".to_string(), vec!["123.0.0.1:0".to_string()], None);
let node_345 = CircuitNode::new("345".to_string(), vec!["123.0.0.1:1".to_string()], None);
let mut service_abc = Service::new(
"abc".to_string(),
"test".to_string(),
"123".to_string(),
vec![],
);
service_abc.set_local_peer_id(PeerTokenPair::new(
PeerAuthorizationToken::from_peer_id("abc_network"),
PeerAuthorizationToken::from_peer_id("123"),
));
let circuit = Circuit::new(
"alpha".into(),
vec![service_abc.clone()],
vec!["123".into(), "345".into()],
AuthorizationType::Trust,
);
writer
.add_circuit(
circuit.circuit_id().into(),
circuit,
vec![node_123, node_345],
)
.expect("Unable to add circuits");
let handler = CircuitDirectMessageHandler::new("123".to_string(), reader);
dispatcher.set_handler(Box::new(handler));
let mut direct_message = CircuitDirectMessage::new();
direct_message.set_circuit("alpha".into());
direct_message.set_sender("def".into());
direct_message.set_recipient("abc".into());
direct_message.set_payload(b"test".to_vec());
direct_message.set_correlation_id("1234".into());
let direct_bytes = direct_message.write_to_bytes().unwrap();
dispatcher
.dispatch(
PeerTokenPair::new(
PeerAuthorizationToken::from_peer_id("def"),
PeerAuthorizationToken::from_peer_id("345"),
)
.into(),
&CircuitMessageType::CIRCUIT_DIRECT_MESSAGE,
direct_bytes.clone(),
)
.unwrap();
let (id, message) = mock_sender.next_outbound().expect("No message was sent");
assert_network_message(
message,
id.into(),
PeerTokenPair::new(
PeerAuthorizationToken::from_peer_id("def"),
PeerAuthorizationToken::from_peer_id("345"),
),
CircuitMessageType::CIRCUIT_ERROR_MESSAGE,
|msg: CircuitError| {
assert_eq!(msg.get_service_id(), "def");
assert_eq!(
msg.get_error(),
CircuitError_Error::ERROR_SENDER_NOT_IN_CIRCUIT_ROSTER
);
assert_eq!(msg.get_correlation_id(), "1234");
},
)
}
#[test]
fn test_circuit_direct_message_handler_recipient_not_in_circuit_roster() {
let mock_sender = MockSender::new();
let mut dispatcher = Dispatcher::new(Box::new(mock_sender.clone()));
let table = RoutingTable::default();
let reader: Box<dyn RoutingTableReader> = Box::new(table.clone());
let mut writer: Box<dyn RoutingTableWriter> = Box::new(table.clone());
let node_123 = CircuitNode::new("123".to_string(), vec!["123.0.0.1:0".to_string()], None);
let node_345 = CircuitNode::new("345".to_string(), vec!["123.0.0.1:1".to_string()], None);
let mut service_def = Service::new(
"def".to_string(),
"test".to_string(),
"345".to_string(),
vec![],
);
service_def.set_local_peer_id(PeerTokenPair::new(
PeerAuthorizationToken::from_peer_id("def_network"),
PeerAuthorizationToken::from_peer_id("345"),
));
let circuit = Circuit::new(
"alpha".into(),
vec![service_def.clone()],
vec!["123".into(), "345".into()],
AuthorizationType::Trust,
);
writer
.add_circuit(
circuit.circuit_id().into(),
circuit,
vec![node_123, node_345],
)
.expect("Unable to add circuits");
let handler = CircuitDirectMessageHandler::new("345".to_string(), reader);
dispatcher.set_handler(Box::new(handler));
let mut direct_message = CircuitDirectMessage::new();
direct_message.set_circuit("alpha".into());
direct_message.set_sender("def".into());
direct_message.set_recipient("abc".into());
direct_message.set_payload(b"test".to_vec());
direct_message.set_correlation_id("1234".into());
let direct_bytes = direct_message.write_to_bytes().unwrap();
dispatcher
.dispatch(
PeerTokenPair::new(
PeerAuthorizationToken::from_peer_id("def"),
PeerAuthorizationToken::from_peer_id("345"),
)
.into(),
&CircuitMessageType::CIRCUIT_DIRECT_MESSAGE,
direct_bytes.clone(),
)
.unwrap();
let (id, message) = mock_sender.next_outbound().expect("No message was sent");
assert_network_message(
message,
id.into(),
PeerTokenPair::new(
PeerAuthorizationToken::from_peer_id("def"),
PeerAuthorizationToken::from_peer_id("345"),
),
CircuitMessageType::CIRCUIT_ERROR_MESSAGE,
|msg: CircuitError| {
assert_eq!(msg.get_service_id(), "def");
assert_eq!(
msg.get_error(),
CircuitError_Error::ERROR_RECIPIENT_NOT_IN_CIRCUIT_ROSTER
);
assert_eq!(msg.get_correlation_id(), "1234");
},
)
}
#[test]
fn test_circuit_direct_message_handler_no_circuit() {
let mock_sender = MockSender::new();
let mut dispatcher = Dispatcher::new(Box::new(mock_sender.clone()));
let table = RoutingTable::default();
let reader: Box<dyn RoutingTableReader> = Box::new(table.clone());
let handler = CircuitDirectMessageHandler::new("345".to_string(), reader);
dispatcher.set_handler(Box::new(handler));
let mut direct_message = CircuitDirectMessage::new();
direct_message.set_circuit("alpha".into());
direct_message.set_sender("def".into());
direct_message.set_recipient("abc".into());
direct_message.set_payload(b"test".to_vec());
direct_message.set_correlation_id("1234".into());
let direct_bytes = direct_message.write_to_bytes().unwrap();
dispatcher
.dispatch(
PeerTokenPair::new(
PeerAuthorizationToken::from_peer_id("def"),
PeerAuthorizationToken::from_peer_id("345"),
)
.into(),
&CircuitMessageType::CIRCUIT_DIRECT_MESSAGE,
direct_bytes.clone(),
)
.unwrap();
let (id, message) = mock_sender.next_outbound().expect("No message was sent");
assert_network_message(
message,
id.into(),
PeerTokenPair::new(
PeerAuthorizationToken::from_peer_id("def"),
PeerAuthorizationToken::from_peer_id("345"),
),
CircuitMessageType::CIRCUIT_ERROR_MESSAGE,
|msg: CircuitError| {
assert_eq!(msg.get_service_id(), "def");
assert_eq!(
msg.get_error(),
CircuitError_Error::ERROR_CIRCUIT_DOES_NOT_EXIST
);
assert_eq!(msg.get_correlation_id(), "1234");
},
)
}
fn assert_network_message<M: protobuf::Message, F: Fn(M)>(
message: Vec<u8>,
recipient: PeerTokenPair,
expected_recipient: PeerTokenPair,
expected_circuit_msg_type: CircuitMessageType,
detail_assertions: F,
) {
assert_eq!(expected_recipient, recipient);
let network_msg: NetworkMessage = Message::parse_from_bytes(&message).unwrap();
let circuit_msg: CircuitMessage =
Message::parse_from_bytes(network_msg.get_payload()).unwrap();
assert_eq!(expected_circuit_msg_type, circuit_msg.get_message_type(),);
let circuit_msg: M = Message::parse_from_bytes(circuit_msg.get_payload()).unwrap();
detail_assertions(circuit_msg);
}
#[derive(Clone)]
struct MockSender {
outbound: Arc<Mutex<VecDeque<(PeerId, Vec<u8>)>>>,
}
impl MockSender {
fn new() -> Self {
Self {
outbound: Arc::new(Mutex::new(VecDeque::new())),
}
}
fn next_outbound(&self) -> Option<(PeerId, Vec<u8>)> {
self.outbound.lock().expect("lock was poisoned").pop_front()
}
}
impl MessageSender<PeerId> for MockSender {
fn send(&self, id: PeerId, message: Vec<u8>) -> Result<(), (PeerId, Vec<u8>)> {
self.outbound
.lock()
.expect("lock was poisoned")
.push_back((id, message));
Ok(())
}
}
}