use std::{
collections::HashMap,
sync::{Arc, Mutex, OnceLock},
};
use compio::{
io::{
AsyncRead, AsyncWrite,
framed::{Framed, frame::LengthDelimited},
},
net::{TcpListener, UnixListener},
};
use futures_util::{
SinkExt, StreamExt,
future::{Either, select},
pin_mut,
};
use super::{codec::RpcCodec, handler::RpcHandler, protocol::*};
static SHUTDOWN_SENDERS: OnceLock<Mutex<Vec<async_broadcast::Sender<()>>>> = OnceLock::new();
fn register_shutdown(tx: async_broadcast::Sender<()>) {
let mutex = SHUTDOWN_SENDERS.get_or_init(|| {
ctrlc::set_handler(|| {
if let Some(mutex) = SHUTDOWN_SENDERS.get()
&& let Ok(senders) = mutex.lock()
{
for tx in senders.iter() {
let _ = tx.try_broadcast(());
}
}
})
.ok();
Mutex::new(Vec::new())
});
if let Ok(mut senders) = mutex.lock() {
senders.push(tx);
}
}
pub(crate) type TreeMap = Arc<HashMap<u64, Arc<dyn RpcHandler>>>;
pub(crate) async fn handle_connection<R, W>(
reader: R,
writer: W,
trees: TreeMap,
mut stop_rx: async_broadcast::Receiver<()>,
) where
R: AsyncRead + Unpin + 'static,
W: AsyncWrite + Unpin + 'static,
{
let mut framed = Framed::new::<Response, Request>(RpcCodec, LengthDelimited::new())
.with_reader(reader)
.with_writer(writer);
loop {
let next_request = framed.next();
let stop_signal = stop_rx.recv();
pin_mut!(next_request);
pin_mut!(stop_signal);
match select(next_request, stop_signal).await {
Either::Left((Some(result), _)) => {
let response = match result {
Ok(request) => dispatch(&trees, request),
Err(e) => Response::Err {
code: 500,
message: e.to_string(),
},
};
if framed.send(response).await.is_err() {
break;
}
}
Either::Left((None, _)) => break, Either::Right(_) => break, }
}
}
fn dispatch(trees: &TreeMap, request: Request) -> Response {
if request.op == OpCode::ListCollections {
return super::collections::list_collections(trees);
}
let handler = match trees.get(&request.hashname) {
Some(h) => h,
None => {
return Response::Err {
code: 404,
message: format!("tree not found: {:#X}", request.hashname),
};
}
};
let result = match (request.op, request.payload) {
(OpCode::Get, RequestPayload::Key(key)) => handler
.get(&key)
.map(|v| Response::Ok(ResponsePayload::OptionalData(v))),
(OpCode::Contains, RequestPayload::Key(key)) => handler
.contains(&key)
.map(|v| Response::Ok(ResponsePayload::OptionalLen(v))),
(OpCode::First, RequestPayload::Empty) => handler
.first()
.map(|v| Response::Ok(ResponsePayload::OptionalKV(v))),
(OpCode::Last, RequestPayload::Empty) => handler
.last()
.map(|v| Response::Ok(ResponsePayload::OptionalKV(v))),
(OpCode::Range, RequestPayload::Range { start, end }) => handler
.range(start, end)
.map(|v| Response::Ok(ResponsePayload::KeyValues(v))),
(OpCode::RangeKeys, RequestPayload::Range { start, end }) => handler
.range_keys(start, end)
.map(|v| Response::Ok(ResponsePayload::Keys(v))),
(OpCode::Count, RequestPayload::Count { exact }) => handler
.count(exact)
.map(|n| Response::Ok(ResponsePayload::Count(n))),
(OpCode::Upsert, RequestPayload::Upsert { key, flag, value }) => handler
.upsert(key, flag, value)
.map(|k| Response::Ok(ResponsePayload::Key(k))),
(OpCode::Remove, RequestPayload::Remove { key, soft }) => handler
.remove(&key, soft)
.map(|_| Response::Ok(ResponsePayload::Empty)),
(OpCode::Take, RequestPayload::Take { key, soft }) => handler
.take(&key, soft)
.map(|v| Response::Ok(ResponsePayload::OptionalData(v))),
(OpCode::ApplyBatch, RequestPayload::Batch(items)) => handler
.apply_batch(items)
.map(|_| Response::Ok(ResponsePayload::Empty)),
_ => {
return Response::Err {
code: 400,
message: "invalid op/payload combination".into(),
};
}
};
match result {
Ok(resp) => resp,
Err(e) => {
let (code, _) = e.to_resp();
Response::Err {
code,
message: e.to_string(),
}
}
}
}
pub(crate) async fn accept_tcp(listener: TcpListener, trees: TreeMap) {
let (shutdown_tx, mut shutdown_rx) = async_broadcast::broadcast::<()>(1);
register_shutdown(shutdown_tx);
loop {
let accept_fut = listener.accept();
let stop_fut = shutdown_rx.recv();
pin_mut!(accept_fut);
pin_mut!(stop_fut);
match select(accept_fut, stop_fut).await {
Either::Left((Ok((stream, _addr)), _)) => {
let trees = trees.clone();
let (reader, writer) = stream.into_split();
let rx = shutdown_rx.clone();
compio::runtime::spawn(handle_connection(reader, writer, trees, rx)).detach();
}
Either::Left((Err(e), _)) => {
error!("tcp accept error: {e}");
}
Either::Right(_) => {
info!("Received Ctrl-C, shutting down TCP server...");
break;
}
}
}
}
pub(crate) async fn accept_uds(listener: UnixListener, trees: TreeMap) {
let (shutdown_tx, mut shutdown_rx) = async_broadcast::broadcast::<()>(1);
register_shutdown(shutdown_tx);
loop {
let accept_fut = listener.accept();
let stop_fut = shutdown_rx.recv();
pin_mut!(accept_fut);
pin_mut!(stop_fut);
match select(accept_fut, stop_fut).await {
Either::Left((Ok((stream, _addr)), _)) => {
let trees = trees.clone();
let (reader, writer) = stream.into_split();
let rx = shutdown_rx.clone();
compio::runtime::spawn(handle_connection(reader, writer, trees, rx)).detach();
}
Either::Left((Err(e), _)) => {
error!("uds accept error: {e}");
}
Either::Right(_) => {
info!("Received Ctrl-C, shutting down UDS server...");
break;
}
}
}
}