use crate::Commands;
use crate::conn::Conn;
use crate::message::{Inbound, Op};
use crate::shard::Shard;
use kevy_persist::{load_snapshot, replay_aof};
use kevy_resp::parse_command_borrowed;
use kevy_sys::Socket;
use kevy_uring::{Completion, IoUring, ProvidedBufRing};
use kevy_map::KevyMap;
use std::io;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
const URING_ENTRIES: u32 = 256;
const URING_SPIN_LIMIT: u32 = 256;
const PBUF_ENTRIES: u16 = 128;
const PBUF_SIZE: u32 = 16 * 1024;
const PBUF_GROUP: u16 = 0;
const ENOBUFS: i32 = 105;
pub(crate) fn io_uring_available() -> bool {
match IoUring::new(URING_ENTRIES) {
Ok(ring) => ring
.register_buf_ring(PBUF_ENTRIES, PBUF_SIZE, PBUF_GROUP)
.is_ok(),
Err(_) => false,
}
}
const OP_SHIFT: u32 = 62;
const OP_RECV: u64 = 1 << OP_SHIFT;
const OP_WRITE: u64 = 2 << OP_SHIFT;
const OP_ACCEPT: u64 = 3 << OP_SHIFT;
const CONN_MASK: u64 = (1 << OP_SHIFT) - 1;
struct UringConn {
recv_armed: bool,
write_buf: Vec<u8>,
write_off: usize,
write_inflight: bool,
closing: bool,
}
impl UringConn {
fn new() -> Self {
UringConn {
recv_armed: false,
write_buf: Vec::new(),
write_off: 0,
write_inflight: false,
closing: false,
}
}
}
impl<C: Commands> Shard<C> {
pub(crate) fn run_uring(mut self, stop: Arc<AtomicBool>) -> io::Result<()> {
let snap = self.snapshot_path();
if snap.exists()
&& let Err(e) = load_snapshot(&mut self.store, &snap)
{
eprintln!("kevy: shard {} failed to load {}: {e}", self.id, snap.display());
}
if self.aof.is_some() {
let aof_path = self.aof_path();
let commands = &self.commands;
let store = &mut self.store;
replay_aof(&aof_path, |args| {
commands.dispatch(store, &args);
})?;
}
let mut ring = IoUring::new(URING_ENTRIES)?;
let mut pbuf = ring.register_buf_ring(PBUF_ENTRIES, PBUF_SIZE, PBUF_GROUP)?;
let mut io: KevyMap<u64, UringConn> = KevyMap::new();
let mut accept_inflight = false;
let mut comps: Vec<Completion> = Vec::with_capacity(URING_ENTRIES as usize);
let mut cids: Vec<u64> = Vec::new();
let mut idle_spins: u32 = 0;
let mut tick_interval = match self.commands.shard_tick_interval_ms() {
0 => None,
ms => Some(Duration::from_millis(ms)),
};
let mut last_tick = Instant::now();
let mut tick_check_counter: u32 = 0;
while !stop.load(Ordering::Relaxed) {
if !accept_inflight {
accept_inflight = ring.prep_accept(self.listener.raw(), OP_ACCEPT);
}
self.uring_arm_conns(&mut ring, &mut io, &mut cids, pbuf.group());
ring.submit_and_wait(0)?; comps.clear();
ring.for_each_completion(|c| comps.push(c));
for c in &comps {
let op = c.user_data & !CONN_MASK;
let cid = c.user_data & CONN_MASK;
match op {
OP_ACCEPT => {
accept_inflight = false;
if c.res >= 0 {
let sock = unsafe { Socket::from_raw_fd(c.res) };
let _ = sock.set_nodelay();
let ncid = self.next_conn_id;
self.next_conn_id += 1;
self.conns.insert(ncid, Conn::new(sock));
io.insert(ncid, UringConn::new());
}
}
OP_RECV => self.uring_on_recv(cid, c, &mut io, &mut pbuf),
OP_WRITE => self.uring_on_write(cid, c.res, &mut io),
_ => {}
}
}
let did_inbound = self.uring_drain_inbound();
self.dirty.clear();
self.flush_backlog();
self.flush_requests();
self.flush_publish();
self.flush_wakes();
if let Some(aof) = &mut self.aof {
let _ = aof.maybe_sync();
}
self.uring_reap_closed(&mut io);
if let Some(iv) = tick_interval {
tick_check_counter = tick_check_counter.wrapping_add(1);
if tick_check_counter >= self.tick_check_every {
tick_check_counter = 0;
let now = Instant::now();
self.tick_blocked_timeouts();
self.tick_xshard_timeouts();
if now.duration_since(last_tick) >= iv {
self.commands.on_shard_tick(&mut self.store);
self.apply_live_runtime_config(&mut tick_interval);
self.maybe_auto_rewrite_aof();
last_tick = now;
}
}
}
if comps.is_empty() && !did_inbound {
idle_spins = idle_spins.saturating_add(1);
if idle_spins >= URING_SPIN_LIMIT {
std::thread::sleep(Duration::from_micros(200));
}
} else {
idle_spins = 0;
}
}
Ok(())
}
fn uring_arm_conns(
&mut self,
ring: &mut IoUring,
io: &mut KevyMap<u64, UringConn>,
cids: &mut Vec<u64>,
bgid: u16,
) {
cids.clear();
cids.extend(self.conns.keys().copied());
for &cid in cids.iter() {
if let (Some(uc), Some(conn)) = (io.get_mut(&cid), self.conns.get_mut(&cid))
&& !uc.write_inflight
&& uc.write_buf.is_empty()
&& !conn.output.is_empty()
{
std::mem::swap(&mut uc.write_buf, &mut conn.output);
uc.write_off = 0;
}
let write_req = io.get(&cid).map(|uc| {
(!uc.write_inflight && uc.write_off < uc.write_buf.len(), uc.write_off)
});
if let Some((true, off)) = write_req {
let fd = self.conns[&cid].sock.raw();
let uc = &io[&cid];
let ok = unsafe {
ring.prep_write(
fd,
uc.write_buf.as_ptr().add(off),
(uc.write_buf.len() - off) as u32,
OP_WRITE | cid,
)
};
if ok {
io.get_mut(&cid).unwrap().write_inflight = true;
}
}
let want_recv = io.get(&cid).is_some_and(|uc| !uc.recv_armed && !uc.closing);
if want_recv {
let fd = self.conns[&cid].sock.raw();
if ring.prep_recv_multishot(fd, bgid, OP_RECV | cid) {
io.get_mut(&cid).unwrap().recv_armed = true;
}
}
}
}
fn uring_on_recv(
&mut self,
cid: u64,
c: &Completion,
io: &mut KevyMap<u64, UringConn>,
pbuf: &mut ProvidedBufRing,
) {
if !c.has_more()
&& let Some(uc) = io.get_mut(&cid)
{
uc.recv_armed = false;
}
if c.res <= 0 {
if c.res != -ENOBUFS
&& let Some(uc) = io.get_mut(&cid)
{
uc.closing = true;
}
return;
}
let Some(bid) = c.buffer_id() else {
return; };
let n = c.res as usize;
if let Some(conn) = self.conns.get_mut(&cid) {
conn.input.extend_from_slice(pbuf.bytes(bid, n));
}
pbuf.recycle(bid);
let mut input_buf = match self.conns.get_mut(&cid) {
Some(c) => std::mem::take(&mut c.input),
None => return,
};
let mut had_protocol_error = false;
self.aof_begin_group();
loop {
let parse = parse_command_borrowed(&input_buf);
let (argv, consumed) = match parse {
Ok(Some(t)) => t,
Ok(None) => break,
Err(_) => {
had_protocol_error = true;
break;
}
};
if let Some(key) = argv.get(1) {
self.store.prefetch_for_key(key);
}
self.handle_command(cid, &argv);
drop(argv);
input_buf.drain(..consumed);
if !self.conns.contains_key(&cid) {
self.uring_aof_end_group();
return;
}
}
self.uring_aof_end_group();
if let Some(c) = self.conns.get_mut(&cid) {
c.input = input_buf;
}
if had_protocol_error {
self.protocol_error(cid);
if let Some(uc) = io.get_mut(&cid) {
uc.closing = true;
}
}
}
#[inline]
fn uring_aof_end_group(&mut self) {
if let Err(e) = self.aof_end_group() {
eprintln!("kevy: shard {} aof group sync failed: {e}", self.id);
}
}
fn uring_on_write(&mut self, cid: u64, res: i32, io: &mut KevyMap<u64, UringConn>) {
let Some(uc) = io.get_mut(&cid) else {
return;
};
uc.write_inflight = false;
if res < 0 {
uc.closing = true;
return;
}
uc.write_off += res as usize;
if uc.write_off >= uc.write_buf.len() {
uc.write_buf.clear();
uc.write_off = 0;
}
}
fn uring_drain_inbound(&mut self) -> bool {
let mut did = false;
for src in 0..self.nshards {
if src == self.id {
continue;
}
while let Some(msg) = self.inboxes[src].as_mut().expect("peer inbox").pop() {
did = true;
match msg {
Inbound::Request { origin, conn, seq, op } => {
let part = self.exec_op(op);
self.send_to(origin, Inbound::Response { conn, seq, part });
}
Inbound::Response { conn, seq, part } => {
self.fold(conn, seq, part);
}
Inbound::RequestBatch { origin, reqs } => {
let mut resps = Vec::with_capacity(reqs.len());
self.aof_begin_group();
for (conn, seq, argv, proto) in reqs {
let part = self.exec_op(Op::Dispatch(argv, proto));
resps.push((conn, seq, part));
}
self.uring_aof_end_group();
self.send_to(origin, Inbound::ResponseBatch(resps));
}
Inbound::ResponseBatch(resps) => {
for (conn, seq, part) in resps {
self.fold(conn, seq, part);
}
}
Inbound::DeliverPublish(batch) => {
for m in &batch {
self.deliver_publish(&m.0, &m.1);
}
}
Inbound::BlockArm {
origin,
conn,
key,
kind,
serve_argv,
proto,
} => self.target_arm(origin, conn, key, kind, serve_argv, proto),
Inbound::BlockReady { conn, key } => self.origin_on_ready(conn, &key),
Inbound::BlockServeReq { origin, conn, key } => {
let reply = self.target_serve(origin, conn, &key);
self.send_to(origin, Inbound::BlockServeResp { conn, key, reply });
}
Inbound::BlockServeResp { conn, key, reply } => {
self.origin_on_serve_resp(conn, key, reply);
}
Inbound::BlockCancel { origin, conn } => self.target_cancel(origin, conn),
}
}
}
did
}
fn uring_reap_closed(&mut self, io: &mut KevyMap<u64, UringConn>) {
let done: Vec<u64> = io
.iter()
.filter(|(cid, uc)| {
let conn = self.conns.get(cid);
let drained = conn.is_none_or(|c| {
c.output.is_empty() && c.pending.is_empty() && c.write_pos == 0
});
let closing = uc.closing || conn.is_some_and(|c| c.closing);
closing && !uc.write_inflight && uc.write_buf.is_empty() && drained
})
.map(|(&cid, _)| cid)
.collect();
for cid in done {
self.close_conn(cid);
io.remove(&cid);
}
}
}