use std::convert::TryInto;
use bitflags::bitflags;
use bytes::{Bytes, BytesMut};
use log::*;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use super::{ProtocolError, ProtocolId};
const LOG_TARGET: &str = "comms::connection_manager::protocol";
const BUF_CAPACITY: usize = u8::MAX as usize;
const MAX_ROUNDS_ALLOWED: u8 = 5;
pub struct ProtocolNegotiation<'a, TSocket> {
buf: BytesMut,
socket: &'a mut TSocket,
}
bitflags! {
#[derive(Debug)]
struct Flags: u8 {
const NONE = 0x00;
const OPTIMISTIC = 0x01;
const TERMINATE = 0x02;
const NOT_SUPPORTED = 0x04;
}
}
impl<'a, TSocket> ProtocolNegotiation<'a, TSocket>
where TSocket: AsyncRead + AsyncWrite + Unpin
{
pub fn new(socket: &'a mut TSocket) -> Self {
Self {
socket,
buf: {
let mut buf = BytesMut::with_capacity(BUF_CAPACITY);
buf.resize(BUF_CAPACITY, 0);
buf
},
}
}
pub async fn negotiate_protocol_outbound(
&mut self,
selected_protocols: &[ProtocolId],
) -> Result<ProtocolId, ProtocolError> {
for protocol in selected_protocols {
self.write_frame_flush(protocol, Flags::NONE).await?;
let (proto, flags) = self.read_frame().await?;
if flags.contains(Flags::TERMINATE) {
return Err(ProtocolError::ProtocolNegotiationTerminatedByPeer);
}
if flags.contains(Flags::NOT_SUPPORTED) {
continue;
}
if proto.as_ref() == protocol {
return Ok(protocol.clone());
}
}
self.write_frame_flush(&[], Flags::TERMINATE).await?;
Err(ProtocolError::ProtocolOutboundNegotiationFailed {
protocols: selected_protocols
.iter()
.map(|b| String::from_utf8_lossy(b).to_string())
.collect::<Vec<_>>()
.join(", "),
})
}
pub async fn negotiate_protocol_outbound_optimistic(
&mut self,
protocol: &ProtocolId,
) -> Result<ProtocolId, ProtocolError> {
self.write_frame_flush(protocol, Flags::OPTIMISTIC | Flags::TERMINATE)
.await?;
Ok(protocol.clone())
}
pub async fn negotiate_protocol_inbound(
&mut self,
supported_protocols: &[ProtocolId],
) -> Result<ProtocolId, ProtocolError> {
let mut round = 0;
loop {
let (proto, flags) = self.read_frame().await?;
if flags.contains(Flags::OPTIMISTIC) {
return if supported_protocols.as_ref().contains(&proto) {
Ok(proto.clone())
} else {
Err(ProtocolError::ProtocolOptimisticNegotiationFailed)
};
}
if flags.contains(Flags::TERMINATE) {
return Err(ProtocolError::ProtocolNegotiationTerminatedByPeer);
}
match supported_protocols.as_ref().iter().find(|p| proto == p) {
Some(proto) => {
self.write_frame_flush(proto, Flags::NONE).await?;
return Ok(proto.clone());
},
None => {
let mut flags = Flags::NOT_SUPPORTED;
let terminate = round == MAX_ROUNDS_ALLOWED - 1;
if terminate {
flags |= Flags::TERMINATE;
}
self.write_frame_flush(&[], flags).await?;
if terminate {
break;
}
},
}
round += 1;
}
Err(ProtocolError::ProtocolInboundNegotiationFailed)
}
async fn read_frame(&mut self) -> Result<(Bytes, Flags), ProtocolError> {
self.socket
.read_exact(self.buf.get_mut(..2).ok_or(ProtocolError::ExpectedReadyBytes)?)
.await?;
let len = u8::from_be_bytes([*self.buf.first().ok_or(ProtocolError::ExpectedReadyBytes)?]) as usize;
let flags = Flags::from_bits(u8::from_be_bytes([*self
.buf
.get(1)
.ok_or(ProtocolError::ExpectedReadyBytes)?]))
.ok_or(ProtocolError::InvalidFlag(format!(
"Does not match any flags ({})",
self.buf.get(1).expect("Already checked")
)))?;
self.socket
.read_exact(self.buf.get_mut(0..len).ok_or(ProtocolError::ExpectedReadyBytes)?)
.await?;
trace!(
target: LOG_TARGET,
"Read frame '{}' ({} byte(s) Flags={:?})",
String::from_utf8_lossy(self.buf.get(0..len).ok_or(ProtocolError::ExpectedReadyBytes)?),
len,
flags,
);
Ok((
Bytes::copy_from_slice(self.buf.get(0..len).ok_or(ProtocolError::ExpectedReadyBytes)?),
flags,
))
}
async fn write_frame_flush(&mut self, protocol: &[u8], flags: Flags) -> Result<(), ProtocolError> {
let len_byte = protocol
.len()
.try_into()
.map(|v: u8| v.to_be_bytes())
.map_err(|_| ProtocolError::ProtocolIdTooLong)?;
self.socket.write_all(&len_byte).await?;
self.socket.write_all(&flags.bits().to_be_bytes()).await?;
self.socket.write_all(protocol).await?;
self.socket.flush().await?;
trace!(
target: LOG_TARGET,
"Wrote frame '{}' ({} byte(s) Flags={:?})",
String::from_utf8_lossy(protocol),
len_byte.first().expect("Already checked"),
flags
);
Ok(())
}
}
#[cfg(test)]
mod test {
use futures::future;
use tari_test_utils::unpack_enum;
use super::*;
use crate::memsocket::MemorySocket;
#[tokio::test]
async fn negotiate_success() {
let (mut initiator, mut responder) = MemorySocket::new_pair();
let mut negotiate_out = ProtocolNegotiation::new(&mut initiator);
let mut negotiate_in = ProtocolNegotiation::new(&mut responder);
let supported_protocols = vec![b"B", b"A"]
.into_iter()
.map(|p| ProtocolId::from_static(p))
.collect::<Vec<_>>();
let selected_protocols = vec![b"C", b"D", b"E", b"F", b"A"]
.into_iter()
.map(|p| ProtocolId::from_static(p))
.collect::<Vec<_>>();
let (in_proto, out_proto) = future::join(
negotiate_in.negotiate_protocol_inbound(&supported_protocols),
negotiate_out.negotiate_protocol_outbound(&selected_protocols),
)
.await;
assert_eq!(in_proto.unwrap(), ProtocolId::from_static(b"A"));
assert_eq!(out_proto.unwrap(), ProtocolId::from_static(b"A"));
}
#[tokio::test]
async fn negotiate_fail() {
let (mut initiator, mut responder) = MemorySocket::new_pair();
let mut negotiate_out = ProtocolNegotiation::new(&mut initiator);
let mut negotiate_in = ProtocolNegotiation::new(&mut responder);
let supported_protocols = vec![b"A", b"B"]
.into_iter()
.map(|p| ProtocolId::from_static(p))
.collect::<Vec<_>>();
let selected_protocols = vec![b"C", b"D", b"E"]
.into_iter()
.map(|p| ProtocolId::from_static(p))
.collect::<Vec<_>>();
let (in_proto, out_proto) = future::join(
negotiate_in.negotiate_protocol_inbound(&supported_protocols),
negotiate_out.negotiate_protocol_outbound(&selected_protocols),
)
.await;
unpack_enum!(ProtocolError::ProtocolNegotiationTerminatedByPeer = in_proto.unwrap_err());
unpack_enum!(ProtocolError::ProtocolOutboundNegotiationFailed { .. } = out_proto.unwrap_err());
}
#[tokio::test]
async fn negotiate_fail_max_rounds() {
let (mut initiator, mut responder) = MemorySocket::new_pair();
let mut negotiate_out = ProtocolNegotiation::new(&mut initiator);
let mut negotiate_in = ProtocolNegotiation::new(&mut responder);
let supported_protocols = vec![b"A", b"B"]
.into_iter()
.map(|p| ProtocolId::from_static(p))
.collect::<Vec<_>>();
let selected_protocols = vec![b"C", b"D", b"E", b"F", b"G", b"A"]
.into_iter()
.map(|p| ProtocolId::from_static(p))
.collect::<Vec<_>>();
let (in_proto, out_proto) = future::join(
negotiate_in.negotiate_protocol_inbound(&supported_protocols),
negotiate_out.negotiate_protocol_outbound(&selected_protocols),
)
.await;
unpack_enum!(ProtocolError::ProtocolInboundNegotiationFailed = in_proto.unwrap_err());
unpack_enum!(ProtocolError::ProtocolNegotiationTerminatedByPeer = out_proto.unwrap_err());
}
#[tokio::test]
async fn negotiate_success_optimistic() {
let (mut initiator, mut responder) = MemorySocket::new_pair();
let mut negotiate_out = ProtocolNegotiation::new(&mut initiator);
let mut negotiate_in = ProtocolNegotiation::new(&mut responder);
let supported_protocols = vec![b"B", b"A"]
.into_iter()
.map(|p| ProtocolId::from_static(p))
.collect::<Vec<_>>();
let (in_proto, out_proto) = future::join(
negotiate_in.negotiate_protocol_inbound(&supported_protocols),
negotiate_out.negotiate_protocol_outbound_optimistic(&Bytes::from_static(b"A")),
)
.await;
assert_eq!(in_proto.unwrap(), ProtocolId::from_static(b"A"));
out_proto.unwrap();
}
#[tokio::test]
async fn negotiate_fail_optimistic() {
let (mut initiator, mut responder) = MemorySocket::new_pair();
let mut negotiate_out = ProtocolNegotiation::new(&mut initiator);
let mut negotiate_in = ProtocolNegotiation::new(&mut responder);
let supported_protocols = vec![b"A", b"B"]
.into_iter()
.map(|p| ProtocolId::from_static(p))
.collect::<Vec<_>>();
let (in_proto, out_proto) = future::join(
negotiate_in.negotiate_protocol_inbound(&supported_protocols),
negotiate_out.negotiate_protocol_outbound_optimistic(&Bytes::from_static(b"C")),
)
.await;
unpack_enum!(ProtocolError::ProtocolOptimisticNegotiationFailed = in_proto.unwrap_err());
out_proto.unwrap();
}
}