namaste 0.3.0

Simple locks between processes
Documentation
// License: see LICENSE file at root directory of `master` branch

//! # Client

use std::{
    io::{self, Error, ErrorKind, Read, Write},
    net::{SocketAddr, TcpListener, TcpStream},
    sync::{
        Arc,
        atomic::{self, ATOMIC_BOOL_INIT, AtomicBool},
    },
    thread,
    time::Duration,
};

use crate::Uid;

type StopFlag = Arc<AtomicBool>;

/// # Client
///
/// Use [`::book()`][::book()] to book a seat. When done with it, you should call [`::check_out()`][::check_out()].
///
/// [::book()]: #method.book
/// [::check_out()]: #method.check_out
pub struct Client {

    /// # Server port
    server_port: u16,

    /// # UID
    uid: Uid,

    /// # Handshake UID
    handshake_uid: Uid,

    /// # Read/write timeout
    rw_timeout: Duration,

    /// # Stopped flag
    stopped_flag: StopFlag,

    /// # Local server port
    local_server_port: u16,

}

impl Client {

    /// # Books a seat
    ///
    /// ## Notes
    ///
    /// - An error is returned if UID and handshake UID are the same.
    /// - UID _must_ be a constant. But you should generate new handshake UID for each session of your program. That means handshake UID should
    ///   _only_ be valid within lifetime of the process.
    /// - An error is returned if read/write timeout is zero.
    /// - A TCP server will be started to communicate with Namaste server. Server port will be granted by system.
    pub fn book(uid: Uid, handshake_uid: Uid, server_port: u16, rw_timeout: Duration) -> io::Result<Option<Self>> {
        if crate::cmp_uids(&uid, &handshake_uid) {
            return Err(Error::new(ErrorKind::InvalidData, "UID and handshake UID must be different"));
        }

        let mut bytes = Vec::with_capacity(uid.len().saturating_add(handshake_uid.len()).saturating_add(3));
        bytes.write_all(&[crate::server::CMD_BOOK])?;
        bytes.write_all(&uid)?;
        bytes.write_all(&handshake_uid)?;

        let (local_server_port, stopped_flag) = start_server(uid, handshake_uid, rw_timeout)?;

        bytes.write_all(&local_server_port.to_be_bytes())?;

        let mut stream = TcpStream::connect_timeout(&SocketAddr::from((crate::DEFAULT_IP, server_port)), rw_timeout)?;
        stream.set_read_timeout(Some(rw_timeout))?;
        stream.set_write_timeout(Some(rw_timeout))?;
        stream.write_all(&bytes)?;

        let mut buf = [false as u8];
        stream.read_exact(&mut buf)?;
        match buf[0] {
            0 => stop_server(&stopped_flag, local_server_port).map(|()| None),
            _ => Ok(Some(Self {
                server_port,
                uid,
                handshake_uid,
                rw_timeout,
                stopped_flag,
                local_server_port,
            })),
        }
    }

    /// # Checks out
    pub fn check_out(&self) -> io::Result<bool> {
        if stop_server(&self.stopped_flag, self.local_server_port).is_err() {
            // Ignore it
        }

        let mut bytes = Vec::with_capacity(self.uid.len().saturating_add(self.handshake_uid.len()).saturating_add(1));
        bytes.write_all(&[crate::server::CMD_CHECK_OUT])?;
        bytes.write_all(&self.uid)?;
        bytes.write_all(&self.handshake_uid)?;

        let mut stream = TcpStream::connect_timeout(&SocketAddr::from((crate::DEFAULT_IP, self.server_port)), self.rw_timeout)?;
        stream.set_read_timeout(Some(self.rw_timeout))?;
        stream.set_write_timeout(Some(self.rw_timeout))?;
        stream.write_all(&bytes)?;

        let mut buf = [false as u8];
        stream.read_exact(&mut buf)?;
        Ok(buf[0] != 0)
    }

}

/// # Starts server
fn start_server(uid: Uid, handshake_uid: Uid, rw_timeout: Duration) -> io::Result<(u16, StopFlag)> {
    let server = TcpListener::bind(SocketAddr::from((crate::DEFAULT_IP, 0)))?;
    let port = server.local_addr()?.port();

    let stopped_flag = Arc::new(ATOMIC_BOOL_INIT);
    thread::spawn({
        let stopped_flag = stopped_flag.clone();
        move || for stream in server.incoming() {
            if stopped_flag.load(atomic::Ordering::Relaxed) == true {
                break;
            }
            let mut stream = match stream {
                Ok(stream) => stream,
                Err(_) => continue,
            };
            if stream.set_read_timeout(Some(rw_timeout)).is_err() {
                continue;
            }
            if stream.set_write_timeout(Some(rw_timeout)).is_err() {
                continue;
            }

            let stream_handshake_uid = match crate::read_uid(&mut stream) {
                Ok(uid) => uid,
                Err(_) => continue,
            };
            if crate::cmp_uids(&handshake_uid, &stream_handshake_uid) == false {
                continue;
            }
            if stream.write_all(&uid).is_err() {
                continue;
            }
        }
    });

    Ok((port, stopped_flag))
}

/// # Stops server
fn stop_server(stopped_flag: &StopFlag, port: u16) -> io::Result<()> {
    stopped_flag.store(true, atomic::Ordering::Relaxed);

    // Make a fake connection to let the server check for stopped flag
    TcpStream::connect(SocketAddr::from((crate::DEFAULT_IP, port))).map(|_| ())
}