#[macro_use]
extern crate proptest_state_machine;
use std::collections::{HashMap, HashSet};
use std::thread;
use proptest::prelude::*;
use proptest::test_runner::Config;
use proptest_state_machine::{ReferenceStateMachine, StateMachineTest};
use system_under_test::{
init_client, init_server, run_client, run_server, ClientDialer, Msg,
ServerDialer, Transport,
};
prop_state_machine! {
#![proptest_config(Config {
// Turn failure persistence off for demonstration. This means that no
// regression file will be captured.
failure_persistence: None,
// Enable verbose mode to make the state machine test print the
// transitions for each case.
verbose: 1,
// Only run 10 cases by default to avoid running out of system resources
// and taking too long to finish.
cases: 10,
.. Config::default()
})]
fn run_echo_server_test(
sequential
1..20
=>
EchoServerTest
);
}
fn main() {
run_echo_server_test();
}
#[derive(Clone, Debug)]
struct RefState {
is_server_up: bool,
clients: HashSet<ClientId>,
transport: Transport,
}
#[derive(Clone, Debug)]
enum Transition {
StartServer,
StopServer,
StartClient(ClientId),
StopClient(ClientId),
ClientMsg(ClientId, Msg),
}
#[derive(Default)]
struct EchoServerTest {
server: Option<TestServer>,
clients: HashMap<ClientId, TestClient>,
}
struct TestServer {
dialer: ServerDialer,
listener_handle: thread::JoinHandle<()>,
}
struct TestClient {
dialer: ClientDialer,
listener_handle: std::thread::JoinHandle<()>,
msgs_recv: std::sync::mpsc::Receiver<Msg>,
}
type ClientId = usize;
impl ReferenceStateMachine for RefState {
type State = RefState;
type Transition = Transition;
fn init_state() -> BoxedStrategy<Self::State> {
prop_oneof![
Just(Transport::Tcp),
Just(Transport::FramedTcp),
Just(Transport::Udp),
Just(Transport::Ws),
]
.prop_map(|transport| Self {
is_server_up: false,
clients: HashSet::default(),
transport,
})
.boxed()
}
fn transitions(state: &Self::State) -> BoxedStrategy<Self::Transition> {
use Transition::*;
if state.clients.is_empty() {
prop_oneof![
Just(StartServer),
Just(StopServer),
(0..32_usize).prop_map(StartClient),
]
.boxed()
} else {
let ids: Vec<_> = state.clients.iter().cloned().collect();
let arb_id = proptest::sample::select(ids);
prop_oneof![
Just(StartServer),
Just(StopServer),
(0..32_usize).prop_map(StartClient),
arb_id.clone().prop_map(StopClient),
arb_id.prop_flat_map(|id| arb_msg_from_client()
.prop_map(move |msg| { ClientMsg(id, msg) })),
]
.boxed()
}
}
fn apply(
mut state: Self::State,
transition: &Self::Transition,
) -> Self::State {
match transition {
Transition::StartServer => {
state.is_server_up = true;
}
Transition::StopServer => {
state.is_server_up = false;
state.clients = Default::default();
}
Transition::StartClient(id) => {
state.clients.insert(*id);
}
Transition::StopClient(id) => {
state.clients.remove(id);
}
Transition::ClientMsg(_id, _msg) => {
}
}
state
}
fn preconditions(
state: &Self::State,
transition: &Self::Transition,
) -> bool {
match transition {
Transition::StartServer => !state.is_server_up,
Transition::StopServer => state.is_server_up,
Transition::StartClient(id) => {
state.is_server_up && !state.clients.contains(id)
}
Transition::StopClient(id) => {
state.clients.contains(id)
}
Transition::ClientMsg(id, _) => {
state.is_server_up && state.clients.contains(id)
}
}
}
}
fn arb_msg_from_client() -> impl Strategy<Value = Msg> {
"[a-z0-9]{1,8}"
}
impl StateMachineTest for EchoServerTest {
type SystemUnderTest = Self;
type Reference = RefState;
fn init_test(
_ref_state: &<Self::Reference as ReferenceStateMachine>::State,
) -> Self::SystemUnderTest {
Self::default()
}
fn apply(
mut state: Self::SystemUnderTest,
ref_state: &<Self::Reference as ReferenceStateMachine>::State,
transition: <Self::Reference as ReferenceStateMachine>::Transition,
) -> Self::SystemUnderTest {
match transition {
Transition::StartServer => {
let (dialer, listener) =
init_server(ref_state.transport, "127.0.0.1:0");
let listener_handle =
thread::spawn(move || run_server(listener));
state.server = Some(TestServer {
dialer,
listener_handle,
})
}
Transition::StopServer => {
let server = state.server.take().unwrap();
server.dialer.handler.stop();
server.listener_handle.join().unwrap();
if !state.clients.is_empty() {
println!(
"The server is waiting for all the clients to \
stop..."
);
for (id, client) in
std::mem::take(&mut state.clients).into_iter()
{
client.dialer.handler.stop();
println!("Asking client {} listener to stop.", id);
client.listener_handle.join().unwrap();
println!("Client {} listener stopped.", id);
}
println!("All clients have stopped.");
}
}
Transition::StartClient(id) => {
let server_addr = state.server.as_ref().unwrap().dialer.address;
let (listener, dialer) =
init_client(ref_state.transport, server_addr);
let (msgs_send, msgs_recv) = std::sync::mpsc::channel();
let listener_handle = std::thread::spawn(move || {
run_client(listener, |msg| {
msgs_send.send(msg).unwrap();
})
});
state.clients.insert(
id,
TestClient {
dialer,
listener_handle,
msgs_recv,
},
);
}
Transition::StopClient(id) => {
let client = state.clients.remove(&id).unwrap();
client.dialer.handler.stop();
client.listener_handle.join().unwrap();
}
Transition::ClientMsg(id, msg) => {
let client = state.clients.get_mut(&id).unwrap();
system_under_test::msg_server_wrong(&mut client.dialer, &msg);
println!("Waiting for server response.");
println!(
"WARN: Because we're using a blocking call here, this will \
halt when the message gets lost when `msg_server_wrong` is used."
);
let recv_msg = client.msgs_recv.recv().unwrap();
assert_eq!(recv_msg, msg)
}
}
state
}
}
mod system_under_test {
pub use message_io::network::Transport;
use message_io::network::{Endpoint, NetEvent, ResourceId, ToRemoteAddr};
use message_io::node::{self, NodeEvent, NodeHandler, NodeListener};
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::atomic::{self, AtomicBool};
use std::sync::Arc;
const ATOMIC_ORDER: atomic::Ordering = atomic::Ordering::SeqCst;
pub type Msg = String;
pub struct ServerListener {
pub listener: NodeListener<()>,
pub handler: NodeHandler<()>,
}
pub struct ServerDialer {
pub address: SocketAddr,
pub resource_id: ResourceId,
pub handler: NodeHandler<()>,
}
pub struct ClientListener {
pub address: SocketAddr,
pub listener: NodeListener<()>,
pub server: Endpoint,
pub handler: NodeHandler<()>,
pub is_connected: Arc<AtomicBool>,
}
pub struct ClientDialer {
pub address: SocketAddr,
pub server: Endpoint,
pub handler: NodeHandler<()>,
pub is_connected: Arc<AtomicBool>,
}
pub fn init_server(
transport: Transport,
addr: impl ToSocketAddrs,
) -> (ServerDialer, ServerListener) {
let (handler, listener) = node::split::<()>();
let (resource_id, address) =
handler.network().listen(transport, addr).unwrap();
println!("Server is running at {address} with {transport}.");
(
ServerDialer {
address,
resource_id,
handler: handler.clone(),
},
ServerListener { listener, handler },
)
}
pub fn run_server(listener: ServerListener) {
let ServerListener { listener, handler } = listener;
listener.for_each(move |event| match event.network() {
NetEvent::Connected(_, _) => (), NetEvent::Accepted(endpoint, _resource_id) => {
println!("Client ({}) connected.", endpoint.addr(),);
}
NetEvent::Message(endpoint, msg_bytes) => {
let message: Msg =
String::from_utf8(msg_bytes.to_vec()).unwrap();
println!("Server received a message \"{message}\".");
handler.network().send(endpoint, msg_bytes);
}
NetEvent::Disconnected(endpoint) => {
println!("Client ({}) disconnected.", endpoint.addr());
}
});
}
pub fn init_client(
transport: Transport,
remote_addr: impl ToRemoteAddr,
) -> (ClientListener, ClientDialer) {
let (handler, listener) = node::split();
let (server, address) =
handler.network().connect(transport, remote_addr).unwrap();
let is_connected = Arc::new(AtomicBool::new(false));
(
ClientListener {
address,
server,
handler: handler.clone(),
listener,
is_connected: is_connected.clone(),
},
ClientDialer {
address,
server,
handler,
is_connected,
},
)
}
pub fn run_client(listener: ClientListener, mut on_msg: impl FnMut(Msg)) {
let ClientListener {
address,
server,
handler,
listener,
is_connected,
} = listener;
listener.for_each(move |event| match event {
NodeEvent::Network(net_event) => match net_event {
NetEvent::Connected(_, established) => {
if established {
println!(
"Client identified by local port: {}.",
address.port()
);
} else {
println!("Cannot connect to server at {server}.")
}
is_connected.store(established, ATOMIC_ORDER);
}
NetEvent::Accepted(_, _) => unreachable!(), NetEvent::Message(_, msg_bytes) => {
let message: Msg =
String::from_utf8(msg_bytes.to_vec()).unwrap();
on_msg(message);
}
NetEvent::Disconnected(_) => {
println!("Server is disconnected.");
is_connected.store(false, ATOMIC_ORDER);
handler.stop();
}
},
NodeEvent::Signal(()) => {
}
});
}
pub fn msg_server_wrong(dialer: &mut ClientDialer, msg: &Msg) {
let output_data = msg.as_bytes();
dialer.handler.network().send(dialer.server, output_data);
}
#[allow(dead_code)]
pub fn msg_server(dialer: &mut ClientDialer, msg: &Msg) {
let output_data = msg.as_bytes();
while !dialer.is_connected.load(ATOMIC_ORDER) {
println!("Waiting for the server to be ready.");
}
dialer.handler.network().send(dialer.server, &output_data);
}
}