use std::{
collections::HashMap,
error::Error,
fmt::{Debug, Display},
io::{Read, Write},
net::{IpAddr, SocketAddr, TcpListener, TcpStream, ToSocketAddrs},
sync::{
mpsc::{self, Receiver, Sender},
Arc, Mutex,
},
thread,
};
use once_cell::sync::Lazy;
use regex::Regex;
use crate::{cell_value::CellValue, replies::Reply};
pub enum ReadMessageResult {
Message(String),
ConnectionClosed,
Err(ConnectionError),
}
pub enum WriteMessageResult {
Ok,
ConnectionClosed,
Err(ConnectionError),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ConnectionError {
ConnectionLost,
MessageTooLong,
MessageInvalidUtf8,
InvalidAddress,
CouldNotConvertToJson,
}
pub trait Reader {
fn read_message(&mut self) -> ReadMessageResult;
fn id(&self) -> String;
}
pub trait Writer {
fn write_message(&mut self, message: Reply) -> WriteMessageResult;
fn id(&self) -> String;
}
pub trait ReaderWriter {
type Reader: Reader + Send + 'static;
type Writer: Writer + Send + 'static;
}
pub struct TerminalReaderWriter;
impl ReaderWriter for TerminalReaderWriter {
type Reader = TerminalReader;
type Writer = TerminalWriter;
}
pub struct ConnectionReaderWriter;
impl ReaderWriter for ConnectionReaderWriter {
type Reader = ConnectionReader;
type Writer = ConnectionWriter;
}
pub enum Connection<R: Reader, W: Writer> {
NewConnection {
reader: R,
writer: W,
},
NoMoreConnections,
}
pub trait Manager {
type ReaderWriter: ReaderWriter;
fn accept_new_connection(
&mut self,
) -> Connection<
<Self::ReaderWriter as ReaderWriter>::Reader,
<Self::ReaderWriter as ReaderWriter>::Writer,
>;
}
pub struct TerminalManager {
_join_handle: thread::JoinHandle<()>,
receiver: Receiver<(TerminalReader, TerminalWriter)>,
}
pub struct TerminalReader {
receiver: Receiver<String>,
id: String,
}
pub struct TerminalWriter {
id: String,
mark_mode: bool,
}
impl TerminalManager {
pub fn launch(mark_mode: bool) -> Self {
let (terminal_sender, terminal_receiver) = std::sync::mpsc::channel();
let line_sender = Arc::new(Mutex::new(HashMap::<String, Sender<String>>::new()));
let join_handle = thread::spawn(move || {
let stdin = std::io::stdin();
loop {
let mut input = String::new();
if stdin.read_line(&mut input).is_err() {
return;
}
if input.trim().is_empty() {
return;
}
static RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"^(?<name>\w+: )?(?<command>.+)").unwrap());
let captures = match RE.captures(&input) {
Some(captures) => captures,
None => {
eprintln!("Invalid command: {input}");
continue;
}
};
let name = captures.name("name").map(|s| s.as_str()).unwrap_or("");
let command = captures.name("command").map(|s| s.as_str()).unwrap_or("");
let mut senders = line_sender.lock().unwrap();
if !senders.contains_key(name) {
let (line_sender, line_receiver) = mpsc::channel();
let new_reader = TerminalReader {
receiver: line_receiver,
id: name.to_string(),
};
let new_writer = TerminalWriter {
id: name.to_string(),
mark_mode: mark_mode,
};
terminal_sender.send((new_reader, new_writer)).unwrap();
senders.insert(name.to_string(), line_sender.clone());
}
senders
.get(name)
.expect("Created above.")
.send(command.to_string())
.unwrap();
}
});
Self {
_join_handle: join_handle,
receiver: terminal_receiver,
}
}
}
impl Manager for TerminalManager {
type ReaderWriter = TerminalReaderWriter;
fn accept_new_connection(&mut self) -> Connection<TerminalReader, TerminalWriter> {
match self.receiver.recv() {
Ok((reader, writer)) => Connection::NewConnection { reader, writer },
Err(_) => Connection::NoMoreConnections,
}
}
}
impl Reader for TerminalReader {
fn read_message(&mut self) -> ReadMessageResult {
loop {
match self.receiver.recv() {
Ok(ref message) if message.starts_with("sleep") => {
let duration = message
.split_whitespace()
.nth(1)
.unwrap_or("1")
.parse()
.unwrap_or(1);
thread::sleep(std::time::Duration::from_millis(duration));
}
Ok(message) => {
return ReadMessageResult::Message(message);
}
Err(_) => return ReadMessageResult::ConnectionClosed,
}
}
}
fn id(&self) -> String {
self.id.clone()
}
}
impl Writer for TerminalWriter {
fn write_message(&mut self, message: Reply) -> WriteMessageResult {
match message {
Reply::Value(n, CellValue::Error(_)) if self.mark_mode => {
println!("{n} = Error (hidden by mark mode)");
}
Reply::Value(n, v) => {
println!("{n} = {v}");
}
Reply::Error(e) => {
if self.mark_mode {
println!("Error (hidden by mark mode)");
} else {
println!("Error: {e}");
}
}
}
WriteMessageResult::Ok
}
fn id(&self) -> String {
self.id.clone()
}
}
pub fn resolve_address(addr: &str) -> Result<SocketAddr, ConnectionError> {
addr.to_socket_addrs()
.map_err(|_| ConnectionError::InvalidAddress)?
.next()
.ok_or_else(|| ConnectionError::InvalidAddress)
}
pub struct ConnectionManager {
listener: TcpListener,
}
impl ConnectionManager {
pub fn launch(address: impl Into<IpAddr>, port: u16) -> Self {
let address = address.into();
let listener = TcpListener::bind((address, port))
.unwrap_or_else(|_| panic!("failed to bind to {address}:{port}"));
Self { listener }
}
}
impl Manager for ConnectionManager {
type ReaderWriter = ConnectionReaderWriter;
fn accept_new_connection(&mut self) -> Connection<ConnectionReader, ConnectionWriter> {
loop {
match self.listener.accept() {
Ok((socket, addr)) => {
let socket_read = match socket.try_clone() {
Ok(socket) => socket,
Err(_) => return Connection::NoMoreConnections,
};
let socket_write = socket;
return Connection::NewConnection {
reader: ConnectionReader::from_socket(socket_read, addr),
writer: ConnectionWriter::from_socket(socket_write, addr),
};
}
Err(_err) => return Connection::NoMoreConnections,
}
}
}
}
pub struct ConnectionReader {
socket: TcpStream,
socket_addr: SocketAddr,
buffer: Box<[u8; 512]>,
buflen: usize,
}
pub struct ConnectionWriter {
socket: TcpStream,
socket_addr: SocketAddr,
}
impl Display for ConnectionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
<Self as Debug>::fmt(self, f)
}
}
impl Error for ConnectionError {}
impl ConnectionReader {
fn from_socket(socket: TcpStream, socket_addr: SocketAddr) -> Self {
Self {
socket,
socket_addr,
buffer: Box::from([0; 512]),
buflen: 0,
}
}
fn buffer_lf(&self) -> Option<usize> {
self.buffer[..self.buflen]
.iter()
.enumerate()
.find(|(_, byte)| **byte == b'\n')
.map(|(index, _)| index)
}
}
impl Reader for ConnectionReader {
fn read_message(&mut self) -> ReadMessageResult {
use std::io::ErrorKind;
if self.buffer_lf().is_none() {
let n_bytes = loop {
break match self.socket.read(&mut self.buffer[self.buflen..]) {
Ok(0) => return ReadMessageResult::ConnectionClosed,
Ok(n_bytes) => n_bytes,
Err(err) => {
match err.kind() {
ErrorKind::Interrupted => continue,
_ => return ReadMessageResult::Err(ConnectionError::ConnectionLost),
}
}
};
};
self.buflen += n_bytes;
}
let Some(end) = self.buffer_lf() else {
self.buflen = 0;
return ReadMessageResult::Err(ConnectionError::MessageTooLong);
};
let bytes = Vec::from(&self.buffer[0..end]);
let after_lf = end + 1;
self.buffer.copy_within(after_lf..self.buflen, 0);
self.buflen -= after_lf;
let Ok(message) = String::from_utf8(bytes) else {
return ReadMessageResult::Err(ConnectionError::MessageInvalidUtf8);
};
ReadMessageResult::Message(message)
}
fn id(&self) -> String {
self.socket_addr.to_string()
}
}
impl ConnectionWriter {
fn from_socket(socket: TcpStream, socket_addr: SocketAddr) -> Self {
Self {
socket,
socket_addr,
}
}
}
impl Writer for ConnectionWriter {
fn write_message(&mut self, message: Reply) -> WriteMessageResult {
let Ok(message) = serde_json::to_string(&message) else {
return WriteMessageResult::Err(ConnectionError::CouldNotConvertToJson);
};
let message = format!("{message}\n");
if self.socket.write_all(message.as_bytes()).is_err() {
return WriteMessageResult::ConnectionClosed;
}
let _ = self.socket.flush();
WriteMessageResult::Ok
}
fn id(&self) -> String {
self.socket_addr.to_string()
}
}