use molecule::prelude::{Builder, Entity, Reader};
use bytes::Bytes;
use futures::prelude::*;
use log::{debug, trace};
use std::cmp::Ordering;
use std::{collections::HashMap, io};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::{Framed, length_delimited::LengthDelimitedCodec};
#[rustfmt::skip]
#[allow(clippy::all)]
#[allow(dead_code)]
mod protocol_select_mol;
pub type SelectFn<T> = Box<dyn Fn(&[T], &[T]) -> Option<T> + Send + 'static>;
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct ProtocolInfo {
pub name: String,
pub support_versions: Vec<String>,
}
impl ProtocolInfo {
pub fn new(name: &str, support_versions: Vec<String>) -> Self {
ProtocolInfo {
name: name.to_owned(),
support_versions,
}
}
pub fn encode(self) -> Bytes {
let name = protocol_select_mol::String::new_builder()
.set(self.name.into_bytes().into_iter().map(Into::into).collect())
.build();
let mut versions = Vec::new();
for version in self.support_versions {
versions.push(
protocol_select_mol::String::new_builder()
.set(version.into_bytes().into_iter().map(Into::into).collect())
.build(),
);
}
let versions = protocol_select_mol::StringVec::new_builder()
.set(versions)
.build();
protocol_select_mol::ProtocolInfo::new_builder()
.name(name)
.support_versions(versions)
.build()
.as_bytes()
}
pub fn decode(data: &[u8]) -> Option<Self> {
let reader = protocol_select_mol::ProtocolInfoReader::from_compatible_slice(data).ok()?;
let mut supports = Vec::new();
for version in reader.support_versions().iter() {
supports.push(String::from_utf8(version.raw_data().to_owned()).ok()?)
}
Some(ProtocolInfo {
name: String::from_utf8(reader.name().raw_data().to_owned()).ok()?,
support_versions: supports,
})
}
}
pub(crate) async fn client_select<T: AsyncWrite + AsyncRead + Send + Unpin>(
handle: T,
proto_info: ProtocolInfo,
) -> Result<(Framed<T, LengthDelimitedCodec>, String, Option<String>), io::Error> {
let mut socket = Framed::new(handle, LengthDelimitedCodec::new());
let data = proto_info.encode();
trace!("client_select send_proto(len={}): {:#x}", data.len(), data);
socket.send(data).await?;
let (raw_remote_info, socket) = socket.into_future().await;
let mut remote_info = match raw_remote_info.transpose()? {
Some(info) => {
trace!("client_select recv_proto(len={}): {:#x}", info.len(), info);
ProtocolInfo::decode(&info)
.ok_or_else(|| io::Error::from(io::ErrorKind::InvalidData))?
}
None => {
debug!("client_select unexpected eof while waiting for remote's protocol proposition");
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected eof",
));
}
};
Ok((
socket,
remote_info.name,
remote_info.support_versions.pop(),
))
}
pub(crate) async fn server_select<T: AsyncWrite + AsyncRead + Send + Unpin>(
handle: T,
mut proto_infos: HashMap<String, (ProtocolInfo, Option<SelectFn<String>>)>,
) -> Result<(Framed<T, LengthDelimitedCodec>, String, Option<String>), io::Error> {
let socket = Framed::new(handle, LengthDelimitedCodec::new());
let (raw_remote_info, mut socket) = socket.into_future().await;
let remote_info = match raw_remote_info.transpose()? {
Some(info) => {
trace!("server_select recv_proto(len={}): {:#x}", info.len(), info);
ProtocolInfo::decode(&info)
.ok_or_else(|| io::Error::from(io::ErrorKind::InvalidData))?
}
None => {
debug!("server_select unexpected eof while waiting for remote's protocol proposition");
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected eof",
));
}
};
let version = proto_infos
.remove(&remote_info.name)
.and_then(|(local_info, select)| {
select
.map(|f| f(&local_info.support_versions, &remote_info.support_versions))
.unwrap_or_else(|| {
select_version(&local_info.support_versions, &remote_info.support_versions)
})
});
let data = ProtocolInfo {
name: remote_info.name.clone(),
support_versions: version.clone().into_iter().collect(),
}
.encode();
trace!("server_select send_proto(len={}): {:#x}", data.len(), data);
socket.send(data).await?;
Ok((socket, remote_info.name, version))
}
#[inline]
pub fn select_version<T: Ord + Clone>(local: &[T], remote: &[T]) -> Option<T> {
let (mut local_iter, mut remote_iter) = (local.iter().rev(), remote.iter().rev());
let (mut local, mut remote) = (local_iter.next(), remote_iter.next());
while let (Some(l), Some(r)) = (local, remote) {
match l.cmp(r) {
Ordering::Less => remote = remote_iter.next(),
Ordering::Greater => local = local_iter.next(),
Ordering::Equal => return Some(l.clone()),
}
}
None
}
#[cfg(test)]
mod tests {
use super::{ProtocolInfo, client_select, select_version, server_select};
use futures::channel;
use std::collections::HashMap;
use tokio::net::{TcpListener, TcpStream};
#[test]
fn protocol_message_decode_encode() {
let message = ProtocolInfo {
name: "test".to_owned(),
support_versions: vec!["1.0.0".to_string(), "1.1.1".to_string()],
};
let byte = message.clone();
assert_eq!(message, ProtocolInfo::decode(&byte.encode()).unwrap())
}
#[test]
fn test_select_version() {
let test_a = vec![
"1.0.0".to_string(),
"1.1.1".to_string(),
"2.0.0".to_string(),
];
let test_b = vec![
"1.0.0".to_string(),
"2.0.0".to_string(),
"3.0.0".to_string(),
];
let test_c = vec![];
let test_d = vec!["5.0.0".to_string()];
let test_e = vec!["1.0.0".to_string()];
assert_eq!(select_version(&test_b, &test_a), Some("2.0.0".to_string()));
assert_eq!(select_version(&test_b, &test_e), Some("1.0.0".to_string()));
assert!(select_version(&test_b, &test_c).is_none());
assert!(select_version(&test_b, &test_d).is_none());
assert!(select_version(&test_d, &test_a).is_none());
assert!(select_version(&test_d, &test_e).is_none());
assert!(select_version(&test_e, &test_d).is_none());
}
fn select_protocol(server: Vec<String>, client: Vec<String>, result: Option<String>) {
let (sender_1, receiver_1) = channel::oneshot::channel::<Option<String>>();
let (sender_2, receiver_2) = channel::oneshot::channel::<Option<String>>();
let (addr_sender, addr_receiver) = channel::oneshot::channel::<::std::net::SocketAddr>();
let rt = tokio::runtime::Runtime::new().unwrap();
rt.spawn(async move {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener_addr = listener.local_addr().unwrap();
let _res = addr_sender.send(listener_addr);
let (connect, _) = listener.accept().await.unwrap();
let message = ProtocolInfo {
name: "test".to_owned(),
support_versions: server,
};
let mut messages = HashMap::new();
messages.insert("test".to_owned(), (message, None));
let (_, _, a) = server_select(connect, messages).await.unwrap();
let _res = sender_1.send(a);
});
rt.spawn(async move {
let listener_addr = addr_receiver.await.unwrap();
let connect = TcpStream::connect(&listener_addr).await.unwrap();
let message = ProtocolInfo {
name: "test".to_owned(),
support_versions: client,
};
let (_, _, a) = client_select(connect, message).await.unwrap();
let _res = sender_2.send(a);
});
rt.block_on(async move {
assert_eq!(receiver_1.await.unwrap(), result);
assert_eq!(receiver_2.await.unwrap(), result);
});
}
#[test]
fn test_select_success_same() {
select_protocol(
vec!["1.0.0".to_string(), "1.1.1".to_string()],
vec!["1.0.0".to_string(), "1.1.1".to_string()],
Some("1.1.1".to_owned()),
)
}
#[test]
fn test_select_success_different() {
select_protocol(
vec!["1.0.0".to_string(), "2.1.1".to_string()],
vec!["1.0.0".to_string(), "1.1.1".to_string()],
Some("1.0.0".to_owned()),
)
}
#[test]
fn test_select_fail() {
select_protocol(
vec!["1.0.0".to_string(), "1.1.1".to_string()],
vec!["2.0.0".to_string(), "2.1.1".to_string()],
None,
)
}
}