use core::future::{ready, Ready};
use core::pin::Pin;
use core::str::FromStr;
use core::sync::atomic::{AtomicBool, Ordering};
use core::task::{Context, Poll};
use core::time::Duration;
use std::collections::VecDeque;
use std::io;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::vec::Vec;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time::sleep;
use tokio::time::Instant;
use tracing::trace;
use crate::base::MessageBuilder;
use crate::base::Name;
use crate::base::Rtype;
use crate::base::StaticCompressor;
use crate::base::StreamTarget;
use crate::logging::init_logging;
use crate::net::server::buf::BufSource;
use crate::net::server::message::Request;
use crate::net::server::middleware::mandatory::MandatoryMiddlewareSvc;
use crate::net::server::service::{
CallResult, Service, ServiceError, ServiceFeedback,
};
use crate::net::server::sock::AsyncAccept;
use crate::net::server::stream::StreamServer;
struct MockStream {
last_ready: Mutex<Option<Instant>>,
messages_to_read: Mutex<VecDeque<Vec<u8>>>,
new_message_every: Duration,
pending_responses: usize,
disconnect_with_pending_responses: bool,
}
impl MockStream {
fn new(
messages_to_read: VecDeque<Vec<u8>>,
new_message_every: Duration,
disconnect_with_pending_responses: bool,
) -> Self {
let pending_responses = messages_to_read.len();
Self {
last_ready: Mutex::new(Option::None),
messages_to_read: Mutex::new(messages_to_read),
new_message_every,
pending_responses,
disconnect_with_pending_responses,
}
}
}
impl AsyncRead for MockStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let mut last_ready = self.last_ready.lock().unwrap();
if last_ready
.map(|instant| instant.elapsed() > self.new_message_every)
.unwrap_or(true)
{
let mut messages_to_read = self.messages_to_read.lock().unwrap();
match buf.remaining() {
2 => {
if let Some(next_msg) = messages_to_read.front() {
let next_msg_len =
u16::try_from(next_msg.len()).unwrap();
buf.put_slice(&next_msg_len.to_be_bytes());
last_ready.replace(Instant::now());
return Poll::Ready(Ok(()));
} else {
if self.disconnect_with_pending_responses {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"mock connection premature disconnect",
)));
} else if self.pending_responses == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"mock connection normal disconnect",
)));
}
}
}
_ => {
if let Some(msg) = messages_to_read.pop_front() {
buf.put_slice(&msg);
return Poll::Ready(Ok(()));
}
}
}
}
let waker = cx.waker().clone();
tokio::spawn(async move {
sleep(Duration::from_millis(500)).await;
waker.wake();
});
Poll::Pending
}
}
impl AsyncWrite for MockStream {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
if self.pending_responses > 0 {
self.pending_responses -= 1;
}
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}
struct MockClientConfig {
pub new_message_every: Duration,
pub messages: VecDeque<Vec<u8>>,
pub client_port: u16,
pub disconnect_with_pending_responses: bool,
}
struct MockListener {
ready: Arc<AtomicBool>,
last_accept: Mutex<Option<Instant>>,
streams_to_read: Mutex<VecDeque<MockClientConfig>>,
new_client_every: Duration,
}
impl MockListener {
fn new(
streams_to_read: VecDeque<MockClientConfig>,
new_client_every: Duration,
) -> Self {
Self {
ready: Arc::new(AtomicBool::new(false)),
streams_to_read: Mutex::new(streams_to_read),
last_accept: Mutex::new(Option::None),
new_client_every,
}
}
fn get_ready_flag(&self) -> Arc<AtomicBool> {
self.ready.clone()
}
fn _ready(&self) -> bool {
self.ready.load(Ordering::Relaxed)
}
fn _last_accept(&self) -> Option<Instant> {
*self.last_accept.lock().unwrap()
}
fn streams_remaining(&self) -> usize {
self.streams_to_read.lock().unwrap().len()
}
}
impl AsyncAccept for MockListener {
type Error = io::Error;
type StreamType = MockStream;
type Future = std::future::Ready<Result<Self::StreamType, io::Error>>;
fn poll_accept(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(Self::Future, SocketAddr), io::Error>> {
match self.ready.load(Ordering::Relaxed) {
true => {
let mut last_accept = self.last_accept.lock().unwrap();
if last_accept
.map(|instant| instant.elapsed() > self.new_client_every)
.unwrap_or(true)
{
let mut streams_to_read =
self.streams_to_read.lock().unwrap();
if let Some(MockClientConfig {
new_message_every,
messages,
client_port,
disconnect_with_pending_responses,
}) = streams_to_read.pop_front()
{
last_accept.replace(Instant::now());
return Poll::Ready(Ok((
std::future::ready(Ok(MockStream::new(
messages,
new_message_every,
disconnect_with_pending_responses,
))),
format!("192.168.0.1:{}", client_port)
.parse()
.unwrap(),
)));
} else {
}
} else {
}
}
false => {
}
}
let waker = cx.waker().clone();
tokio::spawn(async move {
sleep(Duration::from_millis(100)).await;
waker.wake();
});
Poll::Pending
}
}
#[derive(Clone)]
struct MockBufSource;
impl BufSource for MockBufSource {
type Output = Vec<u8>;
fn create_buf(&self) -> Self::Output {
vec![0; 1024]
}
fn create_sized(&self, size: usize) -> Self::Output {
vec![0; size]
}
}
struct MySingle {
done: bool,
}
impl MySingle {
fn new() -> MySingle {
Self { done: false }
}
}
impl futures_util::stream::Stream for MySingle {
type Item = Result<CallResult<Vec<u8>>, ServiceError>;
fn poll_next(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.done {
Poll::Ready(None)
} else {
let builder = MessageBuilder::new_stream_vec();
let response = builder.additional();
let command = ServiceFeedback::Reconfigure {
idle_timeout: Some(Duration::from_millis(5000)),
};
let call_result =
CallResult::new(response).with_feedback(command);
self.done = true;
Poll::Ready(Some(Ok(call_result)))
}
}
}
#[derive(Clone)]
struct MyService;
impl MyService {
fn new() -> Self {
Self
}
}
impl Service<Vec<u8>, ()> for MyService
where
Self: Clone + Send + Sync + 'static,
{
type Target = Vec<u8>;
type Stream = MySingle;
type Future = Ready<Self::Stream>;
fn call(&self, request: Request<Vec<u8>, ()>) -> Self::Future {
trace!("Processing request id {}", request.message().header().id());
ready(MySingle::new())
}
}
fn mk_query() -> StreamTarget<Vec<u8>> {
let mut msg = MessageBuilder::from_target(StaticCompressor::new(
StreamTarget::new_vec(),
))
.unwrap();
msg.header_mut().set_rd(true);
msg.header_mut().set_random_id();
let mut msg = msg.question();
msg.push((Name::<Vec<u8>>::from_str("example.com.").unwrap(), Rtype::A))
.unwrap();
let mut msg = msg.additional();
msg.opt(|opt| {
opt.set_udp_payload_size(4096);
Ok(())
})
.unwrap();
msg.finish().into_target()
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn tcp_service_test() {
init_logging();
let (srv_handle, server_status_printer_handle) = {
let fast_client = MockClientConfig {
new_message_every: Duration::from_millis(100),
messages: VecDeque::from([
mk_query().as_dgram_slice().to_vec(),
mk_query().as_dgram_slice().to_vec(),
mk_query().as_dgram_slice().to_vec(),
mk_query().as_dgram_slice().to_vec(),
mk_query().as_dgram_slice().to_vec(),
]),
client_port: 1,
disconnect_with_pending_responses: false,
};
let slow_client = MockClientConfig {
new_message_every: Duration::from_millis(3000),
messages: VecDeque::from([
mk_query().as_dgram_slice().to_vec(),
mk_query().as_dgram_slice().to_vec(),
]),
client_port: 2,
disconnect_with_pending_responses: false,
};
let num_messages =
fast_client.messages.len() + slow_client.messages.len();
let streams_to_read = VecDeque::from([fast_client, slow_client]);
let new_client_every = Duration::from_millis(2000);
let listener = MockListener::new(streams_to_read, new_client_every);
let ready_flag = listener.get_ready_flag();
let buf = MockBufSource;
let my_service =
Arc::new(MandatoryMiddlewareSvc::new(MyService::new()));
let srv =
Arc::new(StreamServer::new(listener, buf, my_service.clone()));
let metrics = srv.metrics();
let server_status_printer_handle = tokio::spawn(async move {
loop {
sleep(Duration::from_millis(250)).await;
eprintln!(
"Server status: #conn={:?}, #in-flight={}, #pending-writes={}, #msgs-recvd={}, #msgs-sent={}",
metrics.num_connections(),
metrics.num_inflight_requests(),
metrics.num_pending_writes(),
metrics.num_received_requests(),
metrics.num_sent_responses(),
);
}
});
let spawned_srv = srv.clone();
let srv_handle = tokio::spawn(async move { spawned_srv.run().await });
eprintln!("Clients sleeping");
sleep(Duration::from_secs(1)).await;
eprintln!("Clients connecting");
ready_flag.store(true, Ordering::Relaxed);
sleep(Duration::from_secs(20)).await;
assert_eq!(0, srv.source().streams_remaining());
assert_eq!(srv.metrics().num_connections(), 0);
assert_eq!(srv.metrics().num_inflight_requests(), 0);
assert_eq!(srv.metrics().num_pending_writes(), 0);
assert_eq!(srv.metrics().num_received_requests(), num_messages);
assert_eq!(srv.metrics().num_sent_responses(), num_messages);
eprintln!("Shutting down");
srv.shutdown().unwrap();
eprintln!("Shutdown command sent");
(srv_handle, server_status_printer_handle)
};
eprintln!("Waiting for service to shutdown");
let _ = srv_handle.await;
server_status_printer_handle.abort();
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn tcp_client_disconnect_test() {
init_logging();
let (srv_handle, server_status_printer_handle) = {
let fast_client = MockClientConfig {
new_message_every: Duration::from_millis(100),
messages: VecDeque::from([
mk_query().as_dgram_slice().to_vec(),
mk_query().as_dgram_slice().to_vec(),
mk_query().as_dgram_slice().to_vec(),
mk_query().as_dgram_slice().to_vec(),
mk_query().as_dgram_slice().to_vec(),
]),
client_port: 1,
disconnect_with_pending_responses: true,
};
let slow_client = MockClientConfig {
new_message_every: Duration::from_millis(3000),
messages: VecDeque::from([
mk_query().as_dgram_slice().to_vec(),
mk_query().as_dgram_slice().to_vec(),
]),
client_port: 2,
disconnect_with_pending_responses: false,
};
let num_messages =
fast_client.messages.len() + slow_client.messages.len();
let streams_to_read = VecDeque::from([fast_client, slow_client]);
let new_client_every = Duration::from_millis(2000);
let listener = MockListener::new(streams_to_read, new_client_every);
let ready_flag = listener.get_ready_flag();
let buf = MockBufSource;
let my_service = Arc::new(MyService::new());
let srv =
Arc::new(StreamServer::new(listener, buf, my_service.clone()));
let metrics = srv.metrics();
let server_status_printer_handle = tokio::spawn(async move {
loop {
sleep(Duration::from_millis(250)).await;
eprintln!(
"Server status: #conn={:?}, #in-flight={}, #pending-writes={}, #msgs-recvd={}, #msgs-sent={}",
metrics.num_connections(),
metrics.num_inflight_requests(),
metrics.num_pending_writes(),
metrics.num_received_requests(),
metrics.num_sent_responses(),
);
}
});
let spawned_srv = srv.clone();
let srv_handle = tokio::spawn(async move { spawned_srv.run().await });
eprintln!("Clients sleeping");
sleep(Duration::from_secs(1)).await;
eprintln!("Clients connecting");
ready_flag.store(true, Ordering::Relaxed);
sleep(Duration::from_secs(20)).await;
assert_eq!(0, srv.source().streams_remaining());
assert_eq!(srv.metrics().num_connections(), 0);
assert_eq!(srv.metrics().num_inflight_requests(), 0);
assert_eq!(srv.metrics().num_pending_writes(), 0);
assert_eq!(srv.metrics().num_received_requests(), num_messages);
assert!(srv.metrics().num_sent_responses() < num_messages);
eprintln!("Shutting down");
srv.shutdown().unwrap();
eprintln!("Shutdown command sent");
(srv_handle, server_status_printer_handle)
};
eprintln!("Waiting for service to shutdown");
let _ = srv_handle.await;
server_status_printer_handle.abort();
}