#![allow(clippy::indexing_slicing)]
use std::{sync::Arc, time::Duration};
use futures::StreamExt;
use tari_shutdown::Shutdown;
use tari_test_utils::unpack_enum;
use tari_utilities::hex::Hex;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::{RwLock, mpsc},
task,
time,
};
use tokio_stream::Stream;
use crate::{
NodeIdentity,
Substream,
framing,
multiplexing::{Control, Yamux},
peer_manager::NodeId,
protocol::{
ProtocolEvent,
ProtocolId,
ProtocolNotification,
rpc,
rpc::{
RpcError,
RpcServer,
RpcServerBuilder,
RpcStatusCode,
context::RpcCommsBackend,
error::HandshakeRejectReason,
handshake::RpcHandshakeError,
server::NamedProtocolService,
test::{
greeting_service::{
GreetingClient,
GreetingRpc,
GreetingServer,
GreetingService,
SayHelloRequest,
SlowGreetingService,
SlowStreamRequest,
},
mock::create_mocked_rpc_context,
},
},
},
test_utils::{node_identity::build_node_identity, transport::build_multiplexed_connections},
};
pub(super) async fn setup_service_with_builder<T: GreetingRpc>(
service_impl: T,
builder: RpcServerBuilder,
) -> (
mpsc::Sender<ProtocolNotification<Substream>>,
task::JoinHandle<()>,
RpcCommsBackend,
Shutdown,
) {
let (notif_tx, notif_rx) = mpsc::channel(10);
let shutdown = Shutdown::new();
let (context, _) = create_mocked_rpc_context();
let server_hnd = task::spawn({
let context = context.clone();
let shutdown_signal = shutdown.to_signal();
async move {
let fut = builder
.finish()
.add_service(GreetingServer::new(service_impl))
.serve(notif_rx, context);
tokio::select! {
biased;
_ = shutdown_signal => {},
r = fut => r.unwrap(),
}
}
});
(notif_tx, server_hnd, context, shutdown)
}
pub(super) async fn setup_service<T: GreetingRpc>(
service_impl: T,
num_concurrent_sessions: usize,
) -> (
mpsc::Sender<ProtocolNotification<Substream>>,
task::JoinHandle<()>,
RpcCommsBackend,
Shutdown,
) {
let builder = RpcServer::builder()
.with_maximum_simultaneous_sessions(num_concurrent_sessions)
.with_minimum_client_deadline(Duration::from_secs(0));
setup_service_with_builder(service_impl, builder).await
}
fn spawn_inbound(
mut inbound: impl Stream<Item = Substream> + Unpin + Send + 'static,
notif_tx: mpsc::Sender<ProtocolNotification<Substream>>,
node_id: NodeId,
) -> task::JoinHandle<()> {
task::spawn(async move {
while let Some(stream) = inbound.next().await {
notif_tx
.send(ProtocolNotification::new(
ProtocolId::from_static(GreetingClient::PROTOCOL_NAME),
ProtocolEvent::NewInboundSubstream(node_id.clone(), stream),
))
.await
.unwrap();
}
})
}
pub(super) async fn setup<T: GreetingRpc>(
service_impl: T,
num_concurrent_sessions: usize,
) -> (Control, Yamux, task::JoinHandle<()>, Arc<NodeIdentity>, Shutdown) {
let (notif_tx, server_hnd, context, shutdown) = setup_service(service_impl, num_concurrent_sessions).await;
let (_, inbound, outbound) = build_multiplexed_connections().await;
let inbound_control = inbound.get_yamux_control();
let node_identity = build_node_identity(Default::default());
let node_id = node_identity.node_id().clone();
spawn_inbound(inbound.into_incoming(), notif_tx.clone(), node_id);
context
.peer_manager()
.add_or_update_peer(node_identity.to_peer())
.await
.unwrap();
(inbound_control, outbound, server_hnd, node_identity, shutdown)
}
#[tokio::test]
async fn request_response_errors_and_streaming() {
let (_inbound, outbound, server_hnd, node_identity, mut shutdown) = setup(GreetingService::default(), 1).await;
let socket = outbound.get_yamux_control().open_stream().await.unwrap();
let framed = framing::canonical(socket, 1024);
let mut client = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.with_deadline_grace_period(Duration::from_secs(5))
.with_handshake_timeout(Duration::from_secs(5))
.connect(framed)
.await
.unwrap();
assert!(client.get_last_request_latency().is_some());
let resp = client
.say_hello(SayHelloRequest {
name: "Yathvan".to_string(),
language: 1,
})
.await
.unwrap();
assert_eq!(resp.greeting, "Jambo Yathvan");
let resp = client.get_greetings(4).await.unwrap();
let greetings = resp.map(|r| r.unwrap()).collect::<Vec<_>>().await;
assert_eq!(greetings, ["Sawubona", "Jambo", "Bonjour", "Hello"]);
let err = client.return_error().await.unwrap_err();
unpack_enum!(RpcError::RequestFailed(status) = err);
assert_eq!(status.as_status_code(), RpcStatusCode::NotImplemented);
assert_eq!(status.details(), "I haven't gotten to this yet :(");
let stream = client.streaming_error("Gurglesplurb".to_string()).await.unwrap();
let status = stream
.collect::<Vec<_>>()
.await
.into_iter()
.collect::<Result<String, _>>()
.unwrap_err();
assert_eq!(status.as_status_code(), RpcStatusCode::BadRequest);
assert_eq!(status.details(), "What does 'Gurglesplurb' mean?");
let stream = client.streaming_error2().await.unwrap();
let results = stream.collect::<Vec<_>>().await;
assert_eq!(results.len(), 2);
let first_reply = results.first().unwrap().as_ref().unwrap();
assert_eq!(first_reply, "This is ok");
let second_reply = results.get(1).unwrap().as_ref().unwrap_err();
assert_eq!(second_reply.as_status_code(), RpcStatusCode::BadRequest);
assert_eq!(second_reply.details(), "This is a problem");
let pk_hex = client.get_public_key_hex().await.unwrap();
assert_eq!(pk_hex, node_identity.public_key().to_hex());
client.close().await;
let err = client
.say_hello(SayHelloRequest {
name: String::new(),
language: 0,
})
.await
.unwrap_err();
match err {
RpcError::ClientClosed | RpcError::RequestCancelled => {},
err => panic!("Unexpected error {err:?}"),
}
shutdown.trigger();
server_hnd.await.unwrap();
}
#[tokio::test]
async fn concurrent_requests() {
let (_inbound, outbound, _, _, _shutdown) = setup(GreetingService::default(), 1).await;
let socket = outbound.get_yamux_control().open_stream().await.unwrap();
let framed = framing::canonical(socket, 1024);
let mut client = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap();
let mut cloned_client = client.clone();
let spawned1 = task::spawn(async move {
cloned_client
.say_hello(SayHelloRequest {
name: "Madeupington".to_string(),
language: 2,
})
.await
.unwrap()
});
let mut cloned_client = client.clone();
let spawned2 = task::spawn(async move {
let resp = cloned_client.get_greetings(5).await.unwrap().collect::<Vec<_>>().await;
resp.into_iter().map(Result::unwrap).collect::<Vec<_>>()
});
let resp = client
.say_hello(SayHelloRequest {
name: "Yathvan".to_string(),
language: 1,
})
.await
.unwrap();
assert_eq!(resp.greeting, "Jambo Yathvan");
assert_eq!(spawned1.await.unwrap().greeting, "Bonjour Madeupington");
assert_eq!(spawned2.await.unwrap(), GreetingService::DEFAULT_GREETINGS[..5]);
}
#[tokio::test]
async fn response_too_big() {
let (_inbound, outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await;
let socket = outbound.get_yamux_control().open_stream().await.unwrap();
let framed = framing::canonical(socket, rpc::max_request_size());
let mut client = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap();
let err = client
.reply_with_msg_of_size(rpc::max_response_payload_size() as u64 - 4)
.await
.unwrap_err();
unpack_enum!(RpcError::RequestFailed(status) = err);
unpack_enum!(RpcStatusCode::MalformedResponse = status.as_status_code());
let _string = client
.reply_with_msg_of_size(rpc::max_response_payload_size() as u64 - 5)
.await
.unwrap();
}
#[tokio::test]
async fn ping_latency() {
let (_inbound, outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await;
let socket = outbound.get_yamux_control().open_stream().await.unwrap();
let framed = framing::canonical(socket, 1024);
let mut client = GreetingClient::builder().connect(framed).await.unwrap();
let latency = client.ping().await.unwrap();
assert!(latency.as_secs() < 5);
}
#[tokio::test]
async fn server_shutdown_before_connect() {
let (_inbound, outbound, _, _, mut shutdown) = setup(GreetingService::new(&[]), 1).await;
let socket = outbound.get_yamux_control().open_stream().await.unwrap();
let framed = framing::canonical(socket, 1024);
shutdown.trigger();
let err = GreetingClient::connect(framed).await.unwrap_err();
assert!(matches!(
err,
RpcError::HandshakeError(RpcHandshakeError::ServerClosedRequest)
));
}
#[tokio::test]
async fn timeout() {
let delay = Arc::new(RwLock::new(Duration::from_secs(10)));
let (_inbound, outbound, _, _, _shutdown) = setup(SlowGreetingService::new(delay.clone()), 1).await;
let socket = outbound.get_yamux_control().open_stream().await.unwrap();
let framed = framing::canonical(socket, 1024);
let mut client = GreetingClient::builder()
.with_deadline(Duration::from_secs(1))
.with_deadline_grace_period(Duration::from_secs(1))
.connect(framed)
.await
.unwrap();
let err = client.say_hello(Default::default()).await.unwrap_err();
unpack_enum!(RpcError::RequestFailed(status) = err);
assert_eq!(status.as_status_code(), RpcStatusCode::Timeout);
*delay.write().await = Duration::from_secs(0);
let resp = client.say_hello(Default::default()).await.unwrap();
assert_eq!(resp.greeting, "took a while to load");
}
#[tokio::test]
async fn unknown_protocol() {
let (notif_tx, _, _, _shutdown) = setup_service(GreetingService::new(&[]), 1).await;
let (_, inbound, mut outbound) = build_multiplexed_connections().await;
let mut in_substream = inbound.get_yamux_control().open_stream().await.unwrap();
in_substream.write_all(b"hello").await.unwrap();
let node_identity = build_node_identity(Default::default());
notif_tx
.send(ProtocolNotification::new(
ProtocolId::from_static(b"this-is-junk"),
ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), in_substream),
))
.await
.unwrap();
let mut out_socket = outbound.incoming_mut().next().await.unwrap();
out_socket.read_exact(&mut [0u8; 5]).await.unwrap();
let framed = framing::canonical(out_socket, 1024);
let err = GreetingClient::connect(framed).await.unwrap_err();
assert!(matches!(
err,
RpcError::HandshakeError(RpcHandshakeError::Rejected(HandshakeRejectReason::ProtocolNotSupported))
));
}
#[tokio::test]
async fn rejected_no_sessions_available() {
let (_inbound, outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 0).await;
let socket = outbound.get_yamux_control().open_stream().await.unwrap();
let framed = framing::canonical(socket, 1024);
let err = GreetingClient::builder().connect(framed).await.unwrap_err();
assert!(matches!(
err,
RpcError::HandshakeError(RpcHandshakeError::Rejected(
HandshakeRejectReason::NoServerSessionsAvailable(_)
))
));
}
#[tokio::test]
async fn stream_still_works_after_cancel() {
let service_impl = GreetingService::default();
let (_inbound, outbound, _, _, _shutdown) = setup(service_impl.clone(), 1).await;
let socket = outbound.get_yamux_control().open_stream().await.unwrap();
let framed = framing::canonical(socket, 1024);
let mut client = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap();
client
.slow_stream(SlowStreamRequest {
num_items: 100,
item_size: 100,
delay_ms: 10,
})
.await
.unwrap();
assert_eq!(service_impl.call_count(), 1);
let resp = client
.slow_stream(SlowStreamRequest {
num_items: 100,
item_size: 100,
delay_ms: 10,
})
.await
.unwrap();
resp.collect::<Vec<_>>().await.into_iter().for_each(|r| {
r.unwrap();
});
}
#[tokio::test]
async fn stream_interruption_handling() {
let service_impl = GreetingService::default();
let (_inbound, outbound, _, _, _shutdown) = setup(service_impl.clone(), 1).await;
let socket = outbound.get_yamux_control().open_stream().await.unwrap();
let framed = framing::canonical(socket, 1024);
let mut client = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap();
let mut resp = client
.slow_stream(SlowStreamRequest {
num_items: 10000,
item_size: 100,
delay_ms: 100,
})
.await
.unwrap();
let _buffer = resp.next().await.unwrap().unwrap();
drop(resp);
let mut resp = client
.slow_stream(SlowStreamRequest {
num_items: 100,
item_size: 100,
delay_ms: 1,
})
.await
.unwrap();
let next_fut = resp.next();
tokio::pin!(next_fut);
time::timeout(Duration::from_secs(10), next_fut)
.await
.unwrap()
.unwrap()
.unwrap();
}
#[tokio::test]
async fn max_global_sessions() {
let builder = RpcServer::builder().with_maximum_simultaneous_sessions(1);
let (muxer, _outbound, context, _shutdown) = setup_service_with_builder(GreetingService::default(), builder).await;
let (_, inbound, outbound) = build_multiplexed_connections().await;
let node_identity = build_node_identity(Default::default());
context
.peer_manager()
.add_or_update_peer(node_identity.to_peer())
.await
.unwrap();
spawn_inbound(inbound.into_incoming(), muxer.clone(), node_identity.node_id().clone());
let socket = outbound.get_yamux_control().open_stream().await.unwrap();
let framed = framing::canonical(socket, 1024);
let mut client = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap();
let socket = outbound.get_yamux_control().open_stream().await.unwrap();
let framed = framing::canonical(socket, 1024);
let err = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap_err();
unpack_enum!(RpcError::HandshakeError(err) = err);
unpack_enum!(
RpcHandshakeError::Rejected(HandshakeRejectReason::NoServerSessionsAvailable(
"session limit reached"
)) = err
);
client.close().await;
let socket = outbound.get_yamux_control().open_stream().await.unwrap();
let framed = framing::canonical(socket, 1024);
let _client = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap();
}
#[tokio::test]
async fn max_per_client_sessions() {
let builder = RpcServer::builder()
.with_maximum_simultaneous_sessions(3)
.with_maximum_sessions_per_client(1)
.with_cull_oldest_peer_rpc_connection_on_full(false);
let (muxer, _outbound, context, _shutdown) = setup_service_with_builder(GreetingService::default(), builder).await;
let (_, inbound, outbound) = build_multiplexed_connections().await;
let node_identity = build_node_identity(Default::default());
context
.peer_manager()
.add_or_update_peer(node_identity.to_peer())
.await
.unwrap();
spawn_inbound(inbound.into_incoming(), muxer.clone(), node_identity.node_id().clone());
let socket = outbound.get_yamux_control().open_stream().await.unwrap();
let framed = framing::canonical(socket, 1024);
let client = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap();
let socket = outbound.get_yamux_control().open_stream().await.unwrap();
let framed = framing::canonical(socket, 1024);
let err = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap_err();
unpack_enum!(RpcError::HandshakeError(err) = err);
unpack_enum!(
RpcHandshakeError::Rejected(HandshakeRejectReason::NoServerSessionsAvailable(
"session limit reached"
)) = err
);
drop(client);
let socket = outbound.get_yamux_control().open_stream().await.unwrap();
let framed = framing::canonical(socket, 1024);
let _client = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap();
}