redis-on-mysql 0.0.1

A Redis-compatible proxy that stores all data and Pub/Sub state in MySQL
Documentation
use std::sync::Arc;

use bytes::Bytes;
use resp_async::response::RespError;
use resp_async::{Cmd, PubSubHandle, PushHandle, State, Value};
use sqlx::Row;
use tokio::time::{self, Instant};

use crate::handlers::util::{arg_as_bytes, wrong_arity};
use crate::state::{AppState, Session, SessionHandle};
use crate::storage::map_sql_err;

pub async fn subscribe(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
    push: PushHandle,
    pubsub: PubSubHandle,
) -> Result<Value, RespError> {
    if cmd.args.is_empty() {
        return Err(wrong_arity("SUBSCRIBE"));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);

    let channels = parse_channels(&cmd.args)?;
    let subscriber_id = ensure_subscriber_id(session.as_ref(), auth.pool.as_ref()).await?;
    insert_subscriptions(auth.pool.as_ref(), subscriber_id, &channels).await?;

    let (responses, total) = {
        let mut state_guard = session.pubsub_state().await;
        let mut replies = Vec::with_capacity(channels.len());
        let mut count = state_guard.channels.len() as i64;
        for channel in &channels {
            if state_guard.channels.insert(channel.clone()) {
                count += 1;
            }
            replies.push(subscribe_reply("subscribe", channel, count));
        }
        (replies, count as usize)
    };

    pubsub.set(total);
    let should_spawn = total > 0 && session.try_activate_poller();
    if should_spawn {
        let pool = Arc::clone(&auth.pool);
        let cfg = state.config.pubsub.clone();
        let session_clone = session.clone();
        let push_clone = push.clone();
        let pubsub_clone = pubsub.clone();
        tokio::spawn(async move {
            run_poller(
                pool,
                subscriber_id,
                push_clone,
                pubsub_clone,
                cfg,
                session_clone,
            )
            .await;
        });
    }

    send_multi_response(responses, push).await
}

pub async fn unsubscribe(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
    push: PushHandle,
    pubsub: PubSubHandle,
) -> Result<Value, RespError> {
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);

    let (channels, existing) = {
        let state_guard = session.pubsub_state().await;
        if cmd.args.is_empty() {
            (
                state_guard.channels.iter().cloned().collect::<Vec<_>>(),
                state_guard.channels.len(),
            )
        } else {
            (parse_channels(&cmd.args)?, state_guard.channels.len())
        }
    };

    if channels.is_empty() && existing == 0 {
        let reply = Value::Array(vec![
            Value::Bulk(Bytes::from_static(b"unsubscribe")),
            Value::Null,
            Value::Integer(0),
        ]);
        return Ok(reply);
    }

    let subscriber_id = {
        let state_guard = session.pubsub_state().await;
        state_guard.subscriber_id
    };

    if let Some(subscriber_id) = subscriber_id {
        remove_subscriptions(auth.pool.as_ref(), subscriber_id, &channels).await?;
    }

    let (responses, total) = {
        let mut state_guard = session.pubsub_state().await;
        let mut replies = Vec::with_capacity(channels.len().max(1));
        let mut count = state_guard.channels.len() as i64;
        if channels.is_empty() {
            replies.push(Value::Array(vec![
                Value::Bulk(Bytes::from_static(b"unsubscribe")),
                Value::Null,
                Value::Integer(count),
            ]));
        } else {
            for channel in &channels {
                if state_guard.channels.remove(channel) {
                    count -= 1;
                }
                replies.push(subscribe_reply("unsubscribe", channel, count));
            }
        }
        (replies, count.max(0) as usize)
    };

    pubsub.set(total);
    if total == 0 {
        session.deactivate_poller();
    }

    send_multi_response(responses, push).await
}

pub async fn publish(
    Cmd(cmd): Cmd,
    State(state): State<AppState>,
    SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
    if cmd.args.len() != 2 {
        return Err(wrong_arity("PUBLISH"));
    }
    let auth = session.auth().await.ok_or(RespError::NoAuth)?;
    state.pools.touch(&auth.user);
    let channel = arg_as_bytes(&cmd.args[0])?;
    let payload = arg_as_bytes(&cmd.args[1])?;

    let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
    let message_id =
        sqlx::query("INSERT INTO redis_pubsub_message (channel, payload) VALUES (?, ?)")
            .bind(channel.as_ref())
            .bind(payload.as_ref())
            .execute(&mut *tx)
            .await
            .map_err(map_sql_err)?
            .last_insert_id();

    let res = sqlx::query(
        "INSERT INTO redis_pubsub_mailbox (subscriber_id, message_id, channel) \
         SELECT subscriber_id, ?, channel FROM redis_pubsub_subscription WHERE channel = ?",
    )
    .bind(message_id as i64)
    .bind(channel.as_ref())
    .execute(&mut *tx)
    .await
    .map_err(map_sql_err)?;
    tx.commit().await.map_err(map_sql_err)?;

    Ok(Value::Integer(res.rows_affected() as i64))
}

fn parse_channels(args: &[Value]) -> Result<Vec<Bytes>, RespError> {
    args.iter().map(|arg| arg_as_bytes(arg).cloned()).collect()
}

