use std::{collections::HashMap, path::Path, sync::Arc, thread::JoinHandle};
use armour_rpc::RpcError;
use armour_rpc::protocol::{
OpCode, Request, RequestPayload, read_bound, read_bytes, read_u8, read_u32_be, read_u64_be,
read_upsert_key,
};
use compio::buf::{IoBuf, IoBufMut};
use compio::io::framed::codec::{Decoder, Encoder};
use compio::{
io::{
AsyncRead, AsyncWrite,
framed::{Framed, frame::LengthDelimited},
},
net::{TcpListener, ToSocketAddrsAsync, UnixListener},
runtime::RuntimeBuilder,
};
use futures_util::{
SinkExt, StreamExt,
future::{Either, select},
pin_mut,
};
use super::handler::RpcHandler;
use super::tree_db::Db;
#[derive(Debug)]
enum Response {
Ok(ResponsePayload),
Err { code: u16, message: String },
}
#[derive(Debug)]
pub(crate) struct CollectionMeta {
pub name: String,
pub hashname: u64,
pub typ_hash: u64,
pub version: u16,
pub count: u64,
}
#[derive(Debug)]
enum ResponsePayload {
Empty,
OptionalData(Option<Vec<u8>>),
OptionalLen(Option<u32>),
OptionalKV(Option<(Vec<u8>, Vec<u8>)>),
KeyValues(Vec<(Vec<u8>, Vec<u8>)>),
Keys(Vec<Vec<u8>>),
Key(Vec<u8>),
Count(u64),
Collections(Vec<CollectionMeta>),
}
struct RpcCodec;
impl<B: IoBuf> Decoder<Request, B> for RpcCodec {
type Error = RpcError;
fn decode(&mut self, buf: &compio::buf::Slice<B>) -> Result<Request, Self::Error> {
let bytes: &[u8] = buf;
let mut pos = 0;
let op_byte = read_u8(bytes, &mut pos)?;
let op = OpCode::from_repr(op_byte)
.ok_or_else(|| RpcError::Protocol("unknown opcode".to_string()))?;
let hashname = read_u64_be(bytes, &mut pos)?;
let payload = match op {
OpCode::Get | OpCode::Contains => {
let key = read_bytes(bytes, &mut pos)?;
RequestPayload::Key(key)
}
OpCode::First | OpCode::Last | OpCode::ListCollections => RequestPayload::Empty,
OpCode::Count => {
let exact = read_u8(bytes, &mut pos)? != 0;
RequestPayload::Count { exact }
}
OpCode::Range | OpCode::RangeKeys => {
let start = read_bound(bytes, &mut pos)?;
let end = read_bound(bytes, &mut pos)?;
RequestPayload::Range { start, end }
}
OpCode::Upsert => {
let key = read_upsert_key(bytes, &mut pos)?;
let flag_byte = read_u8(bytes, &mut pos)?;
let flag = match flag_byte {
0 => None,
1 => Some(true),
2 => Some(false),
_ => {
return Err(RpcError::Protocol("invalid upsert flag".to_string()));
}
};
let value = read_bytes(bytes, &mut pos)?;
RequestPayload::Upsert { key, flag, value }
}
OpCode::Remove => {
let key = read_bytes(bytes, &mut pos)?;
let soft = read_u8(bytes, &mut pos)? != 0;
RequestPayload::Remove { key, soft }
}
OpCode::Take => {
let key = read_bytes(bytes, &mut pos)?;
let soft = read_u8(bytes, &mut pos)? != 0;
RequestPayload::Take { key, soft }
}
OpCode::ApplyBatch => {
let count = read_u32_be(bytes, &mut pos)? as usize;
let mut items = Vec::with_capacity(count);
for _ in 0..count {
let key = read_bytes(bytes, &mut pos)?;
let has_val = read_u8(bytes, &mut pos)?;
let val = if has_val == 1 {
Some(read_bytes(bytes, &mut pos)?)
} else {
None
};
items.push((key, val));
}
RequestPayload::Batch(items)
}
};
Ok(Request {
op,
hashname,
payload,
})
}
}
impl<B: IoBufMut> Encoder<Response, B> for RpcCodec {
type Error = RpcError;
fn encode(&mut self, item: Response, buf: &mut B) -> Result<(), Self::Error> {
let mut tmp = Vec::new();
match item {
Response::Ok(payload) => {
tmp.push(0x00);
encode_ok_payload(&mut tmp, payload);
}
Response::Err { code, message } => {
tmp.push(0x01);
tmp.extend_from_slice(&code.to_be_bytes());
let msg_bytes = message.as_bytes();
tmp.extend_from_slice(&(msg_bytes.len() as u16).to_be_bytes());
tmp.extend_from_slice(msg_bytes);
}
}
buf.extend_from_slice(&tmp)
.map_err(|e| RpcError::Other(e.to_string()))?;
Ok(())
}
}
fn encode_ok_payload(buf: &mut Vec<u8>, payload: ResponsePayload) {
match payload {
ResponsePayload::Empty => {}
ResponsePayload::OptionalData(opt) => match opt {
None => buf.push(0),
Some(data) => {
buf.push(1);
buf.extend_from_slice(&(data.len() as u32).to_be_bytes());
buf.extend_from_slice(&data);
}
},
ResponsePayload::OptionalLen(opt) => match opt {
None => buf.push(0),
Some(len) => {
buf.push(1);
buf.extend_from_slice(&len.to_be_bytes());
}
},
ResponsePayload::OptionalKV(opt) => match opt {
None => buf.push(0),
Some((key, val)) => {
buf.push(1);
buf.extend_from_slice(&(key.len() as u32).to_be_bytes());
buf.extend_from_slice(&key);
buf.extend_from_slice(&(val.len() as u32).to_be_bytes());
buf.extend_from_slice(&val);
}
},
ResponsePayload::KeyValues(pairs) => {
buf.extend_from_slice(&(pairs.len() as u32).to_be_bytes());
for (key, val) in pairs {
buf.extend_from_slice(&(key.len() as u32).to_be_bytes());
buf.extend_from_slice(&key);
buf.extend_from_slice(&(val.len() as u32).to_be_bytes());
buf.extend_from_slice(&val);
}
}
ResponsePayload::Keys(keys) => {
buf.extend_from_slice(&(keys.len() as u32).to_be_bytes());
for key in keys {
buf.extend_from_slice(&(key.len() as u32).to_be_bytes());
buf.extend_from_slice(&key);
}
}
ResponsePayload::Key(key) => {
buf.extend_from_slice(&(key.len() as u32).to_be_bytes());
buf.extend_from_slice(&key);
}
ResponsePayload::Count(n) => {
buf.extend_from_slice(&n.to_be_bytes());
}
ResponsePayload::Collections(collections) => {
buf.extend_from_slice(&(collections.len() as u32).to_be_bytes());
for c in collections {
let name = c.name.as_bytes();
buf.extend_from_slice(&(name.len() as u32).to_be_bytes());
buf.extend_from_slice(name);
buf.extend_from_slice(&(name.len() as u32).to_be_bytes());
buf.extend_from_slice(name);
buf.extend_from_slice(&c.hashname.to_be_bytes());
buf.extend_from_slice(&c.typ_hash.to_be_bytes());
buf.extend_from_slice(&c.version.to_be_bytes());
buf.extend_from_slice(&c.count.to_be_bytes());
}
}
}
}
pub type TreeMap = Arc<HashMap<u64, Arc<dyn RpcHandler>>>;
async fn handle_connection<R, W>(
reader: R,
writer: W,
trees: TreeMap,
mut stop_rx: async_broadcast::Receiver<()>,
) where
R: AsyncRead + Unpin + 'static,
W: AsyncWrite + Unpin + 'static,
{
let mut framed = Framed::new::<Response, Request>(RpcCodec, LengthDelimited::new())
.with_reader(reader)
.with_writer(writer);
loop {
let next_request = framed.next();
let stop_signal = stop_rx.recv();
pin_mut!(next_request);
pin_mut!(stop_signal);
match select(next_request, stop_signal).await {
Either::Left((Some(result), _)) => {
let response = match result {
Ok(request) => dispatch(&trees, request),
Err(e) => Response::Err {
code: 500,
message: e.to_string(),
},
};
if framed.send(response).await.is_err() {
break;
}
}
Either::Left((None, _)) => break,
Either::Right(_) => break,
}
}
}
fn dispatch(trees: &TreeMap, request: Request) -> Response {
if request.op == OpCode::ListCollections {
return list_collections(trees);
}
let handler = match trees.get(&request.hashname) {
Some(h) => h,
None => {
return Response::Err {
code: 404,
message: format!("tree not found: {:#X}", request.hashname),
};
}
};
let result = match (request.op, request.payload) {
(OpCode::Get, RequestPayload::Key(key)) => handler
.get(&key)
.map(|v| Response::Ok(ResponsePayload::OptionalData(v))),
(OpCode::Contains, RequestPayload::Key(key)) => handler
.contains(&key)
.map(|v| Response::Ok(ResponsePayload::OptionalLen(v))),
(OpCode::First, RequestPayload::Empty) => handler
.first()
.map(|v| Response::Ok(ResponsePayload::OptionalKV(v))),
(OpCode::Last, RequestPayload::Empty) => handler
.last()
.map(|v| Response::Ok(ResponsePayload::OptionalKV(v))),
(OpCode::Range, RequestPayload::Range { start, end }) => handler
.range(start, end)
.map(|v| Response::Ok(ResponsePayload::KeyValues(v))),
(OpCode::RangeKeys, RequestPayload::Range { start, end }) => handler
.range_keys(start, end)
.map(|v| Response::Ok(ResponsePayload::Keys(v))),
(OpCode::Count, RequestPayload::Count { exact: _ }) => handler
.count()
.map(|n| Response::Ok(ResponsePayload::Count(n))),
(OpCode::Upsert, RequestPayload::Upsert { key, flag, value }) => handler
.upsert(key, flag, value)
.map(|k| Response::Ok(ResponsePayload::Key(k))),
(OpCode::Remove, RequestPayload::Remove { key, soft: _ }) => handler
.remove(&key)
.map(|_| Response::Ok(ResponsePayload::Empty)),
(OpCode::Take, RequestPayload::Take { key: _, soft: _ }) => {
return Response::Err {
code: 501,
message: "take not implemented".into(),
};
}
(OpCode::ApplyBatch, RequestPayload::Batch(items)) => handler
.apply_batch(items)
.map(|_| Response::Ok(ResponsePayload::Empty)),
_ => {
return Response::Err {
code: 400,
message: "invalid op/payload combination".into(),
};
}
};
match result {
Ok(resp) => resp,
Err(e) => Response::Err {
code: 500,
message: e.to_string(),
},
}
}
fn list_collections(trees: &TreeMap) -> Response {
let collections = trees
.iter()
.map(|(&hashname, handler)| {
let (typ_hash, version) = handler.info();
let count = handler.count().unwrap_or(0);
CollectionMeta {
name: handler.name().to_string(),
hashname,
typ_hash,
version,
count,
}
})
.collect();
Response::Ok(ResponsePayload::Collections(collections))
}
async fn accept_tcp(
listener: TcpListener,
trees: TreeMap,
mut shutdown_rx: async_broadcast::Receiver<()>,
) {
loop {
let accept_fut = listener.accept();
let stop_fut = shutdown_rx.recv();
pin_mut!(accept_fut);
pin_mut!(stop_fut);
match select(accept_fut, stop_fut).await {
Either::Left((Ok((stream, _addr)), _)) => {
let trees = trees.clone();
let (reader, writer) = stream.into_split();
let rx = shutdown_rx.clone();
compio::runtime::spawn(handle_connection(reader, writer, trees, rx)).detach();
}
Either::Left((Err(e), _)) => {
tracing::error!("tcp accept error: {e}");
}
Either::Right(_) => {
tracing::info!("shutting down TCP server");
break;
}
}
}
}
async fn accept_uds(
listener: UnixListener,
trees: TreeMap,
mut shutdown_rx: async_broadcast::Receiver<()>,
) {
loop {
let accept_fut = listener.accept();
let stop_fut = shutdown_rx.recv();
pin_mut!(accept_fut);
pin_mut!(stop_fut);
match select(accept_fut, stop_fut).await {
Either::Left((Ok((stream, _addr)), _)) => {
let trees = trees.clone();
let (reader, writer) = stream.into_split();
let rx = shutdown_rx.clone();
compio::runtime::spawn(handle_connection(reader, writer, trees, rx)).detach();
}
Either::Left((Err(e), _)) => {
tracing::error!("uds accept error: {e}");
}
Either::Right(_) => {
tracing::info!("shutting down UDS server");
break;
}
}
}
}
pub struct RpcHandle {
thread: Option<JoinHandle<()>>,
shutdown_tx: async_broadcast::Sender<()>,
}
impl RpcHandle {
pub fn stop(&mut self) {
let _ = self.shutdown_tx.try_broadcast(());
if let Some(h) = self.thread.take() {
let _ = h.join();
}
}
}
impl Drop for RpcHandle {
fn drop(&mut self) {
self.stop();
}
}
impl Db {
pub fn listen_tcp(&self, addr: impl ToSocketAddrsAsync + Send + 'static) {
let trees = self.build_tree_map();
let shutdown_rx = self.shutdown.subscribe_broadcast();
let (local_tx, _) = async_broadcast::broadcast::<()>(1);
let local_tx2 = local_tx.clone();
let thread = std::thread::spawn(move || {
RuntimeBuilder::new()
.build()
.expect("Failed to build compio runtime")
.block_on(async {
let listener = TcpListener::bind(addr)
.await
.expect("Failed to bind TCP listener");
accept_tcp(listener, trees, shutdown_rx).await;
});
});
self.rpc_handles.lock().push(RpcHandle {
thread: Some(thread),
shutdown_tx: local_tx2,
});
}
pub fn listen_uds(&self, path: impl AsRef<Path> + Send + 'static) {
let trees = self.build_tree_map();
let shutdown_rx = self.shutdown.subscribe_broadcast();
let (local_tx, _) = async_broadcast::broadcast::<()>(1);
let local_tx2 = local_tx.clone();
let thread = std::thread::spawn(move || {
RuntimeBuilder::new()
.build()
.expect("Failed to build compio runtime")
.block_on(async {
let _ = std::fs::remove_file(path.as_ref());
let listener = UnixListener::bind(path)
.await
.expect("Failed to bind UDS listener");
accept_uds(listener, trees, shutdown_rx).await;
});
});
self.rpc_handles.lock().push(RpcHandle {
thread: Some(thread),
shutdown_tx: local_tx2,
});
}
}