use std::sync::{Arc, Mutex, OnceLock, Weak};
#[cfg(unix)]
use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};
use crate::db::Db;
static REGISTRY: OnceLock<Mutex<Vec<Weak<Db>>>> = OnceLock::new();
fn registry() -> &'static Mutex<Vec<Weak<Db>>> {
REGISTRY.get_or_init(|| Mutex::new(Vec::new()))
}
fn flush_all_registered() {
if let Some(reg) = REGISTRY.get() {
let mut guard = match reg.lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(), };
guard.retain(|w| w.strong_count() > 0);
for w in guard.iter() {
if let Some(db) = w.upgrade() {
db.flush_all();
}
}
}
}
impl Db {
pub fn install_exit_flush(self_arc: Arc<Db>) {
if self_arc.root == std::path::PathBuf::from(":memory:") {
return;
}
{
let mut reg = match registry().lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(),
};
let already = reg
.iter()
.any(|w| w.upgrade().is_some_and(|a| Arc::ptr_eq(&a, &self_arc)));
if !already {
reg.push(Arc::downgrade(&self_arc));
}
}
install_signal_handler_once();
}
}
#[cfg(unix)]
static INSTALLED: AtomicBool = AtomicBool::new(false);
#[cfg(unix)]
static PIPE_WRITE_FD: AtomicI32 = AtomicI32::new(-1);
#[cfg(unix)]
extern "C" fn handler(sig: libc::c_int) {
let fd = PIPE_WRITE_FD.load(Ordering::SeqCst);
if fd >= 0 {
let byte = [sig as u8];
unsafe {
let _ = libc::write(fd, byte.as_ptr() as *const libc::c_void, 1);
}
}
}
#[cfg(unix)]
fn install_signal_handler_once() {
if INSTALLED.swap(true, Ordering::SeqCst) {
return;
}
unsafe {
let mut fds = [0i32; 2];
if libc::pipe(fds.as_mut_ptr()) != 0 {
INSTALLED.store(false, Ordering::SeqCst); return;
}
let (read_fd, write_fd) = (fds[0], fds[1]);
let flags = libc::fcntl(write_fd, libc::F_GETFL);
if flags != -1 {
libc::fcntl(write_fd, libc::F_SETFL, flags | libc::O_NONBLOCK);
}
PIPE_WRITE_FD.store(write_fd, Ordering::SeqCst);
let mut sa: libc::sigaction = std::mem::zeroed();
sa.sa_sigaction = handler as extern "C" fn(libc::c_int) as libc::sighandler_t;
libc::sigemptyset(&mut sa.sa_mask);
sa.sa_flags = libc::SA_RESTART;
libc::sigaction(libc::SIGINT, &sa, std::ptr::null_mut());
libc::sigaction(libc::SIGTERM, &sa, std::ptr::null_mut());
std::thread::Builder::new()
.name("nedb-exit-flush".into())
.spawn(move || {
reader_loop(read_fd);
})
.ok();
}
}
#[cfg(unix)]
fn reader_loop(read_fd: i32) -> ! {
let mut buf = [0u8; 1];
loop {
let n = unsafe { libc::read(read_fd, buf.as_mut_ptr() as *mut libc::c_void, 1) };
if n <= 0 {
continue; }
let sig = buf[0] as libc::c_int;
flush_all_registered();
unsafe {
let mut sa: libc::sigaction = std::mem::zeroed();
sa.sa_sigaction = libc::SIG_DFL;
libc::sigemptyset(&mut sa.sa_mask);
libc::sigaction(sig, &sa, std::ptr::null_mut());
libc::raise(sig);
}
std::process::exit(128 + sig);
}
}
#[cfg(not(unix))]
fn install_signal_handler_once() {
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn in_memory_is_not_registered() {
let db = Arc::new(Db::in_memory());
Db::install_exit_flush(Arc::clone(&db));
let found = registry()
.lock()
.unwrap()
.iter()
.any(|w| w.upgrade().is_some_and(|a| Arc::ptr_eq(&a, &db)));
assert!(!found, ":memory: db must not be registered");
}
#[test]
fn durable_registration_is_idempotent() {
let dir = tempfile::tempdir().unwrap();
let db = Arc::new(Db::open(dir.path(), None).unwrap());
let count = || {
registry()
.lock()
.unwrap()
.iter()
.filter(|w| w.upgrade().is_some_and(|a| Arc::ptr_eq(&a, &db)))
.count()
};
Db::install_exit_flush(Arc::clone(&db));
assert_eq!(count(), 1, "first install registers exactly once");
Db::install_exit_flush(Arc::clone(&db));
assert_eq!(count(), 1, "second install does not duplicate");
}
#[test]
fn registered_flush_makes_writes_durable_on_reopen() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().to_path_buf();
{
let db = Arc::new(Db::open(&path, None).unwrap());
Db::install_exit_flush(Arc::clone(&db));
db.put("k", "v1", serde_json::json!({ "n": 1 }), vec![], None, None)
.unwrap();
flush_all_registered(); }
let reopened = Db::open(&path, None).unwrap();
let got = reopened.get("k", "v1");
assert!(got.is_some(), "write must survive flush + reopen");
assert_eq!(got.unwrap().data.get("n").and_then(|v| v.as_i64()), Some(1));
}
}