rsheet_lib 0.2.0

Libraries to help implementing cs6991-24T1-ass2
Documentation
//! Contains all the code for talking to and from connections.

use std::{
    collections::HashMap,
    error::Error,
    fmt::{Debug, Display},
    io::{Read, Write},
    net::{IpAddr, SocketAddr, TcpListener, TcpStream, ToSocketAddrs},
    sync::{
        mpsc::{self, Receiver, Sender},
        Arc, Mutex,
    },
    thread,
};

use once_cell::sync::Lazy;
use regex::Regex;

use crate::{cell_value::CellValue, replies::Reply};

pub enum ReadMessageResult {
    Message(String),
    ConnectionClosed,
    Err(ConnectionError),
}

pub enum WriteMessageResult {
    Ok,
    ConnectionClosed,
    Err(ConnectionError),
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ConnectionError {
    ConnectionLost,
    MessageTooLong,
    MessageInvalidUtf8,
    InvalidAddress,
    CouldNotConvertToJson,
}

pub trait Reader {
    fn read_message(&mut self) -> ReadMessageResult;
    fn id(&self) -> String;
}

pub trait Writer {
    fn write_message(&mut self, message: Reply) -> WriteMessageResult;
    fn id(&self) -> String;
}

pub trait ReaderWriter {
    type Reader: Reader + Send + 'static;
    type Writer: Writer + Send + 'static;
}

pub struct TerminalReaderWriter;
impl ReaderWriter for TerminalReaderWriter {
    type Reader = TerminalReader;
    type Writer = TerminalWriter;
}
pub struct ConnectionReaderWriter;
impl ReaderWriter for ConnectionReaderWriter {
    type Reader = ConnectionReader;
    type Writer = ConnectionWriter;
}

pub enum Connection<R: Reader, W: Writer> {
    NewConnection {
        reader: R,
        writer: W,
    },
    /// There will be no more new connections.
    NoMoreConnections,
}

pub trait Manager {
    type ReaderWriter: ReaderWriter;

    fn accept_new_connection(
        &mut self,
    ) -> Connection<
        <Self::ReaderWriter as ReaderWriter>::Reader,
        <Self::ReaderWriter as ReaderWriter>::Writer,
    >;
}

pub struct TerminalManager {
    _join_handle: thread::JoinHandle<()>,
    receiver: Receiver<(TerminalReader, TerminalWriter)>,
}

pub struct TerminalReader {
    receiver: Receiver<String>,
    id: String,
}

pub struct TerminalWriter {
    id: String,
    mark_mode: bool,
}

impl TerminalManager {
    pub fn launch(mark_mode: bool) -> Self {
        let (terminal_sender, terminal_receiver) = std::sync::mpsc::channel();
        let line_sender = Arc::new(Mutex::new(HashMap::<String, Sender<String>>::new()));

        let join_handle = thread::spawn(move || {
            let stdin = std::io::stdin();

            loop {
                let mut input = String::new();
                if stdin.read_line(&mut input).is_err() {
                    return;
                }

                if input.trim().is_empty() {
                    return;
                }

                static RE: Lazy<Regex> =
                    Lazy::new(|| Regex::new(r"^(?<name>\w+: )?(?<command>.+)").unwrap());
                let captures = match RE.captures(&input) {
                    Some(captures) => captures,
                    None => {
                        eprintln!("Invalid command: {input}");
                        continue;
                    }
                };

                let name = captures.name("name").map(|s| s.as_str()).unwrap_or("");
                let command = captures.name("command").map(|s| s.as_str()).unwrap_or("");

                let mut senders = line_sender.lock().unwrap();

                if !senders.contains_key(name) {
                    let (line_sender, line_receiver) = mpsc::channel();
                    let new_reader = TerminalReader {
                        receiver: line_receiver,
                        id: name.to_string(),
                    };

                    let new_writer = TerminalWriter {
                        id: name.to_string(),
                        mark_mode: mark_mode,
                    };

                    terminal_sender.send((new_reader, new_writer)).unwrap();
                    senders.insert(name.to_string(), line_sender.clone());
                }

                senders
                    .get(name)
                    .expect("Created above.")
                    .send(command.to_string())
                    .unwrap();
            }
        });

        Self {
            _join_handle: join_handle,
            receiver: terminal_receiver,
        }
    }
}

impl Manager for TerminalManager {
    type ReaderWriter = TerminalReaderWriter;

    fn accept_new_connection(&mut self) -> Connection<TerminalReader, TerminalWriter> {
        match self.receiver.recv() {
            Ok((reader, writer)) => Connection::NewConnection { reader, writer },
            Err(_) => Connection::NoMoreConnections,
        }
    }
}

impl Reader for TerminalReader {
    fn read_message(&mut self) -> ReadMessageResult {
        loop {
            match self.receiver.recv() {
                Ok(ref message) if message.starts_with("sleep") => {
                    let duration = message
                        .split_whitespace()
                        .nth(1)
                        .unwrap_or("1")
                        .parse()
                        .unwrap_or(1);
                    thread::sleep(std::time::Duration::from_millis(duration));
                }
                Ok(message) => {
                    return ReadMessageResult::Message(message);
                }
                Err(_) => return ReadMessageResult::ConnectionClosed,
            }
        }
    }

