use crate::message::{Agg, Inbound, Op, Part, PendingSlot};
use crate::reduce::{drain_front, shard_of};
use crate::shard::Shard;
use crate::{Commands, ResolvedCmd, Route};
use kevy_resp::{Argv, ArgvView, encode_array_len};
use std::collections::HashMap;
impl<C: Commands> Shard<C> {
pub(crate) fn do_watch<A: ArgvView + ?Sized>(
&mut self,
conn_id: u64,
seq: u64,
args: &A,
) {
let mut by_shard: HashMap<usize, Vec<Vec<u8>>> = HashMap::new();
for i in 1..args.len() {
let key = &args[i];
by_shard
.entry(shard_of(key, self.nshards))
.or_default()
.push(key.to_vec());
}
let targets: Vec<(usize, Op)> = by_shard
.into_iter()
.map(|(s, ks)| (s, Op::CollectWatchVersions(ks)))
.collect();
let remaining = targets.len().max(1) as u32;
if let Some(c) = self.conns.get_mut(&conn_id) {
c.pending.push_back(PendingSlot {
remaining,
agg: Agg::WatchCollect { pairs: Vec::new() },
done: None,
});
}
if targets.is_empty() {
self.fold(conn_id, seq, Part::WatchVersions(Vec::new()));
return;
}
for (shard, op) in targets {
if shard == self.id {
let part = self.exec_op(op);
self.fold(conn_id, seq, part);
} else {
self.send_to(
shard,
Inbound::Request {
origin: self.id,
conn: conn_id,
seq,
op,
},
);
}
}
}
pub(crate) fn do_unwatch(&mut self, conn_id: u64, seq: u64) {
if let Some(c) = self.conns.get_mut(&conn_id) {
c.watched.clear();
c.pending.push_back(PendingSlot {
remaining: 1,
agg: Agg::First(None),
done: None,
});
}
self.fold(conn_id, seq, Part::Reply(b"+OK\r\n".to_vec()));
}
pub(crate) fn exec_transaction_watched(
&mut self,
conn_id: u64,
queued: Vec<Argv>,
watched: Vec<(Vec<u8>, u64)>,
) {
let n = queued.len();
let Some((header_seq, base_idx)) = self.preallocate_exec_slots(conn_id, queued) else {
return;
};
let by_shard = self.group_watched_pairs(watched);
let groups = by_shard.len().max(1) as u32;
if let Some(c) = self.conns.get_mut(&conn_id)
&& let Some(slot) = c.pending.get_mut(base_idx)
{
slot.remaining = groups;
}
if by_shard.is_empty() {
self.fold(conn_id, header_seq, Part::Int(0));
return;
}
for (shard, pairs) in by_shard {
self.send_check_watch(conn_id, header_seq, shard, pairs);
}
let _ = n;
}
fn preallocate_exec_slots(
&mut self,
conn_id: u64,
queued: Vec<Argv>,
) -> Option<(u64, usize)> {
let n = queued.len();
let c = self.conns.get_mut(&conn_id)?;
let header_seq = c.next_seq;
let base_idx = c.pending.len();
c.next_seq += 1 + n as u64;
c.pending.push_back(PendingSlot {
remaining: 1, agg: Agg::ExecPrep { dirty: false, queued, header_seq },
done: None,
});
for _ in 0..n {
c.pending.push_back(PendingSlot {
remaining: 1,
agg: Agg::First(None),
done: None,
});
}
Some((header_seq, base_idx))
}
fn group_watched_pairs(
&self,
watched: Vec<(Vec<u8>, u64)>,
) -> HashMap<usize, Vec<(Vec<u8>, u64)>> {
let mut by_shard: HashMap<usize, Vec<(Vec<u8>, u64)>> = HashMap::new();
for (k, v) in watched {
by_shard
.entry(shard_of(&k, self.nshards))
.or_default()
.push((k, v));
}
by_shard
}
fn send_check_watch(
&mut self,
conn_id: u64,
seq: u64,
shard: usize,
pairs: Vec<(Vec<u8>, u64)>,
) {
let op = Op::CheckWatch(pairs);
if shard == self.id {
let part = self.exec_op(op);
self.fold(conn_id, seq, part);
} else {
self.send_to(
shard,
Inbound::Request {
origin: self.id,
conn: conn_id,
seq,
op,
},
);
}
}
pub(crate) fn finalize_watch_agg(&mut self, conn_id: u64, seq: u64, agg: Agg) {
match agg {
Agg::WatchCollect { pairs } => self.finalize_watch_collect(conn_id, seq, pairs),
Agg::ExecPrep {
dirty,
queued,
header_seq,
} => self.finalize_exec_prep(conn_id, header_seq, dirty, queued),
_ => {}
}
}
fn finalize_watch_collect(
&mut self,
conn_id: u64,
seq: u64,
pairs: Vec<(Vec<u8>, u64)>,
) {
let Some(c) = self.conns.get_mut(&conn_id) else { return };
c.watched.extend(pairs);
let idx = (seq - c.next_emit) as usize;
if let Some(slot) = c.pending.get_mut(idx) {
slot.done = Some(b"+OK\r\n".to_vec());
}
drain_front(c);
}
fn finalize_exec_prep(
&mut self,
conn_id: u64,
header_seq: u64,
dirty: bool,
queued: Vec<Argv>,
) {
let n = queued.len();
if dirty {
if let Some(c) = self.conns.get_mut(&conn_id) {
let base_idx = (header_seq - c.next_emit) as usize;
if let Some(h) = c.pending.get_mut(base_idx) {
h.done = Some(b"*-1\r\n".to_vec());
}
for i in 0..n {
if let Some(p) = c.pending.get_mut(base_idx + 1 + i) {
p.done = Some(Vec::new());
}
}
drain_front(c);
}
return;
}
let mut header_bytes = Vec::with_capacity(8);
encode_array_len(&mut header_bytes, n as i64);
if let Some(c) = self.conns.get_mut(&conn_id) {
let base_idx = (header_seq - c.next_emit) as usize;
if let Some(h) = c.pending.get_mut(base_idx) {
h.done = Some(header_bytes);
}
drain_front(c);
}
for (i, cmd) in queued.iter().enumerate() {
let qseq = header_seq + 1 + i as u64;
let resolved = self.commands.resolve(cmd);
self.start_command_at_seq(conn_id, qseq, cmd, resolved);
}
}
fn start_command_at_seq<A: ArgvView + ?Sized>(
&mut self,
conn_id: u64,
seq: u64,
args: &A,
resolved: ResolvedCmd,
) {
let ResolvedCmd { route, is_quit, is_write, .. } = resolved;
match route {
Route::Unwatch => self.fill_placeholder(conn_id, seq, b"+OK\r\n".to_vec()),
Route::Subscribe
| Route::Unsubscribe
| Route::Psubscribe
| Route::Punsubscribe
| Route::Publish
| Route::Watch => self.fill_placeholder(
conn_id,
seq,
b"-ERR pub/sub or WATCH not allowed inside MULTI\r\n".to_vec(),
),
Route::Local => {
self.start_single_at_seq(conn_id, seq, args, self.id, is_quit, is_write)
}
Route::Single(idx) => {
let shard = shard_of(&args[idx], self.nshards);
self.start_single_at_seq(conn_id, seq, args, shard, is_quit, is_write)
}
other => self.start_multi_at_seq(conn_id, seq, args, other, is_quit),
}
}
fn fill_placeholder(&mut self, conn_id: u64, seq: u64, bytes: Vec<u8>) {
let Some(c) = self.conns.get_mut(&conn_id) else { return };
let idx = (seq - c.next_emit) as usize;
if let Some(slot) = c.pending.get_mut(idx) {
slot.done = Some(bytes);
}
drain_front(c);
}
fn start_single_at_seq<A: ArgvView + ?Sized>(
&mut self,
conn_id: u64,
seq: u64,
args: &A,
shard: usize,
is_quit: bool,
is_write: bool,
) {
if is_quit
&& let Some(c) = self.conns.get_mut(&conn_id)
{
c.closing = true;
}
if shard == self.id {
let part = self.exec_op(Op::Dispatch(args.to_argv()));
self.fold(conn_id, seq, part);
let _ = is_write;
} else {
self.request_batch[shard].push((conn_id, seq, args.to_argv()));
}
}
fn start_multi_at_seq<A: ArgvView + ?Sized>(
&mut self,
conn_id: u64,
seq: u64,
args: &A,
route: Route,
is_quit: bool,
) {
let (targets, agg) = self.build_multi_targets(args, route);
let remaining = targets.len().max(1) as u32;
if let Some(c) = self.conns.get_mut(&conn_id) {
let idx = (seq - c.next_emit) as usize;
if let Some(slot) = c.pending.get_mut(idx) {
slot.remaining = remaining;
slot.agg = agg;
}
if is_quit {
c.closing = true;
}
}
if targets.is_empty() {
self.fold(conn_id, seq, Part::Int(0));
return;
}
self.dispatch_targets(conn_id, seq, targets);
}
}