use super::super::read_write::ReadWrite;
use crate::{libp2p::read_write, util::leb128};
use alloc::{collections::VecDeque, string::String};
use core::{cmp, fmt, str};
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Config<P> {
Dialer {
requested_protocol: P,
},
Listener {
max_protocol_name_len: usize,
},
}
#[derive(Debug)]
pub enum Negotiation<P> {
InProgress(InProgress<P>),
ListenerAcceptOrDeny(ListenerAcceptOrDeny<P>),
Success,
NotAvailable,
}
impl<P> Negotiation<P>
where
P: AsRef<str>,
{
pub fn new(config: Config<P>) -> Self {
Negotiation::InProgress(InProgress::new(config))
}
}
#[derive(Debug)]
pub struct ListenerAcceptOrDeny<P> {
inner: InProgress<P>,
protocol: String,
}
impl<P> ListenerAcceptOrDeny<P> {
pub fn requested_protocol(&self) -> &str {
&self.protocol
}
pub fn accept(mut self) -> InProgress<P> {
debug_assert!(matches!(self.inner.state, InProgressState::CommandExpected));
write_message(
Message::ProtocolOk(self.protocol.into_bytes()),
&mut self.inner.data_send_out,
);
self.inner.state = InProgressState::Finishing;
self.inner
}
pub fn reject(mut self) -> InProgress<P> {
debug_assert!(matches!(self.inner.state, InProgressState::CommandExpected));
write_message(
Message::<&'static [u8]>::ProtocolNa,
&mut self.inner.data_send_out,
);
self.inner
}
}
pub struct InProgress<P> {
config: Config<P>,
data_send_out: VecDeque<u8>,
state: InProgressState,
max_in_frame_len: usize,
next_in_frame_len: Option<usize>,
}
#[derive(Debug, Copy, Clone)]
enum InProgressState {
Finishing,
HandshakeExpected,
CommandExpected,
ProtocolRequestAnswerExpected,
}
impl<P> InProgress<P>
where
P: AsRef<str>,
{
pub fn new(config: Config<P>) -> Self {
let max_proto_name_len = match &config {
Config::Dialer { requested_protocol } => requested_protocol.as_ref().len(),
Config::Listener {
max_protocol_name_len,
} => *max_protocol_name_len,
};
const MIN_PROTO_LEN_NO_ERR: usize = 512;
let max_frame_len = cmp::max(
cmp::max(max_proto_name_len, MIN_PROTO_LEN_NO_ERR),
HANDSHAKE.len(),
) + 1;
InProgress {
data_send_out: {
let mut data = VecDeque::new();
write_message(Message::<&'static [u8]>::Handshake, &mut data);
if let Config::Dialer { requested_protocol } = &config {
write_message(
Message::ProtocolRequest(requested_protocol.as_ref()),
&mut data,
);
}
data
},
config,
state: InProgressState::HandshakeExpected,
max_in_frame_len: max_frame_len,
next_in_frame_len: None,
}
}
pub fn can_write_protocol_data(&self) -> bool {
matches!(self.state, InProgressState::ProtocolRequestAnswerExpected)
}
pub fn read_write<TNow>(
mut self,
read_write: &mut ReadWrite<TNow>,
) -> Result<Negotiation<P>, Error> {
loop {
read_write.write_from_vec_deque(&mut self.data_send_out);
if let InProgressState::Finishing = self.state {
debug_assert!(matches!(self.config, Config::Listener { .. }));
if self.data_send_out.is_empty() {
return Ok(Negotiation::Success);
} else {
break;
}
}
let mut frame = if let Some(next_frame_len) = self.next_in_frame_len {
match read_write.incoming_bytes_take(next_frame_len) {
Ok(None) => return Ok(Negotiation::InProgress(self)),
Ok(Some(frame)) => {
self.next_in_frame_len = None;
frame
}
Err(err) => return Err(Error::Frame(err)),
}
} else {
match read_write.incoming_bytes_take_leb128(self.max_in_frame_len) {
Ok(None) => return Ok(Negotiation::InProgress(self)),
Ok(Some(size)) => {
self.next_in_frame_len = Some(size);
continue;
}
Err(err) => return Err(Error::FrameLength(err)),
}
};
match (self.state, &self.config) {
(InProgressState::HandshakeExpected, Config::Dialer { .. }) => {
if &*frame != HANDSHAKE {
return Err(Error::BadHandshake);
}
self.state = InProgressState::ProtocolRequestAnswerExpected;
}
(InProgressState::HandshakeExpected, Config::Listener { .. }) => {
if &*frame != HANDSHAKE {
return Err(Error::BadHandshake);
}
self.state = InProgressState::CommandExpected;
}
(InProgressState::CommandExpected, Config::Listener { .. }) => {
if frame.pop() != Some(b'\n') {
return Err(Error::InvalidCommand);
}
let protocol = String::from_utf8(frame).map_err(|_| Error::InvalidCommand)?;
return Ok(Negotiation::ListenerAcceptOrDeny(ListenerAcceptOrDeny {
inner: self,
protocol,
}));
}
(
InProgressState::ProtocolRequestAnswerExpected,
Config::Dialer { requested_protocol },
) => {
if frame.pop() != Some(b'\n') {
return Err(Error::UnexpectedProtocolRequestAnswer);
}
if &*frame == b"na" {
return Ok(Negotiation::NotAvailable);
}
if frame != requested_protocol.as_ref().as_bytes() {
return Err(Error::UnexpectedProtocolRequestAnswer);
}
return Ok(Negotiation::Success);
}
(InProgressState::CommandExpected, Config::Dialer { .. })
| (InProgressState::ProtocolRequestAnswerExpected, Config::Listener { .. })
| (InProgressState::Finishing, _) => {
unreachable!();
}
};
}
Ok(Negotiation::InProgress(self))
}
}
impl<P> fmt::Debug for InProgress<P> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("InProgress").finish()
}
}
#[derive(Debug, Clone, derive_more::Display, derive_more::Error)]
pub enum Error {
ReadClosed,
WriteClosed,
#[display("LEB128 frame error: {_0}")]
FrameLength(read_write::IncomingBytesTakeLeb128Error),
#[display("LEB128 frame error: {_0}")]
Frame(read_write::IncomingBytesTakeError),
BadHandshake,
InvalidCommand,
UnexpectedProtocolRequestAnswer,
}
const HANDSHAKE: &[u8] = b"/multistream/1.0.0\n";
#[derive(Debug, Copy, Clone)]
enum Message<P> {
Handshake,
ProtocolRequest(P),
ProtocolOk(P),
ProtocolNa,
}
fn write_message(message: Message<impl AsRef<[u8]>>, out: &mut VecDeque<u8>) {
match message {
Message::Handshake => {
out.reserve(HANDSHAKE.len() + 4);
out.extend(leb128::encode_usize(HANDSHAKE.len()));
out.extend(HANDSHAKE);
}
Message::ProtocolRequest(p) | Message::ProtocolOk(p) => {
let p = p.as_ref();
out.reserve(p.len() + 5);
out.extend(leb128::encode_usize(p.len() + 1));
out.extend(p);
out.push_back(b'\n');
}
Message::ProtocolNa => {
out.reserve(8);
out.extend(leb128::encode_usize(3));
out.extend(b"na\n");
}
}
}
#[cfg(test)]
mod tests {
use alloc::collections::VecDeque;
use core::{cmp, mem};
use super::{super::super::read_write::ReadWrite, Config, Message, Negotiation, write_message};
#[test]
fn encode() {
let mut message = VecDeque::new();
write_message(Message::<&'static [u8]>::Handshake, &mut message);
assert_eq!(
message.drain(..).collect::<Vec<_>>(),
b"\x13/multistream/1.0.0\n".to_vec()
);
write_message(Message::ProtocolRequest("/hello"), &mut message);
assert_eq!(
message.drain(..).collect::<Vec<_>>(),
b"\x07/hello\n".to_vec()
);
write_message(Message::<&'static [u8]>::ProtocolNa, &mut message);
assert_eq!(message.drain(..).collect::<Vec<_>>(), b"\x03na\n".to_vec());
}
#[test]
fn negotiation_basic_works() {
fn test_with_buffer_sizes(mut size1: usize, mut size2: usize) {
let mut negotiation1 = Negotiation::new(Config::Dialer {
requested_protocol: "/foo",
});
let mut negotiation2 = Negotiation::new(Config::<String>::Listener {
max_protocol_name_len: 4,
});
let mut buf_1_to_2 = Vec::new();
let mut buf_2_to_1 = Vec::new();
let mut num_iterations = 0;
while !matches!(
(&negotiation1, &negotiation2),
(Negotiation::Success, Negotiation::Success)
) {
num_iterations += 1;
assert!(num_iterations <= 5000);
match negotiation1 {
Negotiation::InProgress(nego) => {
let mut read_write = ReadWrite {
now: 0,
incoming_buffer: buf_2_to_1,
expected_incoming_bytes: Some(0),
read_bytes: 0,
write_bytes_queued: buf_1_to_2.len(),
write_bytes_queueable: Some(size1 - buf_1_to_2.len()),
write_buffers: vec![mem::take(&mut buf_1_to_2)],
wake_up_after: None,
};
negotiation1 = nego.read_write(&mut read_write).unwrap();
buf_2_to_1 = read_write.incoming_buffer;
buf_1_to_2.extend(
read_write
.write_buffers
.drain(..)
.flat_map(|b| b.into_iter()),
);
size2 = cmp::max(size2, read_write.expected_incoming_bytes.unwrap_or(0));
}
Negotiation::Success => {}
Negotiation::ListenerAcceptOrDeny(_) => unreachable!(),
Negotiation::NotAvailable => panic!(),
}
match negotiation2 {
Negotiation::InProgress(nego) => {
let mut read_write = ReadWrite {
now: 0,
incoming_buffer: buf_1_to_2,
expected_incoming_bytes: Some(0),
read_bytes: 0,
write_bytes_queued: buf_2_to_1.len(),
write_bytes_queueable: Some(size2 - buf_2_to_1.len()),
write_buffers: vec![mem::take(&mut buf_2_to_1)],
wake_up_after: None,
};
negotiation2 = nego.read_write(&mut read_write).unwrap();
buf_1_to_2 = read_write.incoming_buffer;
buf_2_to_1.extend(
read_write
.write_buffers
.drain(..)
.flat_map(|b| b.into_iter()),
);
size1 = cmp::max(size1, read_write.expected_incoming_bytes.unwrap_or(0));
}
Negotiation::ListenerAcceptOrDeny(accept_reject)
if accept_reject.requested_protocol() == "/foo" =>
{
negotiation2 = Negotiation::InProgress(accept_reject.accept());
}
Negotiation::ListenerAcceptOrDeny(accept_reject) => {
negotiation2 = Negotiation::InProgress(accept_reject.reject());
}
Negotiation::Success => {}
Negotiation::NotAvailable => panic!(),
}
}
}
test_with_buffer_sizes(256, 256);
test_with_buffer_sizes(1, 1);
test_with_buffer_sizes(1, 2048);
test_with_buffer_sizes(2048, 1);
}
}