quad-net 0.1.2

Miniquad friendly network abstractions
Documentation
use std::net::ToSocketAddrs;
use std::net::{TcpListener, TcpStream};
use std::time::{Duration, Instant};

use std::sync::{Arc, Mutex};

use super::protocol::MessageReader;

pub struct Settings<F, F1, F2, S>
where
    F: Fn(&mut SocketHandle, &mut S, Vec<u8>) + Send + 'static,
    F1: Fn(&mut SocketHandle, &S) + Send + 'static,
    F2: Fn(&S) + Send + 'static,
    S: Default + Send,
{
    pub on_message: F,
    pub on_timer: F1,
    pub on_disconnect: F2,
    pub timer: Option<Duration>,

    pub _marker: std::marker::PhantomData<S>,
}

enum Sender<'a> {
    WebSocket(&'a ws::Sender),
    Tcp(&'a mut TcpStream),
}

pub struct SocketHandle<'a> {
    sender: Sender<'a>,
    disconnect: bool,
}

impl<'a> Sender<'a> {
    fn send(&mut self, data: &[u8]) -> Option<()> {
        use std::io::Write;

        match self {
            Sender::WebSocket(out) => {
                out.send(data).ok()?;
            }
            Sender::Tcp(stream) => {
                stream.write(&[data.len() as u8]).ok()?;
                stream.write(data).ok()?;
            }
        }

        Some(())
    }
}

impl<'a> SocketHandle<'a> {
    fn new(sender: Sender<'a>) -> SocketHandle<'a> {
        SocketHandle {
            sender,
            disconnect: false,
        }
    }

    pub fn send(&mut self, data: &[u8]) -> Result<(), ()> {
        self.sender.send(data).ok_or(())
    }

    #[cfg(feature = "nanoserde")]
    pub fn send_bin<T: nanoserde::SerBin>(&mut self, data: &T) -> Result<(), ()> {
        self.send(&nanoserde::SerBin::serialize_bin(data))
    }

    pub fn disconnect(&mut self) {
        self.disconnect = true;
    }
}

pub fn listen<A, A1, F, F1, F2, S>(tcp_addr: A, ws_addr: A1, settings: Settings<F, F1, F2, S>)
where
    A: ToSocketAddrs + std::fmt::Debug + Send,
    A1: ToSocketAddrs + std::fmt::Debug + Send + 'static,
    F: Fn(&mut SocketHandle, &mut S, Vec<u8>) + Send + 'static,
    F1: Fn(&mut SocketHandle, &S) + Send + 'static,
    F2: Fn(&S) + Send + 'static,
    S: Default + Send + 'static,
{
    let on_message = Arc::new(Mutex::new(settings.on_message));
    let on_timer = Arc::new(Mutex::new(settings.on_timer));
    let on_disconnect = Arc::new(Mutex::new(settings.on_disconnect));
    let timer = settings.timer;

    struct WsHandler<
        S: Default,
        F: Fn(&mut SocketHandle, &mut S, Vec<u8>) + Send + 'static,
        F1: Fn(&mut SocketHandle, &S) + Send + 'static,
        F2: Fn(&S) + Send + 'static,
    > {
        out: ws::Sender,
        state: S,
        on_message: Arc<Mutex<F>>,
        on_timer: Arc<Mutex<F1>>,
        on_disconnect: Arc<Mutex<F2>>,
        timeout: Option<Duration>,
    }

    impl<
            S: Default,
            F: Fn(&mut SocketHandle, &mut S, Vec<u8>) + Send + 'static,
            F1: Fn(&mut SocketHandle, &S) + Send + 'static,
            F2: Fn(&S) + Send + 'static,
        > ws::Handler for WsHandler<S, F, F1, F2>
    {
        fn on_message(&mut self, msg: ws::Message) -> ws::Result<()> {
            let data = msg.into_data();
            let mut handle = SocketHandle::new(Sender::WebSocket(&self.out));
            (self.on_message.lock().unwrap())(&mut handle, &mut self.state, data);
            if handle.disconnect {
                self.out.close(ws::CloseCode::Normal)?;
            }
            Ok(())
        }

        fn on_open(&mut self, _: ws::Handshake) -> ws::Result<()> {
            if let Some(timeout) = self.timeout {
                self.out
                    .timeout(timeout.as_millis() as _, ws::util::Token(1))?;
            }
            Ok(())
        }

        fn on_timeout(&mut self, _: ws::util::Token) -> ws::Result<()> {
            if let Some(timeout) = self.timeout {
                let mut handle = SocketHandle::new(Sender::WebSocket(&self.out));
                (self.on_timer.lock().unwrap())(&mut handle, &self.state);
                if handle.disconnect == false {
                    self.out
                        .timeout(timeout.as_millis() as _, ws::util::Token(1))?;
                } else {
                    self.out.close(ws::CloseCode::Normal)?;
                }
            }
            Ok(())
        }

        fn on_close(&mut self, _code: ws::CloseCode, _reason: &str) {
            (self.on_disconnect.lock().unwrap())(&self.state);
        }
    }

    std::thread::spawn({
        let on_message = on_message.clone();
        let on_timer = on_timer.clone();
        let on_disconnect = on_disconnect.clone();

        move || {
            ws::Builder::new()
                .with_settings(ws::Settings {
                    timer_tick_millis: 10,
                    tcp_nodelay: true,
                    ..ws::Settings::default()
                })
                .build(move |out| {
                    let on_message = on_message.clone();
                    let on_timer = on_timer.clone();
                    let on_disconnect = on_disconnect.clone();

                    WsHandler {
                        out,
                        state: S::default(),
                        on_message,
                        on_timer,
                        on_disconnect,
                        timeout: timer,
                    }
                })
                .unwrap()
                .listen(ws_addr)
                .unwrap();
        }
    });

    let listener = TcpListener::bind(tcp_addr).unwrap();
    for stream in listener.incoming() {
        let on_message = on_message.clone();
        let on_timer = on_timer.clone();
        let on_disconnect = on_disconnect.clone();

        std::thread::spawn(move || {
            let mut stream = stream.unwrap();
            stream.set_nodelay(true).unwrap();
            stream.set_nonblocking(true).unwrap();
            let mut message_reader = MessageReader::new();
            let mut state = S::default();

            let mut time = Instant::now();
            loop {
                match message_reader.next(&mut stream) {
                    Ok(Some(message)) => {
                        let mut handle = SocketHandle::new(Sender::Tcp(&mut stream));
                        (on_message.lock().unwrap())(&mut handle, &mut state, message);
                        if handle.disconnect {
                            (on_disconnect.lock().unwrap())(&state);
                            return;
                        }
                    }
                    Ok(None) => {}
                    Err(_err) => {
                        (on_disconnect.lock().unwrap())(&state);
                        return;
                    }
                }

                if let Some(timer) = timer {
                    if time.elapsed() >= timer {
                        time = Instant::now();
                        let mut handle = SocketHandle::new(Sender::Tcp(&mut stream));

                        (on_timer.lock().unwrap())(&mut handle, &state);
                        if handle.disconnect {
                            (on_disconnect.lock().unwrap())(&state);
                            return;
                        }
                    }
                }
            }
        });
    }
}