use crate::client_id::ClientId;
use crate::{data::MessageDataInternal, prelude::*};
mod transport;
use pipenet::{NonBlockStream, Packs};
use std::{
collections::{HashMap, HashSet},
io::{Error, ErrorKind},
net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream},
sync::{
Arc, Mutex,
mpsc::{Receiver, Sender, channel},
},
thread::JoinHandle,
time::{Duration, Instant},
};
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
pub struct Session<const MS: usize = 0> {
kind: SessionKind,
config: Config,
accept_routine: Option<JoinHandle<()>>,
io_routine: Option<JoinHandle<()>>,
uuid: ClientId,
server_uuid: Option<ClientId>,
clients: Arc<Mutex<HashMap<ClientId, NonBlockStream<MS>>>>,
tx_writer: Option<Sender<MessageDataInternal>>,
rx_reader: Option<Arc<Mutex<Receiver<MessageData>>>>,
reconnect_to: Arc<Mutex<Option<ReconnectTo>>>,
}
impl<const MS: usize> Session<MS> {
#[must_use]
pub fn new_server(config: Config) -> Self {
Self {
kind: SessionKind::Server,
config,
accept_routine: None,
io_routine: None,
uuid: ClientId::new(),
server_uuid: None,
clients: Arc::default(),
tx_writer: None,
rx_reader: None,
reconnect_to: Mutex::new(None).into(),
}
}
#[must_use]
pub fn new_client(config: Config) -> Self {
Self {
kind: SessionKind::Client,
config,
accept_routine: None,
io_routine: None,
uuid: ClientId::new(),
server_uuid: None,
clients: Arc::default(),
tx_writer: None,
rx_reader: None,
reconnect_to: Mutex::new(None).into(),
}
}
pub fn start(&mut self) -> Result<()> {
match self.kind {
SessionKind::Server => self.start_server(),
SessionKind::Client => self.start_client(),
}
}
pub fn stop(&mut self) {
self.accept_routine = None;
self.io_routine = None;
self.rx_reader = None;
self.tx_writer = None;
}
#[must_use]
pub fn is_server(&self) -> bool {
self.kind == SessionKind::Server
}
#[must_use]
pub fn is_connected(&self) -> bool {
match self.kind {
SessionKind::Server => {
if let Some(h) = &self.accept_routine
&& !h.is_finished()
&& let Some(h) = &self.io_routine
&& !h.is_finished()
{
return true;
}
false
}
SessionKind::Client => {
if let Some(h) = &self.io_routine
&& !h.is_finished()
{
return true;
}
false
}
}
}
#[must_use]
pub fn server_uuid(&self) -> Option<ClientId> {
if self.is_server() {
Some(self.uuid)
} else {
self.server_uuid
}
}
#[must_use]
pub fn uuid(&self) -> ClientId {
self.uuid
}
#[must_use]
pub fn clients(&self) -> HashSet<ClientId> {
let Ok(lock) = self.clients.lock() else {
return HashSet::new();
};
lock.keys().copied().collect::<HashSet<ClientId>>()
}
pub fn read(&mut self) -> Result<Option<MessageData>> {
if !self.is_connected() {
return Err(Error::new(ErrorKind::NotConnected, "not connected").into());
}
let Some(c) = self.rx_reader.as_mut() else {
return Ok(None);
};
let Ok(c) = c.lock() else {
return Ok(None);
};
let msg = c.try_recv().ok();
drop(c);
if msg.is_none() {
self.check_reconnect_to()?;
}
Ok(msg)
}
pub fn send_to(&mut self, uuid: ClientId, m: Vec<u8>) -> Result<()> {
if !self.is_connected() {
return Err(Error::new(ErrorKind::NotConnected, "not connected").into());
}
self.check_reconnect_to()?;
let Some(c) = self.tx_writer.as_mut() else {
return Ok(());
};
let _ = c.send(MessageDataInternal::Send(self.uuid, uuid, m));
Ok(())
}
pub fn broadcast(&mut self, m: Vec<u8>) -> Result<()> {
if !self.is_connected() {
return Err(Error::new(ErrorKind::NotConnected, "not connected").into());
}
self.check_reconnect_to()?;
let Some(c) = self.tx_writer.as_mut() else {
return Ok(());
};
let _ = c.send(MessageDataInternal::Broadcast(self.uuid, m));
Ok(())
}
pub fn promote_to_host(&mut self, uuid: ClientId, port: Option<u16>) {
if !self.is_server() {
return;
}
let Some(c) = self.tx_writer.as_mut() else {
return;
};
let Ok(map) = self.clients.lock() else {
return;
};
let Some(client) = map.get(&uuid) else {
return;
};
let addr = client.remote_addr().ip();
let port = port.unwrap_or(self.config.port);
let msg = MessageDataInternal::PromoteToHost(uuid, addr, port);
let _ = c.send(msg);
}
#[must_use]
pub fn total_read(&self) -> usize {
let mut t = 0;
let Ok(map) = self.clients.lock() else {
return 0;
};
for (_, c) in map.iter() {
t += c.total_read();
}
t
}
#[must_use]
pub fn total_sent(&self) -> usize {
let mut t = 0;
let Ok(map) = self.clients.lock() else {
return 0;
};
for (_, c) in map.iter() {
t += c.total_sent();
}
t
}
fn start_server(&mut self) -> Result<()> {
if self.is_connected() {
return Ok(());
}
let addr = self
.config
.address
.unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED));
let addr = SocketAddr::from((addr, self.config.port));
let server_uuid = self.uuid;
let listener = TcpListener::bind(addr)?;
let accept_timeout = self.config.accept_timeout;
let client_list = self.clients.clone();
let config = self.config.clone();
self.accept_routine = Some(std::thread::spawn(move || {
transport::accept_loop(server_uuid, &config, &client_list, &listener, accept_timeout);
}));
let server_uuid = self.uuid;
let client_list = self.clients.clone();
let reconnect_to = self.reconnect_to.clone();
let (tx_reader, rx_reader) = channel();
let (tx_writer, rx_writer) = channel();
self.tx_writer = Some(tx_writer);
self.rx_reader = Some(Mutex::new(rx_reader).into());
self.io_routine = Some(std::thread::spawn(move || {
transport::server_loop::<MS>(
server_uuid,
&client_list,
&reconnect_to,
&rx_writer,
&tx_reader,
);
}));
Ok(())
}
fn start_client(&mut self) -> Result<()> {
if self.is_connected() {
return Ok(());
}
let Some(addr) = self.config.address else {
return Ok(());
};
let addr = SocketAddr::from((addr, self.config.port));
let socket = connect_with_retry_and_wait(addr)?;
socket.set_nonblocking(true)?;
let mut socket = to_pipenet::<MS>(socket, &self.config);
socket.write((MessageDataInternal::ClientJoined(self.uuid)).try_into()?)?;
let Some(server_uuid) =
wait_for_server_uuid_message(self.config.accept_timeout, &mut socket)?
else {
return Err("Could not connect to server: did not receive server uuid.".into());
};
self.server_uuid = Some(server_uuid);
if let Ok(mut map) = self.clients.lock() {
map.clear();
map.insert(self.uuid, socket.clone());
}
let reconnect_to = self.reconnect_to.clone();
let (tx_reader, rx_reader) = channel();
let (tx_writer, rx_writer) = channel();
self.tx_writer = Some(tx_writer);
self.rx_reader = Some(Mutex::new(rx_reader).into());
self.io_routine = Some(std::thread::spawn(move || {
transport::client_loop::<MS>(socket, &reconnect_to, &rx_writer, &tx_reader);
}));
Ok(())
}
fn check_reconnect_to(&mut self) -> Result<()> {
let Ok(mut reconnect_to) = self.reconnect_to.lock() else {
return Ok(());
};
let Some(ref to) = *reconnect_to else {
return Ok(());
};
let server = to.become_server;
let address = to.address;
let port = to.port;
*reconnect_to = None;
drop(reconnect_to);
self.stop();
self.config.address = Some(address);
self.config.port = port;
self.kind = if server {
SessionKind::Server
} else {
SessionKind::Client
};
self.start()
}
}
impl<const MS: usize> Drop for Session<MS> {
fn drop(&mut self) {
if let Some(c) = self.tx_writer.as_ref() {
let _ = c.send(MessageDataInternal::ClientLeft(self.uuid));
}
self.stop();
}
}
#[derive(Default, PartialEq)]
enum SessionKind {
#[default]
Server,
Client,
}
fn connect_with_retry_and_wait(addr: SocketAddr) -> Result<TcpStream> {
let mut ct = 0;
loop {
match TcpStream::connect(addr) {
Ok(stream) => return Ok(stream),
Err(e) => {
if ct > 10 {
return Err(e.into());
}
std::thread::sleep(Duration::from_millis(100));
ct += 1;
}
}
}
}
pub(crate) struct ReconnectTo {
pub(crate) become_server: bool,
pub(crate) address: IpAddr,
pub(crate) port: u16,
}
pub(crate) fn to_pipenet<const MS: usize>(
stream: TcpStream,
config: &Config,
) -> NonBlockStream<MS> {
#[allow(unused_mut)]
let mut packs = Packs::default();
#[cfg(feature = "compression")]
if config.compress {
packs = packs.compress();
}
#[cfg(feature = "encryption")]
if let Some(key) = config.key.as_ref() {
packs = packs.encrypt(key);
}
NonBlockStream::<MS>::from_version_packs(config.versions, packs, stream)
}
fn wait_for_server_uuid_message<const MS: usize>(
timeout: Duration,
client: &mut NonBlockStream<MS>,
) -> Result<Option<ClientId>> {
let now = Instant::now();
loop {
let Some(msg) = client.read()? else {
continue;
};
let msg = MessageDataInternal::try_from(msg.as_slice())?;
if let MessageDataInternal::ServerUuid(uuid) = msg {
return Ok(Some(uuid));
}
if now.elapsed() > timeout {
return Ok(None);
}
}
}