#![forbid(unsafe_code)]
use core::{ffi::CStr, marker::PhantomData, ops::Deref};
use buggy::BugExt as _;
use cfg_if::cfg_if;
use ciborium as cbor;
use ciborium_io::{Read, Write};
use rustix::{
fd::{AsFd, BorrowedFd, OwnedFd},
fs::{self, AtFlags, FlockOperation, Mode, OFlags},
io::{self, Errno},
path::Arg,
};
use spideroak_base58::{String32, ToBase58 as _};
use super::error::{Error, RootDeleted, UnexpectedEof};
use crate::{
BaseId, KeyStore,
engine::WrappedKey,
keystore::{Entry, Occupied, Vacant},
};
pub struct Store {
root: OwnedFd,
}
impl Store {
const fn new(dir: OwnedFd) -> Self {
Self { root: dir }
}
pub fn open(path: impl Arg) -> Result<Self, Error> {
let fd = fs::open(
path,
OFlags::DIRECTORY | OFlags::RDONLY | OFlags::CLOEXEC,
Mode::empty(),
)?;
Self::init_canary(fd.as_fd())?;
Ok(Self::new(fd))
}
pub fn try_clone(&self) -> Result<Self, Error> {
let root = self.root.try_clone().or_else(|err| {
#[allow(clippy::useless_conversion, reason = "depends on cfg")]
let raw: Option<i32> = err.raw_os_error().into();
let raw = raw.assume("should have a raw OS error")?;
Err(Error::from(Errno::from_raw_os_error(raw)))
})?;
Self::init_canary(root.as_fd())?;
Ok(Self::new(root))
}
fn alias(&self, id: BaseId) -> Alias {
Alias(id.to_base58())
}
fn init_canary(fd: BorrowedFd<'_>) -> Result<(), Error> {
if !cfg!(debug_assertions) {
return Ok(());
}
fs::openat(
fd,
"__canary",
OFlags::CREATE | OFlags::RDWR | OFlags::CLOEXEC,
Mode::RUSR | Mode::WUSR, )?;
Ok(())
}
fn check_canary(&self) -> Result<(), Error> {
if !cfg!(debug_assertions) {
return Ok(());
}
match fs::statat(&self.root, "__canary", AtFlags::empty()) {
Err(Errno::NOENT) => Err(RootDeleted(()).into()),
_ => Ok(()),
}
}
}
impl KeyStore for Store {
type Error = Error;
type Vacant<'a, T: WrappedKey> = VacantEntry<'a, T>;
type Occupied<'a, T: WrappedKey> = OccupiedEntry<'a, T>;
fn entry<T: WrappedKey>(&mut self, id: BaseId) -> Result<Entry<'_, Self, T>, Self::Error> {
let alias = self.alias(id);
let entry = loop {
match Exclusive::openat(&self.root, &*alias) {
Ok(fd) => {
break Entry::Occupied(OccupiedEntry::new(self.root.as_fd(), fd, alias));
}
Err(Errno::NOENT) => {
}
Err(err) => return Err(err.into()),
}
match Exclusive::create_new(&self.root, &*alias) {
Ok(fd) => {
break Entry::Vacant(VacantEntry::new(self.root.as_fd(), fd, alias));
}
Err(Errno::NOENT) => {
}
Err(err) => return Err(err.into()),
}
};
Ok(entry)
}
fn get<T: WrappedKey>(&self, id: BaseId) -> Result<Option<T>, Self::Error> {
match Shared::openat(&self.root, &*self.alias(id)) {
Ok(fd) => Ok(cbor::from_reader(fd)?),
Err(Errno::NOENT) => {
self.check_canary()?;
Ok(None)
}
Err(err) => Err(err.into()),
}
}
}
struct Alias(String32);
impl Deref for Alias {
type Target = CStr;
fn deref(&self) -> &Self::Target {
self.0.as_cstr()
}
}
pub struct VacantEntry<'a, T> {
root: BorrowedFd<'a>,
fd: Exclusive,
alias: Alias,
dirty: bool,
_t: PhantomData<T>,
}
impl<'a, T> VacantEntry<'a, T> {
const fn new(root: BorrowedFd<'a>, fd: Exclusive, alias: Alias) -> Self {
Self {
root,
fd,
alias,
dirty: false,
_t: PhantomData,
}
}
}
impl<T: WrappedKey> Vacant<T> for VacantEntry<'_, T> {
type Error = Error;
fn insert(mut self, key: T) -> Result<(), Self::Error> {
debug_assert_eq!(self.fd.fstat()?.st_size, 0);
cbor::into_writer(&key, &self.fd)?;
self.fd.fsync()?;
self.dirty = true;
Ok(())
}
}
impl<T> Drop for VacantEntry<'_, T> {
fn drop(&mut self) {
if !self.dirty {
let _ = fs::unlinkat(self.root, &*self.alias, AtFlags::empty());
}
}
}
pub struct OccupiedEntry<'a, T> {
root: BorrowedFd<'a>,
fd: Exclusive,
alias: Alias,
_t: PhantomData<T>,
}
impl<'a, T> OccupiedEntry<'a, T> {
const fn new(root: BorrowedFd<'a>, fd: Exclusive, alias: Alias) -> Self {
Self {
root,
fd,
alias,
_t: PhantomData,
}
}
}
impl<T: WrappedKey> Occupied<T> for OccupiedEntry<'_, T> {
type Error = Error;
fn get(&self) -> Result<T, Self::Error> {
Ok(cbor::from_reader(&self.fd)?)
}
fn remove(self) -> Result<T, Self::Error> {
fs::unlinkat(self.root, &*self.alias, AtFlags::empty())?;
self.get()
}
}
struct Exclusive(OwnedFd);
impl Exclusive {
fn openat(dir: impl AsFd, path: impl Arg) -> io::Result<Self> {
let fd = fs::openat(dir, path, OFlags::RDWR | OFlags::CLOEXEC, Mode::empty())?;
fs::flock(&fd, FlockOperation::LockExclusive)?;
Ok(Self(fd))
}
fn create_new(dir: impl AsFd, path: impl Arg) -> io::Result<Self> {
let fd = fs::openat(
dir,
path,
OFlags::CREATE | OFlags::EXCL | OFlags::RDWR | OFlags::CLOEXEC,
Mode::RUSR | Mode::WUSR, )?;
fs::flock(&fd, FlockOperation::LockExclusive)?;
Ok(Self(fd))
}
fn fstat(&self) -> io::Result<fs::Stat> {
fs::fstat(&self.0)
}
fn fsync(&self) -> io::Result<()> {
cfg_if! {
if #[cfg(any(target_os = "macos", target_os = "ios"))] {
fs::fcntl_fullfsync(&self.0)?;
} else {
fs::fdatasync(&self.0)?;
}
}
Ok(())
}
}
impl Write for Exclusive {
type Error = Error;
fn write_all(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
(&*self).write_all(buf)
}
fn flush(&mut self) -> Result<(), Self::Error> {
(&*self).flush()
}
}
impl Write for &Exclusive {
type Error = Error;
fn write_all(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
write_all(self.0.as_fd(), buf)
}
fn flush(&mut self) -> Result<(), Self::Error> {
Ok(())
}
}
impl Read for &Exclusive {
type Error = Error;
fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), Self::Error> {
read_exact(self.0.as_fd(), buf)
}
}
struct Shared(OwnedFd);
impl Shared {
fn openat(dir: impl AsFd, path: impl Arg) -> io::Result<Self> {
let fd = fs::openat(dir, path, OFlags::RDONLY | OFlags::CLOEXEC, Mode::empty())?;
fs::flock(&fd, FlockOperation::LockShared)?;
Ok(Self(fd))
}
}
impl Read for Shared {
type Error = Error;
fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), Self::Error> {
(&*self).read_exact(buf)
}
}
impl Read for &Shared {
type Error = Error;
fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), Self::Error> {
read_exact(self.0.as_fd(), buf)
}
}
fn read_exact(fd: BorrowedFd<'_>, mut buf: &mut [u8]) -> Result<(), Error> {
while !buf.is_empty() {
match io::read(fd, buf) {
Ok(0) => break,
Ok(n) => buf = &mut buf[n..],
Err(Errno::INTR) => {}
Err(e) => return Err(e.into()),
}
}
if !buf.is_empty() {
return Err(UnexpectedEof.into());
}
Ok(())
}
fn write_all(fd: BorrowedFd<'_>, mut buf: &[u8]) -> Result<(), Error> {
while !buf.is_empty() {
match io::write(fd, buf) {
Ok(0) => break,
Ok(n) => buf = &buf[n..],
Err(Errno::INTR) => {}
Err(e) => return Err(e.into()),
}
}
Ok(())
}