    fn id(&self) -> String {
        self.id.clone()
    }
}

impl Writer for TerminalWriter {
    fn write_message(&mut self, message: Reply) -> WriteMessageResult {
        match message {
            Reply::Value(n, CellValue::Error(_)) if self.mark_mode => {
                println!("{n} = Error (hidden by mark mode)");
            }
            Reply::Value(n, v) => {
                println!("{n} = {v}");
            }
            Reply::Error(e) => {
                if self.mark_mode {
                    println!("Error (hidden by mark mode)");
                } else {
                    println!("Error: {e}");
                }
            }
        }
        WriteMessageResult::Ok
    }

    fn id(&self) -> String {
        self.id.clone()
    }
}

pub fn resolve_address(addr: &str) -> Result<SocketAddr, ConnectionError> {
    addr.to_socket_addrs()
        .map_err(|_| ConnectionError::InvalidAddress)?
        .next()
        .ok_or_else(|| ConnectionError::InvalidAddress)
}

pub struct ConnectionManager {
    listener: TcpListener,
}

impl ConnectionManager {
    pub fn launch(address: impl Into<IpAddr>, port: u16) -> Self {
        let address = address.into();
        let listener = TcpListener::bind((address, port))
            .unwrap_or_else(|_| panic!("failed to bind to {address}:{port}"));

        Self { listener }
    }
}

impl Manager for ConnectionManager {
    type ReaderWriter = ConnectionReaderWriter;

    fn accept_new_connection(&mut self) -> Connection<ConnectionReader, ConnectionWriter> {
        // TODO: loop can be removed since first loop body iteration always returns.
        loop {
            match self.listener.accept() {
                Ok((socket, addr)) => {
                    let socket_read = match socket.try_clone() {
                        Ok(socket) => socket,
                        Err(_) => return Connection::NoMoreConnections,
                    };
                    let socket_write = socket;

                    return Connection::NewConnection {
                        reader: ConnectionReader::from_socket(socket_read, addr),
                        writer: ConnectionWriter::from_socket(socket_write, addr),
                    };
                }
                Err(_err) => return Connection::NoMoreConnections,
            }
        }
    }
}

pub struct ConnectionReader {
    socket: TcpStream,
    socket_addr: SocketAddr,
    buffer: Box<[u8; 512]>,
    buflen: usize,
}

pub struct ConnectionWriter {
    socket: TcpStream,
    socket_addr: SocketAddr,
}

impl Display for ConnectionError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        <Self as Debug>::fmt(self, f)
    }
}

impl Error for ConnectionError {}

impl ConnectionReader {
    fn from_socket(socket: TcpStream, socket_addr: SocketAddr) -> Self {
        Self {
            socket,
            socket_addr,
            buffer: Box::from([0; 512]),
            buflen: 0,
        }
    }

    fn buffer_lf(&self) -> Option<usize> {
        self.buffer[..self.buflen]
            .iter()
            .enumerate()
            .find(|(_, byte)| **byte == b'\n')
            .map(|(index, _)| index)
    }
}

impl Reader for ConnectionReader {
    fn read_message(&mut self) -> ReadMessageResult {
        use std::io::ErrorKind;

        if self.buffer_lf().is_none() {
            let n_bytes = loop {
                break match self.socket.read(&mut self.buffer[self.buflen..]) {
                    Ok(0) => return ReadMessageResult::ConnectionClosed,
                    Ok(n_bytes) => n_bytes,
                    Err(err) => {
                        match err.kind() {
                            // Retry `read` if interrupted...
                            ErrorKind::Interrupted => continue,
                            _ => return ReadMessageResult::Err(ConnectionError::ConnectionLost),
                        }
                    }
                };
            };

            self.buflen += n_bytes;
        }

        let Some(end) = self.buffer_lf() else {
            // Clear out their data...
            self.buflen = 0;
            return ReadMessageResult::Err(ConnectionError::MessageTooLong);
        };

        let bytes = Vec::from(&self.buffer[0..end]);

        // end + '\n'
        let after_lf = end + 1;

        self.buffer.copy_within(after_lf..self.buflen, 0);
        self.buflen -= after_lf;

        let Ok(message) = String::from_utf8(bytes) else {
            return ReadMessageResult::Err(ConnectionError::MessageInvalidUtf8);
        };

        ReadMessageResult::Message(message)
    }

    fn id(&self) -> String {
        self.socket_addr.to_string()
    }
}

impl ConnectionWriter {
    fn from_socket(socket: TcpStream, socket_addr: SocketAddr) -> Self {
        Self {
            socket,
            socket_addr,
        }
    }
}

impl Writer for ConnectionWriter {
    fn write_message(&mut self, message: Reply) -> WriteMessageResult {
        let Ok(message) = serde_json::to_string(&message) else {
            return WriteMessageResult::Err(ConnectionError::CouldNotConvertToJson);
        };
        let message = format!("{message}\n");
        if self.socket.write_all(message.as_bytes()).is_err() {
            return WriteMessageResult::ConnectionClosed;
        }
        let _ = self.socket.flush();

        WriteMessageResult::Ok
    }

    fn id(&self) -> String {
        self.socket_addr.to_string()
    }
}