use std::{
io::{self, Error, ErrorKind, Read, Write},
net::{SocketAddr, TcpListener, TcpStream},
sync::{
Arc,
atomic::{self, ATOMIC_BOOL_INIT, AtomicBool},
},
thread,
time::Duration,
};
use crate::Uid;
type StopFlag = Arc<AtomicBool>;
pub struct Client {
server_port: u16,
uid: Uid,
handshake_uid: Uid,
rw_timeout: Duration,
stopped_flag: StopFlag,
local_server_port: u16,
}
impl Client {
pub fn book(uid: Uid, handshake_uid: Uid, server_port: u16, rw_timeout: Duration) -> io::Result<Option<Self>> {
if crate::cmp_uids(&uid, &handshake_uid) {
return Err(Error::new(ErrorKind::InvalidData, "UID and handshake UID must be different"));
}
let mut bytes = Vec::with_capacity(uid.len().saturating_add(handshake_uid.len()).saturating_add(3));
bytes.write_all(&[crate::server::CMD_BOOK])?;
bytes.write_all(&uid)?;
bytes.write_all(&handshake_uid)?;
let (local_server_port, stopped_flag) = start_server(uid, handshake_uid, rw_timeout)?;
bytes.write_all(&local_server_port.to_be_bytes())?;
let mut stream = TcpStream::connect_timeout(&SocketAddr::from((crate::DEFAULT_IP, server_port)), rw_timeout)?;
stream.set_read_timeout(Some(rw_timeout))?;
stream.set_write_timeout(Some(rw_timeout))?;
stream.write_all(&bytes)?;
let mut buf = [false as u8];
stream.read_exact(&mut buf)?;
match buf[0] {
0 => stop_server(&stopped_flag, local_server_port).map(|()| None),
_ => Ok(Some(Self {
server_port,
uid,
handshake_uid,
rw_timeout,
stopped_flag,
local_server_port,
})),
}
}
pub fn check_out(&self) -> io::Result<bool> {
if stop_server(&self.stopped_flag, self.local_server_port).is_err() {
}
let mut bytes = Vec::with_capacity(self.uid.len().saturating_add(self.handshake_uid.len()).saturating_add(1));
bytes.write_all(&[crate::server::CMD_CHECK_OUT])?;
bytes.write_all(&self.uid)?;
bytes.write_all(&self.handshake_uid)?;
let mut stream = TcpStream::connect_timeout(&SocketAddr::from((crate::DEFAULT_IP, self.server_port)), self.rw_timeout)?;
stream.set_read_timeout(Some(self.rw_timeout))?;
stream.set_write_timeout(Some(self.rw_timeout))?;
stream.write_all(&bytes)?;
let mut buf = [false as u8];
stream.read_exact(&mut buf)?;
Ok(buf[0] != 0)
}
}
fn start_server(uid: Uid, handshake_uid: Uid, rw_timeout: Duration) -> io::Result<(u16, StopFlag)> {
let server = TcpListener::bind(SocketAddr::from((crate::DEFAULT_IP, 0)))?;
let port = server.local_addr()?.port();
let stopped_flag = Arc::new(ATOMIC_BOOL_INIT);
thread::spawn({
let stopped_flag = stopped_flag.clone();
move || for stream in server.incoming() {
if stopped_flag.load(atomic::Ordering::Relaxed) == true {
break;
}
let mut stream = match stream {
Ok(stream) => stream,
Err(_) => continue,
};
if stream.set_read_timeout(Some(rw_timeout)).is_err() {
continue;
}
if stream.set_write_timeout(Some(rw_timeout)).is_err() {
continue;
}
let stream_handshake_uid = match crate::read_uid(&mut stream) {
Ok(uid) => uid,
Err(_) => continue,
};
if crate::cmp_uids(&handshake_uid, &stream_handshake_uid) == false {
continue;
}
if stream.write_all(&uid).is_err() {
continue;
}
}
});
Ok((port, stopped_flag))
}
fn stop_server(stopped_flag: &StopFlag, port: u16) -> io::Result<()> {
stopped_flag.store(true, atomic::Ordering::Relaxed);
TcpStream::connect(SocketAddr::from((crate::DEFAULT_IP, port))).map(|_| ())
}