shardmap 0.2.1

Sharded embedded in-memory map with optional cache, protocol, and server internals
Documentation
#[cfg(feature = "server")]
use bytes::BytesMut;

use crate::commands::redis::{array_bulk, bulk, error, int, wrong_arity};
#[cfg(feature = "server")]
use crate::commands::redis::{write_resp_array_header, write_resp_null, write_resp_wrong_arity};
use crate::protocol::Frame;
#[cfg(feature = "server")]
use crate::server::wire::ServerWire;
use crate::storage::EmbeddedStore;

macro_rules! define_pubsub_command {
    ($type:ident, $static_name:ident, $name:literal, $mutates:expr) => {
        #[derive(Debug, Clone, Copy)]
        pub(crate) struct $type;

        pub(crate) static $static_name: $type = $type;

        impl crate::commands::CommandSpec for $type {
            const NAME: &'static str = $name;
            const MUTATES_VALUE: bool = $mutates;
        }
    };
}

define_pubsub_command!(Publish, PUBLISH_COMMAND, "PUBLISH", false);
define_pubsub_command!(SPublish, SPUBLISH_COMMAND, "SPUBLISH", false);
define_pubsub_command!(PubSub, PUBSUB_COMMAND, "PUBSUB", false);
define_pubsub_command!(Subscribe, SUBSCRIBE_COMMAND, "SUBSCRIBE", false);
define_pubsub_command!(Unsubscribe, UNSUBSCRIBE_COMMAND, "UNSUBSCRIBE", false);
define_pubsub_command!(PSubscribe, PSUBSCRIBE_COMMAND, "PSUBSCRIBE", false);
define_pubsub_command!(PUnsubscribe, PUNSUBSCRIBE_COMMAND, "PUNSUBSCRIBE", false);
define_pubsub_command!(SSubscribe, SSUBSCRIBE_COMMAND, "SSUBSCRIBE", false);
define_pubsub_command!(SUnsubscribe, SUNSUBSCRIBE_COMMAND, "SUNSUBSCRIBE", false);

impl crate::commands::redis::RedisCommand for Publish {
    fn execute(_store: &EmbeddedStore, args: &[&[u8]]) -> Frame {
        match args {
            [_channel, _message] => int(0),
            _ => wrong_arity("PUBLISH"),
        }
    }

    #[cfg(feature = "server")]
    fn write_resp(_store: &EmbeddedStore, args: &[&[u8]], out: &mut BytesMut) {
        match args {
            [_channel, _message] => ServerWire::write_resp_integer(out, 0),
            _ => write_resp_wrong_arity(out, "PUBLISH"),
        }
    }
}

impl crate::commands::redis::RedisCommand for SPublish {
    fn execute(_store: &EmbeddedStore, args: &[&[u8]]) -> Frame {
        match args {
            [_channel, _message] => int(0),
            _ => wrong_arity("SPUBLISH"),
        }
    }

    #[cfg(feature = "server")]
    fn write_resp(_store: &EmbeddedStore, args: &[&[u8]], out: &mut BytesMut) {
        match args {
            [_channel, _message] => ServerWire::write_resp_integer(out, 0),
            _ => write_resp_wrong_arity(out, "SPUBLISH"),
        }
    }
}