fn subscribe_reply(kind: &str, channel: &Bytes, count: i64) -> Value {
    let kind = match kind {
        "subscribe" => Bytes::from_static(b"subscribe"),
        "unsubscribe" => Bytes::from_static(b"unsubscribe"),
        _ => Bytes::from(kind.as_bytes().to_vec()),
    };
    Value::Array(vec![
        Value::Bulk(kind),
        Value::Bulk(channel.clone()),
        Value::Integer(count),
    ])
}

async fn send_multi_response(
    mut replies: Vec<Value>,
    push: PushHandle,
) -> Result<Value, RespError> {
    if replies.is_empty() {
        return Ok(Value::Null);
    }
    let first = replies.remove(0);
    for reply in replies {
        let _ = push.send(reply).await;
    }
    Ok(first)
}

async fn ensure_subscriber_id(session: &Session, pool: &sqlx::MySqlPool) -> Result<u64, RespError> {
    if let Some(id) = session.pubsub_state().await.subscriber_id {
        return Ok(id);
    }
    let result = sqlx::query("INSERT INTO redis_pubsub_subscriber () VALUES ()")
        .execute(pool)
        .await
        .map_err(map_sql_err)?;
    let id = result.last_insert_id();
    let mut state_guard = session.pubsub_state().await;
    if state_guard.subscriber_id.is_none() {
        state_guard.subscriber_id = Some(id);
    }
    Ok(state_guard.subscriber_id.unwrap_or(id))
}

async fn insert_subscriptions(
    pool: &sqlx::MySqlPool,
    subscriber_id: u64,
    channels: &[Bytes],
) -> Result<(), RespError> {
    if channels.is_empty() {
        return Ok(());
    }
    let mut qb = sqlx::QueryBuilder::new(
        "INSERT IGNORE INTO redis_pubsub_subscription (subscriber_id, channel) ",
    );
    qb.push_values(channels, |mut row, channel| {
        row.push_bind(subscriber_id as i64);
        row.push_bind(channel.as_ref());
    });
    qb.build().execute(pool).await.map_err(map_sql_err)?;
    Ok(())
}

async fn remove_subscriptions(
    pool: &sqlx::MySqlPool,
    subscriber_id: u64,
    channels: &[Bytes],
) -> Result<(), RespError> {
    if channels.is_empty() {
        return Ok(());
    }
    let mut qb =
        sqlx::QueryBuilder::new("DELETE FROM redis_pubsub_subscription WHERE subscriber_id = ");
    qb.push_bind(subscriber_id as i64);
    qb.push(" AND channel IN (");
    let mut separated = qb.separated(", ");
    for channel in channels {
        separated.push_bind(channel.as_ref());
    }
    qb.push(")");
    qb.build().execute(pool).await.map_err(map_sql_err)?;
    Ok(())
}

async fn run_poller(
    pool: Arc<sqlx::MySqlPool>,
    subscriber_id: u64,
    push: PushHandle,
    pubsub: PubSubHandle,
    cfg: crate::config::PubSubConfig,
    session: Arc<Session>,
) {
    let mut interval = time::interval(cfg.poll_interval);
    let mut last_heartbeat = Instant::now();
    loop {
        interval.tick().await;
        if pubsub.count() == 0 {
            break;
        }
        let rows = sqlx::query(
            "SELECT mb.message_id, mb.channel, msg.payload \
             FROM redis_pubsub_mailbox mb \
             JOIN redis_pubsub_message msg ON mb.message_id = msg.id \
             WHERE mb.subscriber_id = ? ORDER BY mb.message_id LIMIT ?",
        )
        .bind(subscriber_id as i64)
        .bind(cfg.poll_batch as i64)
        .fetch_all(pool.as_ref())
        .await;

        let mut message_ids = Vec::new();
        match rows {
            Ok(rows) => {
                for row in rows {
                    let message_id: i64 = row.try_get("message_id").unwrap_or_default();
                    let channel: Vec<u8> = row.try_get("channel").unwrap_or_default();
                    let payload: Vec<u8> = row.try_get("payload").unwrap_or_default();
                    let reply = Value::Array(vec![
                        Value::Bulk(Bytes::from_static(b"message")),
                        Value::Bulk(Bytes::from(channel)),
                        Value::Bulk(Bytes::from(payload)),
                    ]);
                    if push.send(reply).await.is_err() {
                        session.deactivate_poller();
                        return;
                    }
                    message_ids.push(message_id);
                }
            }
            Err(_) => {
                continue;
            }
        }

        if !message_ids.is_empty() {
            let mut qb =
                sqlx::QueryBuilder::new("DELETE FROM redis_pubsub_mailbox WHERE subscriber_id = ");
            qb.push_bind(subscriber_id as i64);
            qb.push(" AND message_id IN (");
            let mut separated = qb.separated(", ");
            for id in &message_ids {
                separated.push_bind(id);
            }
            qb.push(")");
            let _ = qb.build().execute(pool.as_ref()).await;
        }

        if last_heartbeat.elapsed() >= cfg.heartbeat_interval {
            let _ = sqlx::query(
                "UPDATE redis_pubsub_subscriber SET last_seen = CURRENT_TIMESTAMP(3) WHERE id = ?",
            )
            .bind(subscriber_id as i64)
            .execute(pool.as_ref())
            .await;
            last_heartbeat = Instant::now();
        }
    }

    session.deactivate_poller();
}