use std::{collections::HashSet, fmt::Debug, string::FromUtf8Error, thread::sleep, time::Duration};
use tubes::ClientId;
use tubes::prelude::*;
#[derive(Clone, Debug, PartialEq)]
pub struct Msg(pub String);
impl From<String> for Msg {
fn from(value: String) -> Self {
Self(value)
}
}
impl TryFrom<&[u8]> for Msg {
type Error = FromUtf8Error;
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> {
Ok(Msg(String::from_utf8(value.to_vec())?))
}
}
impl TryFrom<Msg> for Vec<u8> {
type Error = ();
fn try_from(value: Msg) -> std::result::Result<Self, Self::Error> {
Ok(value.0.into())
}
}
pub struct TestSession {
server: Session,
nodes: Vec<Session>,
promoting: Option<ClientId>,
}
impl TestSession {
pub fn new_pair(port: u16) -> Self {
let mut res = Self {
server: Self::new_server(port),
nodes: vec![Self::new_client(port)],
promoting: None,
};
res.connect();
res
}
pub fn new_triple(port: u16) -> Self {
let mut res = Self {
server: Self::new_server(port),
nodes: vec![Self::new_client(port), Self::new_client(port)],
promoting: None,
};
res.connect();
res
}
pub fn new_quad(port: u16) -> Self {
let mut res = Self {
server: Self::new_server(port),
nodes: vec![
Self::new_client(port),
Self::new_client(port),
Self::new_client(port),
],
promoting: None,
};
res.connect();
res
}
#[cfg(all(feature = "encryption", feature = "compression"))]
pub fn new_pair_pack(port: u16) -> Self {
let mut res = Self {
server: Self::new_server_pack(port),
nodes: vec![Self::new_client_pack(port)],
promoting: None,
};
res.connect();
res
}
#[cfg(all(feature = "encryption", feature = "compression"))]
pub fn new_triple_pack(port: u16) -> Self {
let mut res = Self {
server: Self::new_server_pack(port),
nodes: vec![Self::new_client_pack(port), Self::new_client_pack(port)],
promoting: None,
};
res.connect();
res
}
#[cfg(all(feature = "encryption", feature = "compression"))]
pub fn new_quad_pack(port: u16) -> Self {
let mut res = Self {
server: Self::new_server_pack(port),
nodes: vec![
Self::new_client_pack(port),
Self::new_client_pack(port),
Self::new_client_pack(port),
],
promoting: None,
};
res.connect();
res
}
pub fn stop(&mut self) {
for c in self.nodes.iter_mut() {
c.stop();
}
self.server.stop();
}
pub fn server_uuid(&self) -> ClientId {
self.server.uuid()
}
pub fn clients(&self) -> HashSet<ClientId> {
self.server.clients()
}
pub fn client_leave(&mut self, uuid: ClientId) {
self.nodes.retain(|c| c.uuid() != uuid);
}
pub fn exhaust(&mut self) {
sleep(Duration::from_millis(100));
if let Some(p) = self.promoting {
for c in self.nodes.iter_mut() {
if p == c.uuid() {
while c.read().unwrap().is_some() {}
self.promoting = None;
}
}
}
while self.server.read().unwrap().is_some() {}
for c in self.nodes.iter_mut() {
while c.read().unwrap().is_some() {}
}
}
pub fn server_broadcast(&mut self, m: Msg) {
self.server.broadcast(m.try_into().unwrap()).unwrap();
}
pub fn server_send_to(&mut self, uuid: ClientId, m: Msg) {
self.server.send_to(uuid, m.try_into().unwrap()).unwrap();
}
pub fn first_client_send_to_server(&mut self, m: Msg) -> Option<ClientId> {
let c = self.nodes.iter_mut().next()?;
c.send_to(self.server.uuid(), m.try_into().unwrap())
.unwrap();
Some(c.uuid())
}
pub fn first_client_broadcast(&mut self, m: Msg) -> Option<ClientId> {
let c = self.nodes.iter_mut().next()?;
c.broadcast(m.try_into().unwrap()).unwrap();
Some(c.uuid())
}
pub fn promote_to_host(&mut self, client: ClientId, port: u16) {
println!(
"TEST /// Before migration, server={}, clients={:?}",
self.server.uuid(),
self.server.clients()
);
self.promoting = Some(client);
self.server.promote_to_host(client, Some(port));
self.exhaust();
self.assert_is_server(client);
let promoted_client = self.nodes.iter_mut().find(|c| c.uuid() == client).unwrap();
std::mem::swap(&mut self.server, promoted_client);
std::thread::sleep(Duration::from_millis(100));
println!(
"TEST /// After migration, server={}, clients={:?}",
self.server.uuid(),
self.server.clients()
);
}
pub fn assert_server_received_broacast(&mut self, from: ClientId, m: Msg) {
let m = MessageData::Broadcast {
from,
data: m.try_into().unwrap(),
};
retry(move || Self::assert_received_message(&mut self.server, &m)).unwrap();
}
pub fn assert_all_clients_received_broacast(&mut self, from: ClientId, m: Msg) {
let m = MessageData::Broadcast {
from,
data: m.try_into().unwrap(),
};
for c in self.nodes.iter_mut() {
retry(|| Self::assert_received_message(c, &m)).unwrap();
}
}
pub fn assert_only_client_received(&mut self, from: ClientId, uuid: ClientId, m: Msg) {
let m = MessageData::Send {
from,
to: uuid,
data: m.try_into().unwrap(),
};
for c in self.nodes.iter_mut() {
if c.uuid() == uuid {
retry(|| Self::assert_received_message(c, &m)).unwrap();
}
}
sleep(Duration::from_millis(100));
for c in self.nodes.iter_mut() {
if c.uuid() != uuid {
assert!(c.read().unwrap().is_none());
}
}
}
pub fn assert_only_server_received(&mut self, from: ClientId, m: Msg) {
let m = MessageData::Send {
from,
to: self.server.uuid(),
data: m.try_into().unwrap(),
};
retry(|| Self::assert_received_message(&mut self.server, &m)).unwrap();
sleep(Duration::from_millis(100));
for c in self.nodes.iter_mut() {
assert!(c.read().unwrap().is_none());
}
}
pub fn assert_client_left(&mut self, uuid: ClientId) {
let m = MessageData::ClientLeft(uuid);
retry(|| Self::assert_received_message(&mut self.server, &m)).unwrap();
for c in self.nodes.iter_mut() {
retry(|| Self::assert_received_message(c, &m)).unwrap();
}
retry(|| {
if self.server.clients().contains(&uuid) {
return Err(format!("Client has not left {}", uuid));
}
Ok(())
})
.unwrap();
}
pub fn assert_is_server(&self, uuid: ClientId) {
retry(|| {
let client_is_server = self
.nodes
.iter()
.find(|c| c.uuid() == uuid)
.map(|c| c.is_server())
.unwrap_or(false);
let server_is_server = uuid == self.server.uuid();
if !server_is_server && !client_is_server {
return Err(format!("{} is not server", uuid));
}
Ok(())
})
.unwrap();
}
fn connect(&mut self) {
self.server.start().unwrap();
assert!(self.server.is_connected());
for c in self.nodes.iter_mut() {
c.start().unwrap();
assert!(c.is_connected());
}
retry(|| {
let len1 = self.server.clients().len();
let len2 = self.nodes.len();
if len1 != len2 {
return Err(format!("Client lengths not matching: {} != {}", len1, len2));
}
for uuid in self.nodes.iter().map(|n| n.uuid()) {
if !self.server.clients().contains(&uuid) {
return Err(format!("Client not in server list: {}", uuid));
}
}
Ok(())
})
.unwrap();
let server_uuid = self.server_uuid();
for c in self.nodes.iter_mut() {
retry(|| {
let Some(suuid) = c.server_uuid() else {
return Err(format!(
"Client {} did not receive the server uuid",
c.uuid()
));
};
if server_uuid != suuid {
return Err(format!(
"Client {} did receive the wrong server uuid {} != {}",
c.uuid(),
server_uuid,
suuid
));
}
Ok(())
})
.unwrap();
}
for uuid in self.clients().iter() {
let m = MessageData::ClientJoined(*uuid);
for c in self.nodes.iter_mut() {
if c.uuid() != *uuid {
retry(|| Self::assert_received_message(c, &m)).unwrap();
}
}
}
}
fn new_server(port: u16) -> Session {
let s = Session::new_server(format!("127.0.0.1:{}", port).as_str().into());
assert!(s.is_server());
assert!(!s.is_connected());
s
}
fn new_client(port: u16) -> Session {
let s = Session::new_client(format!("127.0.0.1:{}", port).as_str().into());
assert!(!s.is_server());
assert!(!s.is_connected());
s
}
#[cfg(all(feature = "encryption", feature = "compression"))]
fn new_server_pack(port: u16) -> Session {
let config = Config {
address: "127.0.0.1".parse().ok(),
port,
versions: Default::default(),
accept_timeout: Default::default(),
compress: true,
key: Some(Self::key()),
};
Session::new_server(config)
}
#[cfg(all(feature = "encryption", feature = "compression"))]
fn new_client_pack(port: u16) -> Session {
let config = Config {
address: "127.0.0.1".parse().ok(),
port,
versions: Default::default(),
accept_timeout: Default::default(),
compress: true,
key: Some(Self::key()),
};
Session::new_client(config)
}
#[cfg(feature = "encryption")]
fn key() -> Vec<u8> {
vec![0; 32]
}
fn assert_received_message(s: &mut Session, m: &MessageData) -> Result<(), String> {
let m = Self::debug_message(m);
let msg = s.read().unwrap().ok_or(format!(
"No message available on {}, expected {}",
s.uuid(),
m
))?;
let msg = Self::debug_message(&msg);
if msg != m {
return Err(format!(
"Wrong message, expected: \n{}\n returned: \n{}",
m, msg
));
}
Ok(())
}
fn debug_message(m: &MessageData) -> String {
match m {
MessageData::Broadcast { from, data: m } => format!("Broadcast({from}, {m:?})"),
MessageData::Send {
from,
to: uuid,
data: m,
} => format!("Send({from}, {uuid}, {m:?})"),
MessageData::ClientJoined(uuid) => format!("ClientJoined({uuid})"),
MessageData::ClientLeft(uuid) => format!("ClientLeft({uuid})"),
}
}
}
impl Drop for TestSession {
fn drop(&mut self) {
self.stop();
}
}
fn retry<E, F: FnMut() -> Result<(), E>>(mut f: F) -> Result<(), E> {
let mut ct = 0;
loop {
if let Err(e) = f() {
if ct >= 20 {
return Err(e);
}
} else {
return Ok(());
}
sleep(Duration::from_millis(100));
ct += 1;
}
}