use crate::Commands;
use crate::conn::Conn;
use crate::message::{Inbound, Op, PubMsg, PubSubReg, ReqBatch};
use kevy_persist::{Aof, load_snapshot, replay_aof};
use kevy_resp::{Argv, parse_command_into};
use kevy_ring::{Consumer, Producer};
use kevy_store::Store;
use kevy_sys::{Event, Poller, Socket, Waker};
use kevy_map::KevyMap;
use std::collections::VecDeque;
use std::io;
use std::time::{Duration, Instant};
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
pub(crate) struct Shard<C: Commands> {
pub(crate) id: usize,
pub(crate) nshards: usize,
pub(crate) store: Store,
pub(crate) commands: C,
pub(crate) poller: Poller,
pub(crate) listener: Socket,
pub(crate) waker: Arc<Waker>,
pub(crate) inboxes: Vec<Option<Consumer<Inbound>>>,
pub(crate) outboxes: Vec<Option<Producer<Inbound>>>,
pub(crate) backlog: Vec<VecDeque<Inbound>>,
pub(crate) wakers: Vec<Arc<Waker>>,
pub(crate) conns: KevyMap<u64, Conn>,
pub(crate) fd_to_conn: KevyMap<i32, u64>,
pub(crate) next_conn_id: u64,
pub(crate) events: Vec<Event>,
pub(crate) read_buf: Vec<u8>,
pub(crate) pending_wakes: Vec<bool>,
pub(crate) parked: Vec<Arc<AtomicBool>>,
pub(crate) data_dir: PathBuf,
pub(crate) aof: Option<Aof>,
pub(crate) auto_aof_rewrite_pct: u32,
pub(crate) auto_aof_rewrite_min_size: u64,
pub(crate) dirty: Vec<u64>,
pub(crate) pubsub: PubSubReg,
pub(crate) publish_batch: Vec<Vec<PubMsg>>,
pub(crate) request_batch: Vec<ReqBatch>,
pub(crate) scratch_argv: Argv,
}
const SPIN_LIMIT: u32 = 256;
const PARK_TIMEOUT_MS: i32 = 50;
const TICK_CHECK_EVERY: u32 = 256;
impl<C: Commands> Shard<C> {
pub(crate) fn snapshot_path(&self) -> PathBuf {
self.data_dir.join(format!("dump-{}.rdb", self.id))
}
pub(crate) fn aof_path(&self) -> PathBuf {
self.data_dir.join(format!("aof-{}.aof", self.id))
}
pub(crate) fn run(mut self, stop: Arc<AtomicBool>) -> io::Result<()> {
let snap = self.snapshot_path();
if snap.exists()
&& let Err(e) = load_snapshot(&mut self.store, &snap)
{
eprintln!(
"kevy: shard {} failed to load {}: {e}",
self.id,
snap.display()
);
}
if self.aof.is_some() {
let aof_path = self.aof_path();
let commands = &self.commands;
let store = &mut self.store;
replay_aof(&aof_path, |args| {
commands.dispatch(store, &args);
})?;
}
self.listener.set_nonblocking()?;
self.poller.add(self.listener.raw(), true, false)?;
self.poller.add(self.waker.read_fd(), true, false)?;
let listener_fd = self.listener.raw();
let waker_fd = self.waker.read_fd();
let me = self.id;
let tick_interval = match self.commands.shard_tick_interval_ms() {
0 => None,
ms => Some(Duration::from_millis(ms)),
};
let mut last_tick = Instant::now();
let mut tick_check_counter: u32 = 0;
let mut idle_spins: u32 = 0;
while !stop.load(Ordering::Relaxed) {
let spinning = idle_spins < SPIN_LIMIT;
let timeout = if spinning {
Some(0)
} else {
self.parked[me].store(true, Ordering::SeqCst);
if self.drain_inbound()? {
self.parked[me].store(false, Ordering::SeqCst);
self.flush_backlog();
self.flush_dirty()?;
self.flush_wakes();
idle_spins = 0;
continue;
}
Some(PARK_TIMEOUT_MS)
};
self.poller.wait(&mut self.events, timeout)?;
if !spinning {
self.parked[me].store(false, Ordering::SeqCst);
}
let mut did_work = !self.events.is_empty();
if did_work {
let events = std::mem::take(&mut self.events);
for ev in &events {
if ev.fd == listener_fd {
self.accept_ready()?;
} else if ev.fd == waker_fd {
self.waker.drain();
} else if let Some(&conn_id) = self.fd_to_conn.get(&ev.fd) {
if ev.readable || ev.hup {
self.conn_readable(conn_id)?;
} else if ev.writable {
self.flush_conn(conn_id)?;
}
}
}
self.events = events;
}
if self.drain_inbound()? {
did_work = true;
}
self.flush_backlog();
self.flush_requests();
self.flush_publish();
self.flush_dirty()?;
self.flush_wakes();
if let Some(aof) = &mut self.aof {
let _ = aof.maybe_sync();
}
if let Some(iv) = tick_interval {
tick_check_counter = tick_check_counter.wrapping_add(1);
if tick_check_counter >= TICK_CHECK_EVERY {
tick_check_counter = 0;
let now = Instant::now();
if now.duration_since(last_tick) >= iv {
self.commands.on_shard_tick(&mut self.store);
self.maybe_auto_rewrite_aof();
last_tick = now;
}
}
}
let has_backlog = self.backlog.iter().any(|b| !b.is_empty());
idle_spins = if did_work || has_backlog {
0
} else {
idle_spins.saturating_add(1)
};
}
Ok(())
}
fn maybe_auto_rewrite_aof(&mut self) {
if self.auto_aof_rewrite_pct == 0 {
return;
}
let Some(aof) = &self.aof else { return; };
let cur = aof.size_bytes();
if cur < self.auto_aof_rewrite_min_size {
return;
}
let baseline = aof.size_at_last_rewrite().max(1);
let lhs = cur.saturating_mul(100);
let rhs = baseline.saturating_mul(100u64.saturating_add(self.auto_aof_rewrite_pct as u64));
if lhs < rhs {
return;
}
let aof = self.aof.as_mut().expect("just checked");
if let Err(e) = aof.rewrite_from(&self.store) {
eprintln!(
"kevy: shard {} auto AOF rewrite failed: {e}",
self.id,
);
}
}
pub(crate) fn flush_wakes(&mut self) {
if !self.pending_wakes.iter().any(|&w| w) {
return;
}
for i in 0..self.pending_wakes.len() {
if self.pending_wakes[i] {
self.pending_wakes[i] = false;
if self.parked[i].load(Ordering::SeqCst) {
let _ = self.wakers[i].wake();
}
}
}
}
#[inline]
fn flush_dirty(&mut self) -> io::Result<()> {
if self.dirty.is_empty() {
return Ok(());
}
while let Some(id) = self.dirty.pop() {
self.flush_conn(id)?;
}
Ok(())
}
fn accept_ready(&mut self) -> io::Result<()> {
loop {
match self.listener.accept() {
Ok(sock) => {
sock.set_nonblocking()?;
let _ = sock.set_nodelay();
let fd = sock.raw();
let id = self.next_conn_id;
self.next_conn_id += 1;
self.poller.add(fd, true, false)?;
self.fd_to_conn.insert(fd, id);
self.conns.insert(id, Conn::new(sock));
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => break,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(_) => break,
}
}
Ok(())
}
fn conn_readable(&mut self, conn_id: u64) -> io::Result<()> {
{
let Some(conn) = self.conns.get_mut(&conn_id) else {
return Ok(());
};
loop {
match conn.sock.read(&mut self.read_buf) {
Ok(0) => {
conn.closing = true;
break;
}
Ok(n) => conn.input.extend_from_slice(&self.read_buf[..n]),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => break,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(_) => {
conn.closing = true;
break;
}
}
}
}
let mut had_protocol_error = false;
loop {
let consumed = {
let Some(conn) = self.conns.get_mut(&conn_id) else {
return Ok(());
};
match parse_command_into(&conn.input, &mut self.scratch_argv) {
Ok(Some(c)) => Some(c),
Ok(None) => None,
Err(_) => {
had_protocol_error = true;
None
}
}
};
match consumed {
Some(c) => {
if let Some(conn) = self.conns.get_mut(&conn_id) {
conn.input.drain(..c);
} else {
return Ok(());
}
let argv = std::mem::take(&mut self.scratch_argv);
if let Some(key) = argv.get(1) {
self.store.prefetch_for_key(key);
}
self.handle_command(conn_id, &argv);
self.scratch_argv = argv;
}
None => break,
}
}
if had_protocol_error {
self.protocol_error(conn_id);
}
self.flush_conn(conn_id)
}
pub(crate) fn send_to(&mut self, dst: usize, msg: Inbound) {
if self.backlog[dst].is_empty() {
match self.outboxes[dst].as_mut() {
Some(p) => {
if let Err(m) = p.push(msg) {
self.backlog[dst].push_back(m);
}
}
None => return,
}
} else {
self.backlog[dst].push_back(msg);
}
self.pending_wakes[dst] = true;
}
#[inline]
pub(crate) fn flush_backlog(&mut self) {
if self.backlog.iter().all(|b| b.is_empty()) {
return;
}
for dst in 0..self.nshards {
if self.backlog[dst].is_empty() {
continue;
}
let Some(p) = self.outboxes[dst].as_mut() else {
self.backlog[dst].clear();
continue;
};
while let Some(msg) = self.backlog[dst].pop_front() {
if let Err(m) = p.push(msg) {
self.backlog[dst].push_front(m);
break;
}
self.pending_wakes[dst] = true;
}
}
}
fn drain_inbound(&mut self) -> io::Result<bool> {
let mut did = false;
for src in 0..self.nshards {
if src == self.id {
continue; }
while let Some(msg) = self.inboxes[src].as_mut().expect("peer inbox").pop() {
did = true;
match msg {
Inbound::Request {
origin,
conn,
seq,
op,
} => {
let part = self.exec_op(op);
self.send_to(origin, Inbound::Response { conn, seq, part });
}
Inbound::Response { conn, seq, part } => {
self.fold(conn, seq, part);
self.flush_conn(conn)?;
}
Inbound::RequestBatch { origin, reqs } => {
let mut resps = Vec::with_capacity(reqs.len());
for (conn, seq, argv) in reqs {
let part = self.exec_op(Op::Dispatch(argv));
resps.push((conn, seq, part));
}
self.send_to(origin, Inbound::ResponseBatch(resps));
}
Inbound::ResponseBatch(resps) => {
let mut to_flush: Vec<u64> = Vec::new();
for (conn, seq, part) in resps {
self.fold(conn, seq, part);
if !to_flush.contains(&conn) {
to_flush.push(conn);
}
}
for conn in to_flush {
self.flush_conn(conn)?;
}
}
Inbound::DeliverPublish(batch) => {
for m in &batch {
self.deliver_publish(&m.0, &m.1);
}
}
}
}
}
Ok(did)
}
pub(crate) fn flush_conn(&mut self, conn_id: u64) -> io::Result<()> {
let (close, want_write, fd) = {
let Some(conn) = self.conns.get_mut(&conn_id) else {
return Ok(());
};
while conn.write_pos < conn.output.len() {
match conn.sock.write(&conn.output[conn.write_pos..]) {
Ok(0) => break,
Ok(n) => conn.write_pos += n,
Err(e) if e.kind() == io::ErrorKind::WouldBlock => break,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(_) => {
conn.closing = true;
break;
}
}
}
if conn.write_pos == conn.output.len() {
conn.output.clear();
conn.write_pos = 0;
}
let out_remaining = conn.write_pos < conn.output.len();
let close = conn.closing && conn.pending.is_empty() && !out_remaining;
(close, out_remaining, conn.sock.raw())
};
if close {
self.close_conn(conn_id);
return Ok(());
}
if let Some(conn) = self.conns.get_mut(&conn_id)
&& want_write != conn.want_write
{
conn.want_write = want_write;
self.poller.modify(fd, true, want_write)?;
}
Ok(())
}
fn close_conn(&mut self, conn_id: u64) {
if let Some(conn) = self.conns.remove(&conn_id) {
let fd = conn.sock.raw();
let _ = self.poller.delete(fd);
self.fd_to_conn.remove(&fd);
self.unregister_subs(&conn.sub);
}
}
pub(crate) fn unregister_subs(&self, subs: &std::collections::HashSet<Vec<u8>>) {
if subs.is_empty() {
return;
}
let mut reg = self.pubsub.write().expect("pubsub registry");
for ch in subs {
let drop = match reg.get_mut(ch) {
Some(e) => {
e.0 = e.0.saturating_sub(1);
e.0 == 0
}
None => false,
};
if drop {
reg.remove(ch);
}
}
}
}