impl crate::commands::redis::RedisCommand for PubSub {
    fn execute(_store: &EmbeddedStore, args: &[&[u8]]) -> Frame {
        match args {
            [] => wrong_arity("PUBSUB"),
            [sub] if sub.eq_ignore_ascii_case(b"CHANNELS") => Frame::Array(Vec::new()),
            [sub, _pattern] if sub.eq_ignore_ascii_case(b"CHANNELS") => Frame::Array(Vec::new()),
            [sub] if sub.eq_ignore_ascii_case(b"SHARDCHANNELS") => Frame::Array(Vec::new()),
            [sub, _pattern] if sub.eq_ignore_ascii_case(b"SHARDCHANNELS") => {
                Frame::Array(Vec::new())
            }
            [sub] if sub.eq_ignore_ascii_case(b"NUMPAT") => int(0),
            [sub] if sub.eq_ignore_ascii_case(b"SHARDNUMSUB") => Frame::Array(Vec::new()),
            [sub, channels @ ..] if sub.eq_ignore_ascii_case(b"NUMSUB") => {
                let mut items = Vec::with_capacity(channels.len().saturating_mul(2));
                for channel in channels {
                    items.push(bulk(channel.to_vec()));
                    items.push(int(0));
                }
                Frame::Array(items)
            }
            [sub, channels @ ..] if sub.eq_ignore_ascii_case(b"SHARDNUMSUB") => {
                let mut items = Vec::with_capacity(channels.len().saturating_mul(2));
                for channel in channels {
                    items.push(bulk(channel.to_vec()));
                    items.push(int(0));
                }
                Frame::Array(items)
            }
            [sub] if sub.eq_ignore_ascii_case(b"HELP") => array_bulk(vec![
                b"PUBSUB CHANNELS [pattern]".to_vec(),
                b"PUBSUB NUMSUB [channel ...]".to_vec(),
                b"PUBSUB NUMPAT".to_vec(),
                b"PUBSUB SHARDCHANNELS [pattern]".to_vec(),
                b"PUBSUB SHARDNUMSUB [channel ...]".to_vec(),
            ]),
            _ => error("ERR unknown PUBSUB subcommand or wrong number of arguments"),
        }
    }

    #[cfg(feature = "server")]
    fn write_resp(_store: &EmbeddedStore, args: &[&[u8]], out: &mut BytesMut) {
        match args {
            [] => write_resp_wrong_arity(out, "PUBSUB"),
            [sub] if sub.eq_ignore_ascii_case(b"CHANNELS") => write_resp_array_header(out, 0),
            [sub, _pattern] if sub.eq_ignore_ascii_case(b"CHANNELS") => {
                write_resp_array_header(out, 0);
            }
            [sub] if sub.eq_ignore_ascii_case(b"SHARDCHANNELS") => {
                write_resp_array_header(out, 0);
            }
            [sub, _pattern] if sub.eq_ignore_ascii_case(b"SHARDCHANNELS") => {
                write_resp_array_header(out, 0);
            }
            [sub] if sub.eq_ignore_ascii_case(b"NUMPAT") => {
                ServerWire::write_resp_integer(out, 0);
            }
            [sub] if sub.eq_ignore_ascii_case(b"SHARDNUMSUB") => {
                write_resp_array_header(out, 0);
            }
            [sub, channels @ ..] if sub.eq_ignore_ascii_case(b"NUMSUB") => {
                write_resp_array_header(out, channels.len().saturating_mul(2));
                for channel in channels {
                    ServerWire::write_resp_blob_string(out, channel);
                    ServerWire::write_resp_integer(out, 0);
                }
            }
            [sub, channels @ ..] if sub.eq_ignore_ascii_case(b"SHARDNUMSUB") => {
                write_resp_array_header(out, channels.len().saturating_mul(2));
                for channel in channels {
                    ServerWire::write_resp_blob_string(out, channel);
                    ServerWire::write_resp_integer(out, 0);
                }
            }
            [sub] if sub.eq_ignore_ascii_case(b"HELP") => {
                write_resp_array_header(out, 5);
                ServerWire::write_resp_blob_string(out, b"PUBSUB CHANNELS [pattern]");
                ServerWire::write_resp_blob_string(out, b"PUBSUB NUMSUB [channel ...]");
                ServerWire::write_resp_blob_string(out, b"PUBSUB NUMPAT");
                ServerWire::write_resp_blob_string(out, b"PUBSUB SHARDCHANNELS [pattern]");
                ServerWire::write_resp_blob_string(out, b"PUBSUB SHARDNUMSUB [channel ...]");
            }
            _ => ServerWire::write_resp_error(
                out,
                "ERR unknown PUBSUB subcommand or wrong number of arguments",
            ),
        }
    }
}

impl crate::commands::redis::RedisCommand for Subscribe {
    fn execute(_store: &EmbeddedStore, args: &[&[u8]]) -> Frame {
        subscription_ack("subscribe", args, true)
    }

    #[cfg(feature = "server")]
    fn write_resp(_store: &EmbeddedStore, args: &[&[u8]], out: &mut BytesMut) {
        write_subscription_ack_resp(out, b"subscribe", args, true);
    }
}

impl crate::commands::redis::RedisCommand for PSubscribe {
    fn execute(_store: &EmbeddedStore, args: &[&[u8]]) -> Frame {
        subscription_ack("psubscribe", args, true)
    }

