#[cfg(test)]
mod test;
use std::any::Any;
use std::collections::HashMap;
use std::io;
use std::io::{BufRead, BufReader, Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
#[derive(Debug)]
pub enum Error {
Protocol(String),
IoError(io::Error),
}
impl From<io::Error> for Error {
fn from(error: io::Error) -> Self {
Error::IoError(error)
}
}
impl Error {
fn new(msg: &str) -> Error {
Error::Protocol(msg.to_owned())
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self {
Error::Protocol(s) => write!(f, "{}", s),
Error::IoError(e) => write!(f, "{}", e),
}
}
}
impl std::error::Error for Error {}
pub struct Conn {
id: u64,
addr: SocketAddr,
reader: BufReader<Box<dyn Read>>,
wbuf: Vec<u8>,
writer: Box<dyn Write>,
closed: bool,
shutdown: bool,
cmds: Vec<Vec<Vec<u8>>>,
conns: Arc<Mutex<HashMap<u64, Arc<AtomicBool>>>>,
pub context: Option<Box<dyn Any>>,
}
impl Conn {
pub fn id(&self) -> u64 {
self.id
}
pub fn addr(&self) -> &SocketAddr {
&self.addr
}
pub fn next_command(&mut self) -> Option<Vec<Vec<u8>>> {
self.cmds.pop()
}
pub fn write_string(&mut self, msg: &str) {
if !self.closed {
self.extend_lossy_line(b'+', msg);
}
}
pub fn write_null(&mut self) {
if !self.closed {
self.wbuf.extend("$-1\r\n".as_bytes());
}
}
pub fn write_error(&mut self, msg: &str) {
if !self.closed {
self.extend_lossy_line(b'-', msg);
}
}
pub fn write_integer(&mut self, x: i64) {
if !self.closed {
self.wbuf.extend(format!(":{}\r\n", x).as_bytes());
}
}
pub fn write_array(&mut self, count: usize) {
if !self.closed {
self.wbuf.extend(format!("*{}\r\n", count).as_bytes());
}
}
pub fn write_bulk(&mut self, msg: &[u8]) {
if !self.closed {
self.wbuf.extend(format!("${}\r\n", msg.len()).as_bytes());
self.wbuf.extend(msg);
self.wbuf.push(b'\r');
self.wbuf.push(b'\n');
}
}
pub fn write_raw(&mut self, raw: &[u8]) {
if !self.closed {
self.wbuf.extend(raw);
}
}
pub fn close(&mut self) {
self.closed = true;
}
pub fn shutdown(&mut self) {
self.closed = true;
self.shutdown = true;
}
pub fn cross_close(&mut self, id: u64) {
if let Some(xcloser) = self.conns.lock().unwrap().get(&id) {
xcloser.store(true, Ordering::SeqCst);
}
}
fn pl_read_array(&mut self, line: Vec<u8>) -> Result<Option<Vec<Vec<u8>>>, Error> {
let n = match String::from_utf8_lossy(&line[1..]).parse::<i32>() {
Ok(n) => n,
Err(_) => {
return Err(Error::new("invalid multibulk length"));
}
};
let mut arr = Vec::new();
for _ in 0..n {
let line = match self.pl_read_line()? {
Some(line) => line,
None => return Ok(None),
};
if line.len() == 0 {
return Err(Error::new("expected '$', got ' '"));
}
if line[0] != b'$' {
return Err(Error::new(&format!(
"expected '$', got '{}'",
if line[0] < 20 || line[0] > b'~' {
' '
} else {
line[0] as char
},
)));
}
let n = match String::from_utf8_lossy(&line[1..]).parse::<i32>() {
Ok(n) => n,
Err(_) => -1,
};
if n < 0 || n > 536870912 {
return Err(Error::new("invalid bulk length"));
}
let mut buf = vec![0u8; n as usize];
self.reader.read_exact(&mut buf)?;
let mut crnl = [0u8; 2];
self.reader.read_exact(&mut crnl)?;
arr.push(buf);
}
Ok(Some(arr))
}
fn pl_read_line(&mut self) -> Result<Option<Vec<u8>>, Error> {
let mut line = Vec::new();
let size = self.reader.read_until(b'\n', &mut line)?;
if size == 0 {
return Ok(None);
}
if line.len() > 1 && line[line.len() - 2] == b'\r' {
line.truncate(line.len() - 2);
} else {
line.truncate(line.len() - 1);
}
Ok(Some(line))
}
fn pl_read_inline(&mut self, line: Vec<u8>) -> Result<Option<Vec<Vec<u8>>>, Error> {
const UNBALANCED: &str = "unbalanced quotes in request";
let mut arr = Vec::new();
let mut arg = Vec::new();
let mut i = 0;
loop {
if i >= line.len() || line[i] == b' ' || line[i] == b'\t' {
if arg.len() > 0 {
arr.push(arg);
arg = Vec::new();
}
if i >= line.len() {
break;
}
} else if line[i] == b'\'' || line[i] == b'\"' {
let quote = line[i];
i += 1;
loop {
if i == line.len() {
return Err(Error::new(UNBALANCED));
}
if line[i] == quote {
i += 1;
break;
}
if line[i] == b'\\' && quote == b'"' {
if i == line.len() - 1 {
return Err(Error::new(UNBALANCED));
}
i += 1;
match line[i] {
b't' => arg.push(b'\t'),
b'n' => arg.push(b'\n'),
b'r' => arg.push(b'\r'),
b'b' => arg.push(8),
b'v' => arg.push(11),
b'x' => {
if line.len() < 3 {
return Err(Error::new(UNBALANCED));
}
let hline = &line[i + 1..i + 3];
let hex = String::from_utf8_lossy(hline);
match u8::from_str_radix(&hex, 16) {
Ok(b) => arg.push(b),
Err(_) => arg.extend(&line[i..i + 3]),
}
i += 2;
}
_ => arg.push(line[i]),
}
} else {
arg.push(line[i]);
}
i += 1;
}
if i < line.len() && line[i] != b' ' && line[i] != b'\t' {
return Err(Error::new(UNBALANCED));
}
} else {
arg.push(line[i]);
}
i += 1;
}
Ok(Some(arr))
}
fn read_pipeline(&mut self) -> Result<Vec<Vec<Vec<u8>>>, Error> {
let mut cmds = Vec::new();
loop {
let line = match self.pl_read_line()? {
Some(line) => line,
None => {
self.closed = true;
break;
}
};
if line.len() == 0 {
continue;
}
let args = if line[0] == b'*' {
self.pl_read_array(line)?
} else {
self.pl_read_inline(line)?
};
let args = match args {
Some(args) => args,
None => {
self.closed = true;
break;
}
};
if args.len() > 0 {
cmds.push(args);
}
if cmds.len() > 0 && self.reader.buffer().len() == 0 {
break;
}
}
Ok(cmds)
}
fn extend_lossy_line(&mut self, prefix: u8, msg: &str) {
self.wbuf.push(prefix);
for b in msg.bytes() {
self.wbuf.push(if b < b' ' { b' ' } else { b })
}
self.wbuf.push(b'\r');
self.wbuf.push(b'\n');
}
fn flush(&mut self) -> Result<(), Error> {
if self.wbuf.len() > 0 {
self.writer.write_all(&self.wbuf)?;
if self.wbuf.len() > 1048576 {
self.wbuf = Vec::new();
} else {
self.wbuf.truncate(0);
}
}
Ok(())
}
}
pub struct Server<T> {
listener: Option<TcpListener>,
data: Option<T>,
local_addr: SocketAddr,
pub command: Option<fn(&mut Conn, &T, Vec<Vec<u8>>)>,
pub opened: Option<fn(&mut Conn, &T)>,
pub closed: Option<fn(&mut Conn, &T, Option<Error>)>,
pub tick: Option<fn(&T) -> Option<Duration>>,
}
pub fn listen<A: ToSocketAddrs, T>(addr: A, data: T) -> Result<Server<T>, Error> {
let listener = TcpListener::bind(addr)?;
let local_addr = listener.local_addr()?;
let svr = Server {
data: Some(data),
listener: Some(listener),
local_addr: local_addr,
command: None,
opened: None,
closed: None,
tick: None,
};
Ok(svr)
}
impl<T: Send + Sync + 'static> Server<T> {
pub fn serve(&mut self) -> Result<(), Error> {
serve(self)
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
}
fn serve<T: Send + Sync + 'static>(s: &mut Server<T>) -> Result<(), Error> {
let listener = match s.listener.take() {
Some(listener) => listener,
None => return Err(Error::IoError(io::Error::from(io::ErrorKind::Other))),
};
let data = s.data.take().unwrap();
let command = s.command.take();
let opened = s.opened.take();
let closed = s.closed.take();
let tick = s.tick.take();
let laddr = s.local_addr;
drop(s);
let wg = Arc::new(AtomicI64::new(0));
let conns: HashMap<u64, Arc<AtomicBool>> = HashMap::new();
let conns = Arc::new(Mutex::new(conns));
let data = Arc::new(data);
let mut next_id: u64 = 1;
let shutdown = Arc::new(AtomicBool::new(false));
let init_shutdown = |shutdown: Arc<AtomicBool>, laddr: &SocketAddr| {
let aord = Ordering::SeqCst;
if shutdown.compare_exchange(false, true, aord, aord).is_err() {
return;
}
let _ = TcpStream::connect(&laddr);
};
if let Some(tick) = tick {
let data = data.clone();
let shutdown = shutdown.clone();
let wg = wg.clone();
wg.fetch_add(1, Ordering::SeqCst);
thread::spawn(move || {
while !shutdown.load(Ordering::SeqCst) {
match (tick)(&data) {
Some(delay) => thread::sleep(delay),
None => {
init_shutdown(shutdown.clone(), &laddr);
break;
}
}
}
wg.fetch_add(-1, Ordering::SeqCst)
});
}
for stream in listener.incoming() {
let shutdown = shutdown.clone();
if shutdown.load(Ordering::SeqCst) {
break;
}
match stream {
Ok(stream) => {
if stream
.set_read_timeout(Some(Duration::from_millis(100)))
.is_err()
{
continue;
}
let addr = match stream.peer_addr() {
Ok(addr) => addr,
_ => continue,
};
let streams = (
match stream.try_clone() {
Ok(stream) => stream,
_ => continue,
},
stream,
);
let data = data.clone();
let conn_id = next_id;
next_id += 1;
let xcloser = Arc::new(AtomicBool::new(false));
let conns = conns.clone();
conns.lock().unwrap().insert(conn_id, xcloser.clone());
let wg = wg.clone();
wg.fetch_add(1, Ordering::SeqCst);
thread::spawn(move || {
let mut conn = Conn {
id: conn_id,
cmds: Vec::new(),
context: None,
addr,
reader: BufReader::new(Box::new(streams.0)),
wbuf: Vec::new(),
writer: Box::new(streams.1),
closed: false,
shutdown: false,
conns: conns.clone(),
};
let mut final_err: Option<Error> = None;
if let Some(opened) = opened {
(opened)(&mut conn, &data);
}
loop {
if let Err(e) = conn.flush() {
if final_err.is_none() {
final_err = Some(From::from(e));
}
conn.closed = true;
}
if conn.closed {
break;
}
match conn.read_pipeline() {
Ok(cmds) => {
conn.cmds = cmds;
conn.cmds.reverse();
while let Some(cmd) = conn.next_command() {
if let Some(command) = command {
(command)(&mut conn, &data, cmd);
}
if conn.closed {
break;
}
}
}
Err(e) => {
if let Error::Protocol(msg) = &e {
conn.write_error(&format!("ERR Protocol error: {}", msg));
} else if let Error::IoError(e) = &e {
if let io::ErrorKind::WouldBlock = e.kind() {
if shutdown.load(Ordering::SeqCst) {
conn.closed = true;
}
if xcloser.load(Ordering::SeqCst) {
conn.closed = true;
}
continue;
}
}
final_err = Some(e);
conn.closed = true;
}
}
}
if conn.shutdown {
init_shutdown(shutdown.clone(), &laddr);
}
if let Some(closed) = closed {
(closed)(&mut conn, &data, final_err);
}
conns.lock().unwrap().remove(&conn.id);
wg.fetch_add(-1, Ordering::SeqCst);
});
}
Err(_) => {}
}
}
while wg.load(Ordering::SeqCst) > 0 {
thread::sleep(Duration::from_millis(10));
}
Ok(())
}