use std::{
io::{self, Error, ErrorKind, Read, Write},
net::{SocketAddr, TcpListener, TcpStream},
sync::{
Arc,
atomic::{self, AtomicBool},
},
thread,
time::Duration,
};
use crate::Id;
type StopFlag = Arc<AtomicBool>;
pub struct Client {
server_port: u16,
id: Id,
handshake_id: Id,
rw_timeout: Duration,
stopped_flag: StopFlag,
local_server_port: u16,
}
impl Client {
pub fn book(id: Id, handshake_id: Id, server_port: Option<u16>, rw_timeout: Duration) -> io::Result<Option<Self>> {
if crate::cmp_ids(&id, &handshake_id) {
return Err(Error::new(ErrorKind::InvalidData, "ID and handshake ID must be different"));
}
let (local_server_port, stopped_flag) = start_server(id, handshake_id, rw_timeout)?;
let mut bytes = Vec::with_capacity(id.len().saturating_add(handshake_id.len()).saturating_add(3));
bytes.push(crate::server::CMD_BOOK);
bytes.extend(id.iter());
bytes.extend(handshake_id.iter());
bytes.extend(local_server_port.to_be_bytes().into_iter());
let server_port = server_port.unwrap_or(crate::server::DEFAULT_PORT);
let mut stream = TcpStream::connect_timeout(&SocketAddr::from((crate::DEFAULT_IP, server_port)), rw_timeout)?;
stream.set_read_timeout(Some(rw_timeout))?;
stream.set_write_timeout(Some(rw_timeout))?;
stream.write_all(&bytes)?;
stream.flush()?;
let mut buf = [false as u8];
stream.read_exact(&mut buf)?;
match buf[0] {
0 => stop_server(&stopped_flag, local_server_port).map(|()| None),
_ => Ok(Some(Self {
server_port,
id,
handshake_id,
rw_timeout,
stopped_flag,
local_server_port,
})),
}
}
pub fn check_out(&self) -> io::Result<bool> {
if stop_server(&self.stopped_flag, self.local_server_port).is_err() {
}
let mut bytes = Vec::with_capacity(self.id.len().saturating_add(self.handshake_id.len()).saturating_add(1));
bytes.push(crate::server::CMD_CHECK_OUT);
bytes.extend(self.id.iter());
bytes.extend(self.handshake_id.iter());
let mut stream = TcpStream::connect_timeout(&SocketAddr::from((crate::DEFAULT_IP, self.server_port)), self.rw_timeout)?;
stream.set_read_timeout(Some(self.rw_timeout))?;
stream.set_write_timeout(Some(self.rw_timeout))?;
stream.write_all(&bytes)?;
stream.flush()?;
let mut buf = [false as u8];
stream.read_exact(&mut buf)?;
Ok(buf[0] != 0)
}
}
fn start_server(id: Id, handshake_id: Id, rw_timeout: Duration) -> io::Result<(u16, StopFlag)> {
let server = TcpListener::bind(SocketAddr::from((crate::DEFAULT_IP, 0)))?;
let port = server.local_addr()?.port();
let stopped_flag = Arc::new(AtomicBool::new(false));
thread::spawn({
let stopped_flag = stopped_flag.clone();
move || for stream in server.incoming() {
if stopped_flag.load(atomic::Ordering::Relaxed) == true {
break;
}
let job = move || {
let mut stream = stream?;
stream.set_read_timeout(Some(rw_timeout))?;
stream.set_write_timeout(Some(rw_timeout))?;
let stream_handshake_id = crate::read_id(&mut stream)?;
match crate::cmp_ids(&handshake_id, &stream_handshake_id) {
true => {
stream.write_all(&id)?;
stream.flush()
},
false => Err(Error::new(ErrorKind::Other, "Handshake ID does not match")),
}
};
if job().is_err() {
}
}
});
Ok((port, stopped_flag))
}
fn stop_server(stopped_flag: &StopFlag, port: u16) -> io::Result<()> {
stopped_flag.store(true, atomic::Ordering::Relaxed);
TcpStream::connect(SocketAddr::from((crate::DEFAULT_IP, port))).map(|_| ())
}