    #[cfg(feature = "server")]
    fn write_resp(_store: &EmbeddedStore, args: &[&[u8]], out: &mut BytesMut) {
        write_subscription_ack_resp(out, b"psubscribe", args, true);
    }
}

impl crate::commands::redis::RedisCommand for Unsubscribe {
    fn execute(_store: &EmbeddedStore, args: &[&[u8]]) -> Frame {
        subscription_ack("unsubscribe", args, false)
    }

    #[cfg(feature = "server")]
    fn write_resp(_store: &EmbeddedStore, args: &[&[u8]], out: &mut BytesMut) {
        write_subscription_ack_resp(out, b"unsubscribe", args, false);
    }
}

impl crate::commands::redis::RedisCommand for PUnsubscribe {
    fn execute(_store: &EmbeddedStore, args: &[&[u8]]) -> Frame {
        subscription_ack("punsubscribe", args, false)
    }

    #[cfg(feature = "server")]
    fn write_resp(_store: &EmbeddedStore, args: &[&[u8]], out: &mut BytesMut) {
        write_subscription_ack_resp(out, b"punsubscribe", args, false);
    }
}

impl crate::commands::redis::RedisCommand for SSubscribe {
    fn execute(_store: &EmbeddedStore, args: &[&[u8]]) -> Frame {
        subscription_ack("ssubscribe", args, true)
    }

    #[cfg(feature = "server")]
    fn write_resp(_store: &EmbeddedStore, args: &[&[u8]], out: &mut BytesMut) {
        write_subscription_ack_resp(out, b"ssubscribe", args, true);
    }
}

impl crate::commands::redis::RedisCommand for SUnsubscribe {
    fn execute(_store: &EmbeddedStore, args: &[&[u8]]) -> Frame {
        subscription_ack("sunsubscribe", args, false)
    }

    #[cfg(feature = "server")]
    fn write_resp(_store: &EmbeddedStore, args: &[&[u8]], out: &mut BytesMut) {
        write_subscription_ack_resp(out, b"sunsubscribe", args, false);
    }
}

fn subscription_ack(kind: &str, args: &[&[u8]], require_channel: bool) -> Frame {
    match args {
        [] if require_channel => wrong_arity(kind),
        [] => Frame::Array(vec![bulk(kind.as_bytes().to_vec()), Frame::Null, int(0)]),
        [channel] => Frame::Array(vec![
            bulk(kind.as_bytes().to_vec()),
            bulk(channel.to_vec()),
            int(if require_channel { 1 } else { 0 }),
        ]),
        channels => Frame::Array(
            channels
                .iter()
                .enumerate()
                .map(|(index, channel)| {
                    Frame::Array(vec![
                        bulk(kind.as_bytes().to_vec()),
                        bulk(channel.to_vec()),
                        int(if require_channel {
                            index.saturating_add(1) as i64
                        } else {
                            0
                        }),
                    ])
                })
                .collect(),
        ),
    }
}

#[cfg(feature = "server")]
fn write_subscription_ack_resp(
    out: &mut BytesMut,
    kind: &'static [u8],
    args: &[&[u8]],
    require_channel: bool,
) {
    match args {
        [] if require_channel => {
            let command = std::str::from_utf8(kind).unwrap_or("subscribe");
            write_resp_wrong_arity(out, command);
        }
        [] => write_subscription_ack_item_resp(out, kind, None, 0),
        [channel] => write_subscription_ack_item_resp(
            out,
            kind,
            Some(channel),
            if require_channel { 1 } else { 0 },
        ),
        channels => {
            write_resp_array_header(out, channels.len());
            for (index, channel) in channels.iter().enumerate() {
                write_subscription_ack_item_resp(
                    out,
                    kind,
                    Some(channel),
                    if require_channel {
                        index.saturating_add(1) as i64
                    } else {
                        0
                    },
                );
            }
        }
    }
}

#[cfg(feature = "server")]
fn write_subscription_ack_item_resp(
    out: &mut BytesMut,
    kind: &[u8],
    channel: Option<&[u8]>,
    count: i64,
) {
    write_resp_array_header(out, 3);
    ServerWire::write_resp_blob_string(out, kind);
    match channel {
        Some(channel) => ServerWire::write_resp_blob_string(out, channel),
        None => write_resp_null(out),
    }
    ServerWire::write_resp_integer(out, count);
}