use futures::prelude::*;
use multistream_select::{dialer_select_proto, listener_select_proto, NegotiationError, Version};
use std::time::Duration;
#[test]
fn select_proto_basic() {
async fn run(version: Version) {
let (client_connection, server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
let server = async_std::task::spawn(async move {
let protos = vec!["/proto1", "/proto2"];
let (proto, mut io) = listener_select_proto(server_connection, protos)
.await
.unwrap();
assert_eq!(proto, "/proto2");
let mut out = vec![0; 32];
let n = io.read(&mut out).await.unwrap();
out.truncate(n);
assert_eq!(out, b"ping");
io.write_all(b"pong").await.unwrap();
io.flush().await.unwrap();
});
let client = async_std::task::spawn(async move {
let protos = vec!["/proto3", "/proto2"];
let (proto, mut io) =
dialer_select_proto(client_connection, protos.into_iter(), version)
.await
.unwrap();
assert_eq!(proto, "/proto2");
io.write_all(b"ping").await.unwrap();
io.flush().await.unwrap();
let mut out = vec![0; 32];
let n = io.read(&mut out).await.unwrap();
out.truncate(n);
assert_eq!(out, b"pong");
});
server.await;
client.await;
}
async_std::task::block_on(run(Version::V1));
async_std::task::block_on(run(Version::V1Lazy));
}
#[test]
fn negotiation_failed() {
let _ = env_logger::try_init();
async fn run(
Test {
version,
listen_protos,
dial_protos,
dial_payload,
}: Test,
) {
let (client_connection, server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
let server = async_std::task::spawn(async move {
let io = match listener_select_proto(server_connection, listen_protos).await {
Ok((_, io)) => io,
Err(NegotiationError::Failed) => return,
Err(NegotiationError::ProtocolError(e)) => {
panic!("Unexpected protocol error {e}")
}
};
match io.complete().await {
Err(NegotiationError::Failed) => {}
_ => panic!(),
}
});
let client =
async_std::task::spawn(async move {
let mut io =
match dialer_select_proto(client_connection, dial_protos.into_iter(), version)
.await
{
Err(NegotiationError::Failed) => return,
Ok((_, io)) => io,
Err(_) => panic!(),
};
io.write_all(&dial_payload).await.unwrap();
match io.complete().await {
Err(NegotiationError::Failed) => {}
_ => panic!(),
}
});
server.await;
client.await;
}
#[derive(Clone)]
struct Test {
version: Version,
listen_protos: Vec<&'static str>,
dial_protos: Vec<&'static str>,
dial_payload: Vec<u8>,
}
let protos = vec![
(vec!["/proto1"], vec!["/proto2"]),
(vec!["/proto1", "/proto2"], vec!["/proto3", "/proto4"]),
];
let payloads = vec![
vec![],
vec![1, 1],
vec![42, 1],
];
for (listen_protos, dial_protos) in protos {
for dial_payload in payloads.clone() {
for &version in &[Version::V1, Version::V1Lazy] {
async_std::task::block_on(run(Test {
version,
listen_protos: listen_protos.clone(),
dial_protos: dial_protos.clone(),
dial_payload: dial_payload.clone(),
}))
}
}
}
}
#[async_std::test]
async fn v1_lazy_do_not_wait_for_negotiation_on_poll_close() {
let (client_connection, _server_connection) = futures_ringbuf::Endpoint::pair(1024 * 1024, 1);
let client = async_std::task::spawn(async move {
let protos = vec!["/proto1"];
let (proto, mut io) =
dialer_select_proto(client_connection, protos.into_iter(), Version::V1Lazy)
.await
.unwrap();
assert_eq!(proto, "/proto1");
io.close().await.unwrap();
});
async_std::future::timeout(Duration::from_secs(10), client)
.await
.unwrap();
}