armour 0.30.27

DDL and serialization for key-value storage
Documentation
use std::{
    collections::HashMap,
    sync::{Arc, Mutex, OnceLock},
};

use compio::{
    io::{
        AsyncRead, AsyncWrite,
        framed::{Framed, frame::LengthDelimited},
    },
    net::{TcpListener, UnixListener},
};
use futures_util::{
    SinkExt, StreamExt,
    future::{Either, select},
    pin_mut,
};

use super::{codec::RpcCodec, handler::RpcHandler, protocol::*};

static SHUTDOWN_SENDERS: OnceLock<Mutex<Vec<async_broadcast::Sender<()>>>> = OnceLock::new();

fn register_shutdown(tx: async_broadcast::Sender<()>) {
    let mutex = SHUTDOWN_SENDERS.get_or_init(|| {
        ctrlc::set_handler(|| {
            if let Some(mutex) = SHUTDOWN_SENDERS.get()
                && let Ok(senders) = mutex.lock()
            {
                for tx in senders.iter() {
                    let _ = tx.try_broadcast(());
                }
            }
        })
        .ok();
        Mutex::new(Vec::new())
    });

    if let Ok(mut senders) = mutex.lock() {
        senders.push(tx);
    }
}

pub(crate) type TreeMap = Arc<HashMap<u64, Arc<dyn RpcHandler>>>;

pub(crate) async fn handle_connection<R, W>(
    reader: R,
    writer: W,
    trees: TreeMap,
    mut stop_rx: async_broadcast::Receiver<()>,
) where
    R: AsyncRead + Unpin + 'static,
    W: AsyncWrite + Unpin + 'static,
{
    let mut framed = Framed::new::<Response, Request>(RpcCodec, LengthDelimited::new())
        .with_reader(reader)
        .with_writer(writer);

    loop {
        let next_request = framed.next();
        let stop_signal = stop_rx.recv();
        pin_mut!(next_request);
        pin_mut!(stop_signal);

        match select(next_request, stop_signal).await {
            Either::Left((Some(result), _)) => {
                let response = match result {
                    Ok(request) => dispatch(&trees, request),
                    Err(e) => Response::Err {
                        code: 500,
                        message: e.to_string(),
                    },
                };
                if framed.send(response).await.is_err() {
                    break;
                }
            }
            Either::Left((None, _)) => break, // Connection closed
            Either::Right(_) => break,        // Stop signal received
        }
    }
}

fn dispatch(trees: &TreeMap, request: Request) -> Response {
    if request.op == OpCode::ListCollections {
        return super::collections::list_collections(trees);
    }

    let handler = match trees.get(&request.hashname) {
        Some(h) => h,
        None => {
            return Response::Err {
                code: 404,
                message: format!("tree not found: {:#X}", request.hashname),
            };
        }
    };

    let result = match (request.op, request.payload) {
        (OpCode::Get, RequestPayload::Key(key)) => handler
            .get(&key)
            .map(|v| Response::Ok(ResponsePayload::OptionalData(v))),
        (OpCode::Contains, RequestPayload::Key(key)) => handler
            .contains(&key)
            .map(|v| Response::Ok(ResponsePayload::OptionalLen(v))),
        (OpCode::First, RequestPayload::Empty) => handler
            .first()
            .map(|v| Response::Ok(ResponsePayload::OptionalKV(v))),
        (OpCode::Last, RequestPayload::Empty) => handler
            .last()
            .map(|v| Response::Ok(ResponsePayload::OptionalKV(v))),
        (OpCode::Range, RequestPayload::Range { start, end }) => handler
            .range(start, end)
            .map(|v| Response::Ok(ResponsePayload::KeyValues(v))),
        (OpCode::RangeKeys, RequestPayload::Range { start, end }) => handler
            .range_keys(start, end)
            .map(|v| Response::Ok(ResponsePayload::Keys(v))),
        (OpCode::Count, RequestPayload::Count { exact }) => handler
            .count(exact)
            .map(|n| Response::Ok(ResponsePayload::Count(n))),
        (OpCode::Upsert, RequestPayload::Upsert { key, flag, value }) => handler
            .upsert(key, flag, value)
            .map(|k| Response::Ok(ResponsePayload::Key(k))),
        (OpCode::Remove, RequestPayload::Remove { key, soft }) => handler
            .remove(&key, soft)
            .map(|_| Response::Ok(ResponsePayload::Empty)),
        (OpCode::Take, RequestPayload::Take { key, soft }) => handler
            .take(&key, soft)
            .map(|v| Response::Ok(ResponsePayload::OptionalData(v))),
        (OpCode::ApplyBatch, RequestPayload::Batch(items)) => handler
            .apply_batch(items)
            .map(|_| Response::Ok(ResponsePayload::Empty)),
        _ => {
            return Response::Err {
                code: 400,
                message: "invalid op/payload combination".into(),
            };
        }
    };

    match result {
        Ok(resp) => resp,
        Err(e) => {
            let (code, _) = e.to_resp();
            Response::Err {
                code,
                message: e.to_string(),
            }
        }
    }
}

pub(crate) async fn accept_tcp(listener: TcpListener, trees: TreeMap) {
    let (shutdown_tx, mut shutdown_rx) = async_broadcast::broadcast::<()>(1);
    register_shutdown(shutdown_tx);

    loop {
        let accept_fut = listener.accept();
        let stop_fut = shutdown_rx.recv();
        pin_mut!(accept_fut);
        pin_mut!(stop_fut);

        match select(accept_fut, stop_fut).await {
            Either::Left((Ok((stream, _addr)), _)) => {
                let trees = trees.clone();
                let (reader, writer) = stream.into_split();
                let rx = shutdown_rx.clone();
                compio::runtime::spawn(handle_connection(reader, writer, trees, rx)).detach();
            }
            Either::Left((Err(e), _)) => {
                error!("tcp accept error: {e}");
            }
            Either::Right(_) => {
                info!("Received Ctrl-C, shutting down TCP server...");
                break;
            }
        }
    }
}

pub(crate) async fn accept_uds(listener: UnixListener, trees: TreeMap) {
    let (shutdown_tx, mut shutdown_rx) = async_broadcast::broadcast::<()>(1);
    register_shutdown(shutdown_tx);

    loop {
        let accept_fut = listener.accept();
        let stop_fut = shutdown_rx.recv();
        pin_mut!(accept_fut);
        pin_mut!(stop_fut);

        match select(accept_fut, stop_fut).await {
            Either::Left((Ok((stream, _addr)), _)) => {
                let trees = trees.clone();
                let (reader, writer) = stream.into_split();
                let rx = shutdown_rx.clone();
                compio::runtime::spawn(handle_connection(reader, writer, trees, rx)).detach();
            }
            Either::Left((Err(e), _)) => {
                error!("uds accept error: {e}");
            }
            Either::Right(_) => {
                info!("Received Ctrl-C, shutting down UDS server...");
                break;
            }
        }
    }
}