use std::sync::Arc;
use bytes::Bytes;
use resp_async::response::RespError;
use resp_async::{Cmd, PubSubHandle, PushHandle, State, Value};
use sqlx::Row;
use tokio::time::{self, Instant};
use crate::handlers::util::{arg_as_bytes, wrong_arity};
use crate::state::{AppState, Session, SessionHandle};
use crate::storage::map_sql_err;
pub async fn subscribe(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
push: PushHandle,
pubsub: PubSubHandle,
) -> Result<Value, RespError> {
if cmd.args.is_empty() {
return Err(wrong_arity("SUBSCRIBE"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let channels = parse_channels(&cmd.args)?;
let subscriber_id = ensure_subscriber_id(session.as_ref(), auth.pool.as_ref()).await?;
insert_subscriptions(auth.pool.as_ref(), subscriber_id, &channels).await?;
let (responses, total) = {
let mut state_guard = session.pubsub_state().await;
let mut replies = Vec::with_capacity(channels.len());
let mut count = state_guard.channels.len() as i64;
for channel in &channels {
if state_guard.channels.insert(channel.clone()) {
count += 1;
}
replies.push(subscribe_reply("subscribe", channel, count));
}
(replies, count as usize)
};
pubsub.set(total);
let should_spawn = total > 0 && session.try_activate_poller();
if should_spawn {
let pool = Arc::clone(&auth.pool);
let cfg = state.config.pubsub.clone();
let session_clone = session.clone();
let push_clone = push.clone();
let pubsub_clone = pubsub.clone();
tokio::spawn(async move {
run_poller(
pool,
subscriber_id,
push_clone,
pubsub_clone,
cfg,
session_clone,
)
.await;
});
}
send_multi_response(responses, push).await
}
pub async fn unsubscribe(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
push: PushHandle,
pubsub: PubSubHandle,
) -> Result<Value, RespError> {
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let (channels, existing) = {
let state_guard = session.pubsub_state().await;
if cmd.args.is_empty() {
(
state_guard.channels.iter().cloned().collect::<Vec<_>>(),
state_guard.channels.len(),
)
} else {
(parse_channels(&cmd.args)?, state_guard.channels.len())
}
};
if channels.is_empty() && existing == 0 {
let reply = Value::Array(vec![
Value::Bulk(Bytes::from_static(b"unsubscribe")),
Value::Null,
Value::Integer(0),
]);
return Ok(reply);
}
let subscriber_id = {
let state_guard = session.pubsub_state().await;
state_guard.subscriber_id
};
if let Some(subscriber_id) = subscriber_id {
remove_subscriptions(auth.pool.as_ref(), subscriber_id, &channels).await?;
}
let (responses, total) = {
let mut state_guard = session.pubsub_state().await;
let mut replies = Vec::with_capacity(channels.len().max(1));
let mut count = state_guard.channels.len() as i64;
if channels.is_empty() {
replies.push(Value::Array(vec![
Value::Bulk(Bytes::from_static(b"unsubscribe")),
Value::Null,
Value::Integer(count),
]));
} else {
for channel in &channels {
if state_guard.channels.remove(channel) {
count -= 1;
}
replies.push(subscribe_reply("unsubscribe", channel, count));
}
}
(replies, count.max(0) as usize)
};
pubsub.set(total);
if total == 0 {
session.deactivate_poller();
}
send_multi_response(responses, push).await
}
pub async fn publish(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() != 2 {
return Err(wrong_arity("PUBLISH"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let channel = arg_as_bytes(&cmd.args[0])?;
let payload = arg_as_bytes(&cmd.args[1])?;
let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
let message_id =
sqlx::query("INSERT INTO redis_pubsub_message (channel, payload) VALUES (?, ?)")
.bind(channel.as_ref())
.bind(payload.as_ref())
.execute(&mut *tx)
.await
.map_err(map_sql_err)?
.last_insert_id();
let res = sqlx::query(
"INSERT INTO redis_pubsub_mailbox (subscriber_id, message_id, channel) \
SELECT subscriber_id, ?, channel FROM redis_pubsub_subscription WHERE channel = ?",
)
.bind(message_id as i64)
.bind(channel.as_ref())
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
tx.commit().await.map_err(map_sql_err)?;
Ok(Value::Integer(res.rows_affected() as i64))
}
fn parse_channels(args: &[Value]) -> Result<Vec<Bytes>, RespError> {
args.iter().map(|arg| arg_as_bytes(arg).cloned()).collect()
}
fn subscribe_reply(kind: &str, channel: &Bytes, count: i64) -> Value {
let kind = match kind {
"subscribe" => Bytes::from_static(b"subscribe"),
"unsubscribe" => Bytes::from_static(b"unsubscribe"),
_ => Bytes::from(kind.as_bytes().to_vec()),
};
Value::Array(vec![
Value::Bulk(kind),
Value::Bulk(channel.clone()),
Value::Integer(count),
])
}
async fn send_multi_response(
mut replies: Vec<Value>,
push: PushHandle,
) -> Result<Value, RespError> {
if replies.is_empty() {
return Ok(Value::Null);
}
let first = replies.remove(0);
for reply in replies {
let _ = push.send(reply).await;
}
Ok(first)
}
async fn ensure_subscriber_id(session: &Session, pool: &sqlx::MySqlPool) -> Result<u64, RespError> {
if let Some(id) = session.pubsub_state().await.subscriber_id {
return Ok(id);
}
let result = sqlx::query("INSERT INTO redis_pubsub_subscriber () VALUES ()")
.execute(pool)
.await
.map_err(map_sql_err)?;
let id = result.last_insert_id();
let mut state_guard = session.pubsub_state().await;
if state_guard.subscriber_id.is_none() {
state_guard.subscriber_id = Some(id);
}
Ok(state_guard.subscriber_id.unwrap_or(id))
}
async fn insert_subscriptions(
pool: &sqlx::MySqlPool,
subscriber_id: u64,
channels: &[Bytes],
) -> Result<(), RespError> {
if channels.is_empty() {
return Ok(());
}
let mut qb = sqlx::QueryBuilder::new(
"INSERT IGNORE INTO redis_pubsub_subscription (subscriber_id, channel) ",
);
qb.push_values(channels, |mut row, channel| {
row.push_bind(subscriber_id as i64);
row.push_bind(channel.as_ref());
});
qb.build().execute(pool).await.map_err(map_sql_err)?;
Ok(())
}
async fn remove_subscriptions(
pool: &sqlx::MySqlPool,
subscriber_id: u64,
channels: &[Bytes],
) -> Result<(), RespError> {
if channels.is_empty() {
return Ok(());
}
let mut qb =
sqlx::QueryBuilder::new("DELETE FROM redis_pubsub_subscription WHERE subscriber_id = ");
qb.push_bind(subscriber_id as i64);
qb.push(" AND channel IN (");
let mut separated = qb.separated(", ");
for channel in channels {
separated.push_bind(channel.as_ref());
}
qb.push(")");
qb.build().execute(pool).await.map_err(map_sql_err)?;
Ok(())
}
async fn run_poller(
pool: Arc<sqlx::MySqlPool>,
subscriber_id: u64,
push: PushHandle,
pubsub: PubSubHandle,
cfg: crate::config::PubSubConfig,
session: Arc<Session>,
) {
let mut interval = time::interval(cfg.poll_interval);
let mut last_heartbeat = Instant::now();
loop {
interval.tick().await;
if pubsub.count() == 0 {
break;
}
let rows = sqlx::query(
"SELECT mb.message_id, mb.channel, msg.payload \
FROM redis_pubsub_mailbox mb \
JOIN redis_pubsub_message msg ON mb.message_id = msg.id \
WHERE mb.subscriber_id = ? ORDER BY mb.message_id LIMIT ?",
)
.bind(subscriber_id as i64)
.bind(cfg.poll_batch as i64)
.fetch_all(pool.as_ref())
.await;
let mut message_ids = Vec::new();
match rows {
Ok(rows) => {
for row in rows {
let message_id: i64 = row.try_get("message_id").unwrap_or_default();
let channel: Vec<u8> = row.try_get("channel").unwrap_or_default();
let payload: Vec<u8> = row.try_get("payload").unwrap_or_default();
let reply = Value::Array(vec![
Value::Bulk(Bytes::from_static(b"message")),
Value::Bulk(Bytes::from(channel)),
Value::Bulk(Bytes::from(payload)),
]);
if push.send(reply).await.is_err() {
session.deactivate_poller();
return;
}
message_ids.push(message_id);
}
}
Err(_) => {
continue;
}
}
if !message_ids.is_empty() {
let mut qb =
sqlx::QueryBuilder::new("DELETE FROM redis_pubsub_mailbox WHERE subscriber_id = ");
qb.push_bind(subscriber_id as i64);
qb.push(" AND message_id IN (");
let mut separated = qb.separated(", ");
for id in &message_ids {
separated.push_bind(id);
}
qb.push(")");
let _ = qb.build().execute(pool.as_ref()).await;
}
if last_heartbeat.elapsed() >= cfg.heartbeat_interval {
let _ = sqlx::query(
"UPDATE redis_pubsub_subscriber SET last_seen = CURRENT_TIMESTAMP(3) WHERE id = ?",
)
.bind(subscriber_id as i64)
.execute(pool.as_ref())
.await;
last_heartbeat = Instant::now();
}
}
session.deactivate_poller();
}