use crate::Commands;
use crate::reduce::pubsub_pmessage;
use crate::shard::Shard;
use kevy_resp::{ArgvView, encode_array_len, encode_bulk, encode_integer, encode_null_bulk};
use kevy_store::glob_match;
impl<C: Commands> Shard<C> {
pub(crate) fn do_psubscribe<A: ArgvView + ?Sized>(
&mut self,
conn_id: u64,
seq: u64,
args: &A,
) {
if self.conns.get(&conn_id).is_none() {
return;
}
let patterns: Vec<Vec<u8>> = (1..args.len()).map(|i| args[i].to_vec()).collect();
let (reply, changed) = self.apply_psub_to_conn(conn_id, &patterns, true);
self.apply_psub_to_registry(&changed, true);
self.fold_pubsub_reply(conn_id, seq, reply);
}
pub(crate) fn do_punsubscribe<A: ArgvView + ?Sized>(
&mut self,
conn_id: u64,
seq: u64,
args: &A,
) {
let patterns: Vec<Vec<u8>> = match self.conns.get(&conn_id) {
None => return,
Some(_) if args.len() > 1 => (1..args.len()).map(|i| args[i].to_vec()).collect(),
Some(c) => c.psub.iter().cloned().collect(),
};
let (reply, changed) = self.apply_psub_to_conn(conn_id, &patterns, false);
self.apply_psub_to_registry(&changed, false);
self.fold_pubsub_reply(conn_id, seq, reply);
}
fn apply_psub_to_conn(
&mut self,
conn_id: u64,
patterns: &[Vec<u8>],
subscribe: bool,
) -> (Vec<u8>, Vec<Vec<u8>>) {
let verb: &[u8] = if subscribe { b"psubscribe" } else { b"punsubscribe" };
let mut out = Vec::new();
let mut changed: Vec<Vec<u8>> = Vec::new();
if patterns.is_empty() {
let count = self.psub_count_for(conn_id);
encode_array_len(&mut out, 3);
encode_bulk(&mut out, verb);
encode_null_bulk(&mut out);
encode_integer(&mut out, count as i64);
return (out, changed);
}
for pat in patterns {
let did = if subscribe {
self.add_psub_local(conn_id, pat)
} else {
self.remove_psub_local(conn_id, pat)
};
if did {
changed.push(pat.clone());
}
let count = self.psub_count_for(conn_id);
encode_array_len(&mut out, 3);
encode_bulk(&mut out, verb);
encode_bulk(&mut out, pat);
encode_integer(&mut out, count as i64);
}
(out, changed)
}
fn add_psub_local(&mut self, conn_id: u64, pattern: &[u8]) -> bool {
let Some(c) = self.conns.get_mut(&conn_id) else { return false };
if !c.psub.insert(pattern.to_vec()) {
return false;
}
self.psub_local
.entry(pattern.to_vec())
.or_default()
.push(conn_id);
true
}
fn remove_psub_local(&mut self, conn_id: u64, pattern: &[u8]) -> bool {
let Some(c) = self.conns.get_mut(&conn_id) else { return false };
if !c.psub.remove(pattern) {
return false;
}
if let Some(ids) = self.psub_local.get_mut(pattern) {
ids.retain(|&id| id != conn_id);
if ids.is_empty() {
self.psub_local.remove(pattern);
}
}
true
}
fn psub_count_for(&self, conn_id: u64) -> usize {
match self.conns.get(&conn_id) {
Some(c) => c.sub.len() + c.psub.len(),
None => 0,
}
}
fn apply_psub_to_registry(&self, changed: &[Vec<u8>], subscribe: bool) {
if changed.is_empty() {
return;
}
let bit = 1u64 << self.id;
let mut reg = self.pubsub_patterns.write().expect("pubsub patterns");
for pat in changed {
let pos = reg.iter().position(|(p, ..)| p == pat);
if subscribe {
let local_has_after = self
.psub_local
.get(pat)
.is_some_and(|ids| !ids.is_empty());
match pos {
Some(i) => {
reg[i].1 += 1;
if local_has_after {
reg[i].2 |= bit;
}
}
None => reg.push((pat.clone(), 1, if local_has_after { bit } else { 0 })),
}
} else if let Some(i) = pos {
reg[i].1 = reg[i].1.saturating_sub(1);
let local_has_after = self
.psub_local
.get(pat)
.is_some_and(|ids| !ids.is_empty());
if !local_has_after {
reg[i].2 &= !bit;
}
if reg[i].1 == 0 {
reg.swap_remove(i);
}
}
}
}
fn fold_pubsub_reply(&mut self, conn_id: u64, seq: u64, reply: Vec<u8>) {
if let Some(c) = self.conns.get_mut(&conn_id) {
c.pending.push_back(crate::message::PendingSlot {
remaining: 1,
agg: crate::message::Agg::First(None),
done: None,
});
}
self.fold(conn_id, seq, crate::message::Part::Reply(reply));
}
pub(crate) fn pattern_match_for_channel(&self, channel: &[u8]) -> (u32, u64) {
let reg = self.pubsub_patterns.read().expect("pubsub patterns");
if reg.is_empty() {
return (0, 0);
}
let mut count: u32 = 0;
let mut bits: u64 = 0;
for (pat, cnt, b) in reg.iter() {
if glob_match(pat, channel) {
count = count.saturating_add(*cnt);
bits |= *b;
}
}
(count, bits)
}
pub(crate) fn deliver_pmessages(&mut self, channel: &[u8], msg: &[u8]) {
if self.psub_local.is_empty() {
return;
}
let mut plans: Vec<(Vec<u8>, u64)> = Vec::new();
for (pat, ids) in &self.psub_local {
if glob_match(pat, channel) {
for id in ids {
plans.push((pat.clone(), *id));
}
}
}
if plans.is_empty() {
return;
}
let mut touched: Vec<u64> = Vec::with_capacity(plans.len());
for (pat, id) in plans {
let frame = pubsub_pmessage(&pat, channel, msg);
if let Some(c) = self.conns.get_mut(&id) {
c.output.extend_from_slice(&frame);
touched.push(id);
}
}
self.dirty.extend_from_slice(&touched);
}
pub(crate) fn unregister_psubs(&mut self, patterns: &std::collections::HashSet<Vec<u8>>) {
if patterns.is_empty() {
return;
}
let bit = 1u64 << self.id;
let mut reg = self.pubsub_patterns.write().expect("pubsub patterns");
for pat in patterns {
if let Some(i) = reg.iter().position(|(p, ..)| p == pat) {
reg[i].1 = reg[i].1.saturating_sub(1);
let local_has_after = self
.psub_local
.get(pat)
.is_some_and(|ids| !ids.is_empty());
if !local_has_after {
reg[i].2 &= !bit;
}
if reg[i].1 == 0 {
reg.swap_remove(i);
}
}
}
}
}