#![forbid(unsafe_code)]
pub mod cache;
pub mod checksum;
pub mod freelist;
pub mod header;
pub mod page;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
use crate::pager::cache::{Cache, Evicted};
use crate::pager::checksum::{
page_trailer_valid, page_trailer_valid_v1, write_page_trailer, write_page_trailer_v1,
};
use crate::pager::freelist::{
decode as decode_freelist_page, encode as encode_freelist_page, FreeListPage,
};
use crate::pager::header::{
decode_header, encode_header, FileHeader, FEATURE_FLAG_COMPRESSION, FEATURE_FLAG_ENCRYPTION,
};
use crate::pager::page::{Page, PageId, ENCRYPTION_OVERHEAD, PAGE_SIZE, PAGE_TRAILER_SIZE};
use crate::platform::{FileBackend, FileHandle, SyncMode};
use crate::wal::{Lsn, Wal, WalConfig};
pub use crate::pager::page::PAGE_SIZE as PAGER_PAGE_SIZE;
#[derive(Debug)]
pub struct PageRef<'a> {
page_id: PageId,
bytes: &'a Page,
}
impl<'a> PageRef<'a> {
fn new(page_id: PageId, bytes: &'a Page) -> Self {
Self { page_id, bytes }
}
#[must_use]
pub fn page_id(&self) -> PageId {
self.page_id
}
#[must_use]
pub fn as_bytes(&self) -> &'a [u8; PAGE_SIZE] {
self.bytes.as_bytes()
}
#[must_use]
pub fn to_owned_page(&self) -> Page {
self.bytes.clone()
}
}
#[derive(Debug, Clone)]
pub struct HeaderSnapshot {
pub root_catalog: u64,
pub freelist_head: u64,
pub page_count: u64,
pub view: HashMap<PageId, Arc<Page>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
#[repr(transparent)]
#[serde(transparent)]
pub struct SnapshotId(u64);
impl SnapshotId {
#[must_use]
pub const fn new(raw: u64) -> Self {
Self(raw)
}
#[must_use]
pub const fn get(self) -> u64 {
self.0
}
}
#[derive(Debug)]
struct SnapshotPin {
id: SnapshotId,
map: Arc<Mutex<HashMap<SnapshotId, Lsn>>>,
}
impl Drop for SnapshotPin {
fn drop(&mut self) {
if let Ok(mut guard) = self.map.lock() {
guard.remove(&self.id);
}
}
}
#[derive(Debug, Clone)]
pub enum PageHandle {
Shared(Arc<Page>),
Owned(Page),
}
impl PageHandle {
#[must_use]
pub fn as_bytes(&self) -> &[u8; PAGE_SIZE] {
match self {
PageHandle::Shared(page) => page.as_bytes(),
PageHandle::Owned(page) => page.as_bytes(),
}
}
#[must_use]
pub fn into_page(self) -> Page {
match self {
PageHandle::Shared(page) => (*page).clone(),
PageHandle::Owned(page) => page,
}
}
}
#[derive(Debug)]
pub struct ReaderSnapshot<F: FileBackend> {
pinned_lsn: Lsn,
frozen_view: HashMap<PageId, Arc<Page>>,
frozen_header: Option<Page>,
root_catalog: u64,
pin: SnapshotPin,
_phantom: std::marker::PhantomData<fn() -> F>,
}
impl<F: FileBackend> ReaderSnapshot<F> {
#[must_use]
pub fn pinned_lsn(&self) -> Lsn {
self.pinned_lsn
}
pub fn frozen_pages(&self) -> impl Iterator<Item = (PageId, &Page)> + '_ {
self.frozen_view
.iter()
.map(|(id, page)| (*id, page.as_ref()))
}
#[must_use]
pub fn frozen_header(&self) -> Option<&Page> {
self.frozen_header.as_ref()
}
#[must_use]
pub fn id(&self) -> SnapshotId {
self.pin.id
}
#[must_use]
pub fn root_catalog(&self) -> u64 {
self.root_catalog
}
pub fn read_page(&self, pager: &Pager<F>, id: PageId) -> Result<PageHandle> {
if let Some(page) = self.frozen_view.get(&id) {
return Ok(PageHandle::Shared(Arc::clone(page)));
}
if pager.is_memory_backed() {
return Ok(PageHandle::Owned(pager.read_cache_or_main(id)?));
}
if id.get() >= pager.main_physical_page_count()? {
return Ok(PageHandle::Owned(Page::zeroed()));
}
Ok(PageHandle::Owned(pager.read_main_file_page(id)?))
}
}
pub const DEFAULT_CACHE_FRAMES: usize = 64;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[non_exhaustive]
pub enum CompressionMode {
#[default]
Off,
Lz4,
}
#[cfg(feature = "encryption")]
pub type MasterKeyBytes = zeroize::Zeroizing<[u8; 32]>;
#[cfg(not(feature = "encryption"))]
pub type MasterKeyBytes = [u8; 32];
#[cfg(feature = "encryption")]
#[inline]
#[allow(dead_code)] pub(crate) fn wrap_master_key(bytes: [u8; 32]) -> MasterKeyBytes {
zeroize::Zeroizing::new(bytes)
}
#[cfg(not(feature = "encryption"))]
#[inline]
#[allow(dead_code)] pub(crate) fn wrap_master_key(bytes: [u8; 32]) -> MasterKeyBytes {
bytes
}
#[cfg_attr(not(feature = "encryption"), derive(Copy))]
#[derive(Clone, serde::Serialize, serde::Deserialize)]
pub struct Config {
pub cache_frames: usize,
pub sync_mode: SyncMode,
pub wal_size_limit: u64,
pub checkpoint_threshold: u64,
pub compression_mode: CompressionMode,
pub encryption_key: Option<MasterKeyBytes>,
}
impl std::fmt::Debug for Config {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Config")
.field("cache_frames", &self.cache_frames)
.field("sync_mode", &self.sync_mode)
.field("wal_size_limit", &self.wal_size_limit)
.field("checkpoint_threshold", &self.checkpoint_threshold)
.field("compression_mode", &self.compression_mode)
.field(
"encryption_key",
if self.encryption_key.is_some() {
&"<set>"
} else {
&"<not set>"
},
)
.finish()
}
}
impl Default for Config {
fn default() -> Self {
Self {
cache_frames: DEFAULT_CACHE_FRAMES,
sync_mode: SyncMode::Full,
wal_size_limit: crate::wal::DEFAULT_WAL_SIZE_LIMIT,
checkpoint_threshold: crate::wal::DEFAULT_CHECKPOINT_THRESHOLD,
compression_mode: CompressionMode::Off,
encryption_key: None,
}
}
}
impl Config {
pub fn with_cache_frames(self, frames: usize) -> Result<Self> {
if frames == 0 {
return Err(Error::InvalidArgument("cache_frames must be >= 1"));
}
Ok(Self {
cache_frames: frames,
..self
})
}
#[must_use]
pub fn with_sync_mode(self, sync_mode: SyncMode) -> Self {
Self { sync_mode, ..self }
}
#[must_use]
pub fn with_wal_size_limit(self, limit: u64) -> Self {
Self {
wal_size_limit: limit,
..self
}
}
#[must_use]
pub fn with_checkpoint_threshold(self, frames: u64) -> Self {
Self {
checkpoint_threshold: frames,
..self
}
}
#[must_use]
pub fn with_compression_mode(self, mode: CompressionMode) -> Self {
Self {
compression_mode: mode,
..self
}
}
#[must_use]
pub fn with_encryption_key(self, key: Option<[u8; 32]>) -> Self {
Self {
encryption_key: key.map(wrap_master_key),
..self
}
}
fn wal_config(&self) -> WalConfig {
WalConfig {
sync_mode: self.sync_mode,
size_limit: self.wal_size_limit,
checkpoint_threshold: self.checkpoint_threshold,
}
}
fn master_key(&self) -> Option<&[u8; 32]> {
self.encryption_key.as_ref().map(|k| {
let bytes: &[u8; 32] = k;
bytes
})
}
}
#[derive(Debug)]
pub struct Pager<F: FileBackend = FileHandle> {
backend: Backend<F>,
header: FileHeader,
cache: Cache,
wal: Option<WalState<F>>,
config: Config,
snapshots: Arc<Mutex<HashMap<SnapshotId, Lsn>>>,
next_snapshot_id: Arc<AtomicU64>,
derived_key: Option<PageEncryptionKey>,
main_high_water: u64,
}
#[cfg_attr(not(feature = "encryption"), derive(Copy))]
#[derive(Clone)]
#[allow(dead_code)] struct PageEncryptionKey(MasterKeyBytes);
#[allow(dead_code)] impl PageEncryptionKey {
#[inline]
fn as_bytes(&self) -> &[u8; 32] {
let bytes: &[u8; 32] = &self.0;
bytes
}
}
impl std::fmt::Debug for PageEncryptionKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("PageEncryptionKey(<redacted>)")
}
}
#[derive(Debug)]
struct WalState<F: FileBackend> {
wal: Wal<F>,
pending: HashMap<PageId, Page>,
view: HashMap<PageId, Arc<Page>>,
header_dirty: bool,
view_header: Option<Page>,
committed_root_catalog: u64,
txn_depth: u32,
}
#[derive(Debug)]
enum Backend<F: FileBackend> {
File(F),
Memory(Vec<u8>),
}
impl Pager<FileHandle> {
pub fn open<P: AsRef<Path>>(path: P, config: Config) -> Result<Self> {
let main_path = path.as_ref().to_path_buf();
let main = FileHandle::open_or_create(&main_path)?;
let wal_path = wal_path_for(&main_path);
let wal = FileHandle::open_or_create(&wal_path)?;
Self::open_with_backends(main, wal, wal_path, config)
}
pub fn memory(config: Config) -> Result<Self> {
if config.cache_frames == 0 {
return Err(Error::InvalidArgument("cache_frames must be >= 1"));
}
refuse_compression_without_feature(config.compression_mode)?;
refuse_encryption_without_feature(config.encryption_key.is_some())?;
let header = build_new_file_header(config.compression_mode, config.master_key())?;
let mut bytes = vec![0u8; PAGE_SIZE];
let mut p = Page::zeroed();
encode_header(&header, &mut p);
bytes[..PAGE_SIZE].copy_from_slice(p.as_bytes());
let derived_key = derive_key_for_open(&config, &header)?;
Ok(Self {
backend: Backend::Memory(bytes),
header,
cache: Cache::new(config.cache_frames),
wal: None,
config,
snapshots: Arc::new(Mutex::new(HashMap::new())),
next_snapshot_id: Arc::new(AtomicU64::new(1)),
derived_key,
main_high_water: 0,
})
}
}
impl<F: FileBackend> Pager<F> {
pub fn open_with_backends(
main: F,
wal: F,
wal_path: std::path::PathBuf,
config: Config,
) -> Result<Self> {
if config.cache_frames == 0 {
return Err(Error::InvalidArgument("cache_frames must be >= 1"));
}
refuse_compression_without_feature(config.compression_mode)?;
refuse_encryption_without_feature(config.encryption_key.is_some())?;
let mut header = if main.is_empty()? {
initialise_file(&main, config.compression_mode, config.master_key())?
} else {
load_header(&main)?
};
refuse_unsupported_features(&header)?;
let derived_key = derive_key_for_open(&config, &header)?;
let (wal_state, recovered_view, view_header) = recover_or_create_wal(
&main,
wal,
wal_path,
&mut header,
&config,
derived_key.as_ref(),
)?;
let view: HashMap<PageId, Arc<Page>> = recovered_view
.into_iter()
.map(|(id, page)| (id, Arc::new(page)))
.collect();
let committed_root_catalog = header.root_catalog;
let file_len = main.len()?;
let mut pager = Self {
backend: Backend::File(main),
header,
cache: Cache::new(config.cache_frames),
wal: Some(WalState {
wal: wal_state,
pending: HashMap::new(),
view,
header_dirty: false,
view_header,
committed_root_catalog,
txn_depth: 0,
}),
config,
snapshots: Arc::new(Mutex::new(HashMap::new())),
next_snapshot_id: Arc::new(AtomicU64::new(1)),
derived_key,
main_high_water: 0,
};
pager.main_high_water = pager.main_pages_for_len(file_len);
pager.debug_assert_recovered_pages_covered();
Ok(pager)
}
fn debug_assert_recovered_pages_covered(&self) {
#[cfg(debug_assertions)]
{
let Some(state) = self.wal.as_ref() else {
return;
};
let mut id_raw = self.main_high_water.max(1);
while id_raw < self.header.page_count {
if let Some(pid) = PageId::new(id_raw) {
debug_assert!(
state.view.contains_key(&pid) || state.pending.contains_key(&pid),
"#91: recovered page {id_raw} beyond the physical \
high-water must be resident in the WAL view",
);
}
id_raw += 1;
}
}
}
#[must_use]
pub fn page_count(&self) -> u64 {
self.header.page_count
}
pub fn main_physical_page_count(&self) -> Result<u64> {
match &self.backend {
Backend::File(handle) => {
let len = handle.len()?;
Ok(self.main_pages_for_len(len))
}
Backend::Memory(_) => Ok(self.header.page_count),
}
}
#[must_use]
pub fn page_size(&self) -> u16 {
self.header.page_size
}
#[must_use]
pub fn format_version(&self) -> (u16, u16) {
(self.header.format_major, self.header.format_minor)
}
#[must_use]
pub fn freelist_head(&self) -> u64 {
self.header.freelist_head
}
#[must_use]
pub fn root_catalog(&self) -> u64 {
self.header.root_catalog
}
pub fn set_root_catalog(&mut self, root: u64) -> Result<()> {
self.header.root_catalog = root;
self.stage_or_write_header()
}
pub fn alloc_page(&mut self) -> Result<PageId> {
debug_assert!(
self.in_txn(),
"alloc_page must be inside a Pager txn (begin_txn/end_txn)"
);
if let Some(head) = PageId::new(self.header.freelist_head) {
self.alloc_from_freelist(head)
} else {
self.alloc_fresh()
}
}
pub fn read_page(&mut self, id: PageId) -> Result<PageRef<'_>> {
debug_assert!(id.get() > 0, "PageId is non-zero by construction");
debug_assert!(
id.get() < self.header.page_count,
"read_page called with out-of-range id",
);
if id.get() >= self.header.page_count {
return Err(Error::InvalidArgument("page id out of range"));
}
if self.wal_lookup_some(id) {
return self.lookup_in_wal(id);
}
if self.cache.get(id).is_some() {
return self.lookup_in_cache(id);
}
let buf = self.read_through(id)?;
let evicted = self.cache.insert(id, buf, false);
self.handle_eviction(evicted)?;
self.lookup_in_cache(id)
}
fn wal_lookup_some(&self, id: PageId) -> bool {
let Some(state) = self.wal.as_ref() else {
return false;
};
state.pending.contains_key(&id) || state.view.contains_key(&id)
}
fn lookup_in_wal(&self, id: PageId) -> Result<PageRef<'_>> {
let state = self
.wal
.as_ref()
.ok_or(Error::InvalidArgument("internal: wal overlay missing"))?;
let page = state
.pending
.get(&id)
.or_else(|| state.view.get(&id).map(Arc::as_ref))
.ok_or(Error::InvalidArgument("internal: wal lookup race"))?;
Ok(PageRef::new(id, page))
}
fn lookup_in_cache(&mut self, id: PageId) -> Result<PageRef<'_>> {
let page = self
.cache
.get(id)
.ok_or(Error::InvalidArgument("internal: cache miss after insert"))?;
Ok(PageRef::new(id, page))
}
pub fn write_page(&mut self, id: PageId, page: &Page) -> Result<()> {
debug_assert!(id.get() < self.header.page_count);
if id.get() >= self.header.page_count {
return Err(Error::InvalidArgument("page id out of range"));
}
if let Some(state) = self.wal.as_mut() {
state.pending.insert(id, page.clone());
return Ok(());
}
if let Some(slot) = self.cache.get_mut(id) {
*slot = page.clone();
return Ok(());
}
let evicted = self.cache.insert(id, page.clone(), true);
self.handle_eviction(evicted)
}
pub fn free_page(&mut self, id: PageId) -> Result<()> {
debug_assert!(id.get() > 0);
debug_assert!(id.get() < self.header.page_count);
debug_assert!(
self.in_txn(),
"free_page must be inside a Pager txn (begin_txn/end_txn)"
);
if id.get() >= self.header.page_count {
return Err(Error::InvalidArgument("page id out of range"));
}
let _ = self.cache.evict(id);
let next = self.header.freelist_head;
let mut buf = Page::zeroed();
encode_freelist_page(FreeListPage::new(next), &mut buf);
write_page_trailer(&mut buf);
if let Some(state) = self.wal.as_mut() {
state.pending.insert(id, buf);
} else {
self.write_back_page(id, &buf)?;
}
self.header.freelist_head = id.get();
self.stage_or_write_header()?;
Ok(())
}
pub fn commit(&mut self) -> Result<Lsn> {
let lsn = self.commit_inner()?;
if let Some(state) = self.wal.as_ref() {
if state.wal.committed_frames() >= self.config.checkpoint_threshold {
self.checkpoint()?;
}
}
Ok(lsn)
}
fn commit_inner(&mut self) -> Result<Lsn> {
let mut header_page: Page = Page::zeroed();
encode_header(&self.header, &mut header_page);
let Some(state) = self.wal.as_mut() else {
return Ok(Lsn::ZERO);
};
let header_dirty = state.header_dirty;
if state.pending.is_empty() && !header_dirty {
return Ok(Lsn::ZERO);
}
let mut txn = state.wal.begin_txn();
let mut ids: Vec<PageId> = state.pending.keys().copied().collect();
ids.sort_unstable();
for id in &ids {
if let Some(page) = state.pending.get(id) {
txn.append(*id, page)?;
}
}
if header_dirty {
txn.append_header(&header_page)?;
}
let lsn = txn.commit()?;
for id in ids {
if let Some(page) = state.pending.remove(&id) {
let fresh = Arc::new(page);
debug_assert_eq!(
Arc::strong_count(&fresh),
1,
"#53: a freshly-committed page version must be \
uniquely owned when published into the view",
);
state.view.insert(id, fresh);
}
let _ = self.cache.evict(id);
}
if header_dirty {
state.view_header = Some(header_page);
state.header_dirty = false;
state.committed_root_catalog = self.header.root_catalog;
}
Ok(lsn)
}
pub fn close(mut self) -> Result<()> {
if let Some(state) = self.wal.as_mut() {
state.pending.clear();
}
self.checkpoint()?;
let path = self
.wal
.as_ref()
.map(|state| state.wal.path().to_path_buf());
drop(self);
if let Some(p) = path {
crate::wal::remove_wal(&p)?;
}
Ok(())
}
pub fn flush(&mut self) -> Result<()> {
let _ = self.commit()?;
self.checkpoint()?;
let cap = self.cache.capacity();
let pending: Vec<(PageId, Page)> = self.cache.drain_dirty().take(cap).collect();
for (id, page) in pending {
self.write_back_page(id, &page)?;
}
self.write_header()?;
match &self.backend {
Backend::File(handle) => handle.sync_data(self.config.sync_mode)?,
Backend::Memory(_) => {}
}
Ok(())
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(name = "pager.checkpoint", level = "debug", skip_all)
)]
pub fn checkpoint(&mut self) -> Result<()> {
if self.checkpoint_deferred_for_pinned_reader() {
#[cfg(feature = "tracing")]
tracing::debug!(reason = "reader_pin", "deferred");
return Ok(());
}
let (view_pages, drained_header): (Vec<(PageId, Arc<Page>)>, Option<Page>) =
if let Some(state) = self.wal.as_mut() {
let pages: Vec<(PageId, Arc<Page>)> = state.view.drain().collect();
let hdr = state.view_header.take();
(pages, hdr)
} else {
return Ok(());
};
let nothing_to_do = view_pages.is_empty() && drained_header.is_none();
self.apply_checkpoint_view(view_pages, drained_header)?;
if nothing_to_do {
return Ok(());
}
self.rotate_wal_salt_and_persist()
}
fn checkpoint_deferred_for_pinned_reader(&self) -> bool {
let Some(min_lsn) = self.min_pinned_lsn() else {
return false;
};
let end_lsn = self
.wal
.as_ref()
.map_or(Lsn::ZERO, |s| s.wal.next_lsn().prev_saturating());
min_lsn < end_lsn
}
fn apply_checkpoint_view(
&mut self,
view_pages: Vec<(PageId, Arc<Page>)>,
drained_header: Option<Page>,
) -> Result<()> {
self.grow_main_to_cover(&view_pages)?;
for (id, page) in view_pages {
self.write_back_page(id, page.as_ref())?;
let _ = self.cache.evict(id);
}
if let Some(hp) = drained_header {
match &mut self.backend {
Backend::File(handle) => handle.write_all_at(hp.as_bytes(), 0)?,
Backend::Memory(bytes) => {
if bytes.len() < PAGE_SIZE {
bytes.resize(PAGE_SIZE, 0);
}
bytes[..PAGE_SIZE].copy_from_slice(hp.as_bytes());
}
}
}
match &self.backend {
Backend::File(handle) => handle.sync_data(self.config.sync_mode)?,
Backend::Memory(_) => {}
}
Ok(())
}
fn grow_main_to_cover(&mut self, view_pages: &[(PageId, Arc<Page>)]) -> Result<()> {
if !matches!(self.backend, Backend::File(_)) {
return Ok(());
}
let Some(max_id) = view_pages.iter().map(|(id, _)| id.get()).max() else {
return Ok(());
};
if max_id < self.main_high_water {
return Ok(());
}
let new_len = self.file_length_for(max_id)?;
if let Backend::File(handle) = &mut self.backend {
handle.set_len(new_len)?;
}
self.main_high_water = self.main_pages_for_len(new_len);
debug_assert!(
self.main_high_water > max_id,
"#91: grown file must physically cover the max drained id",
);
Ok(())
}
fn rotate_wal_salt_and_persist(&mut self) -> Result<()> {
if let Some(state) = self.wal.as_mut() {
state.wal.reset_after_checkpoint()?;
stamp_salt_into_header(&mut self.header, state.wal.salt());
}
self.write_header()?;
match &self.backend {
Backend::File(handle) => handle.sync_data(self.config.sync_mode)?,
Backend::Memory(_) => {}
}
Ok(())
}
pub fn reader_snapshot(&mut self) -> Result<ReaderSnapshot<F>> {
let pinned_lsn = self
.wal
.as_ref()
.map_or(Lsn::ZERO, |s| s.wal.next_lsn().prev_saturating());
let frozen_view = self
.wal
.as_ref()
.map(|s| s.view.clone())
.unwrap_or_default();
let frozen_header = self.wal.as_ref().and_then(|s| s.view_header.clone());
let root_catalog = match self.wal.as_ref() {
Some(state) => state.committed_root_catalog,
None => self.header.root_catalog,
};
let snapshot_id = SnapshotId::new(self.next_snapshot_id.fetch_add(1, Ordering::Relaxed));
let mut guard = self
.snapshots
.lock()
.map_err(|_| Error::InvalidArgument("snapshot map poisoned"))?;
debug_assert!(
!guard.contains_key(&snapshot_id),
"next_snapshot_id is monotonic; collisions are impossible",
);
guard.insert(snapshot_id, pinned_lsn);
drop(guard);
Ok(ReaderSnapshot {
pinned_lsn,
frozen_view,
frozen_header,
root_catalog,
pin: SnapshotPin {
id: snapshot_id,
map: Arc::clone(&self.snapshots),
},
_phantom: std::marker::PhantomData,
})
}
#[must_use]
pub fn header_snapshot(&self) -> HeaderSnapshot {
HeaderSnapshot {
root_catalog: self.header.root_catalog,
freelist_head: self.header.freelist_head,
page_count: self.header.page_count,
view: self
.wal
.as_ref()
.map(|s| s.view.clone())
.unwrap_or_default(),
}
}
pub fn restore_header_snapshot(&mut self, snap: HeaderSnapshot) -> Result<()> {
self.header.root_catalog = snap.root_catalog;
self.header.freelist_head = snap.freelist_head;
self.header.page_count = snap.page_count;
if let Some(state) = self.wal.as_mut() {
state.view = snap.view;
}
self.write_header()
}
pub fn rollback_pending_writes(&mut self) {
if let Some(state) = self.wal.as_mut() {
state.pending.clear();
state.header_dirty = false;
}
}
#[must_use]
pub fn live_snapshot_count(&self) -> usize {
self.snapshots.lock().map(|g| g.len()).unwrap_or_default()
}
pub fn min_pinned_lsn(&self) -> Option<Lsn> {
let guard = self.snapshots.lock().ok()?;
guard.values().copied().min()
}
#[must_use]
pub fn is_memory_backed(&self) -> bool {
self.wal.is_none()
}
pub(crate) fn read_cache_or_main(&self, id: PageId) -> Result<Page> {
debug_assert!(id.get() > 0);
debug_assert!(id.get() < self.header.page_count);
if id.get() >= self.header.page_count {
return Err(Error::InvalidArgument("page id out of range"));
}
if let Some(page) = self.cache.peek(id) {
return Ok(page.clone());
}
self.read_through(id)
}
pub fn read_main_file_page_zero(&self, buf: &mut [u8; PAGE_SIZE]) -> Result<()> {
match &self.backend {
Backend::File(handle) => handle.read_exact_at(buf, 0),
Backend::Memory(_) => Err(Error::BackupNotSupportedForMemoryPager),
}
}
pub fn read_main_file_page(&self, id: PageId) -> Result<Page> {
debug_assert!(id.get() > 0, "PageId is non-zero by construction");
debug_assert!(id.get() < self.header.page_count);
if id.get() >= self.header.page_count {
return Err(Error::InvalidArgument("page id out of range"));
}
self.read_through(id)
}
fn read_freelist_page(&self, id: PageId) -> Result<Page> {
if let Some(state) = self.wal.as_ref() {
if let Some(p) = state
.pending
.get(&id)
.or_else(|| state.view.get(&id).map(Arc::as_ref))
{
return Ok(p.clone());
}
}
self.read_through(id)
}
fn alloc_from_freelist(&mut self, head: PageId) -> Result<PageId> {
let head_page = self.read_freelist_page(head)?;
let entry = decode_freelist_page(&head_page).ok_or(Error::Corruption {
page_id: head.get(),
})?;
if entry.next != 0 && (entry.next == head.get() || entry.next >= self.header.page_count) {
return Err(Error::Corruption {
page_id: head.get(),
});
}
self.header.freelist_head = entry.next;
let _ = self.cache.evict(head);
if let Some(state) = self.wal.as_mut() {
state.pending.remove(&head);
}
self.stage_or_write_header()?;
Ok(head)
}
fn alloc_fresh(&mut self) -> Result<PageId> {
debug_assert!(
self.in_txn(),
"alloc_fresh must be inside a Pager txn (begin_txn/end_txn)"
);
let new_id_raw = self.header.page_count;
let new_id =
PageId::new(new_id_raw).ok_or(Error::InvalidArgument("page_count overflow"))?;
if self.wal.is_some() {
self.alloc_fresh_wal(new_id, new_id_raw)
} else {
self.alloc_fresh_memory(new_id, new_id_raw)
}
}
fn alloc_fresh_wal(&mut self, new_id: PageId, new_id_raw: u64) -> Result<PageId> {
let mut blank = Page::zeroed();
write_page_trailer(&mut blank);
self.header.page_count = new_id_raw + 1;
let state = self
.wal
.as_mut()
.ok_or(Error::InvalidArgument("internal: wal overlay missing"))?;
state.pending.insert(new_id, blank);
debug_assert!(
state.pending.contains_key(&new_id),
"#91: fresh page must be staged in pending before commit",
);
self.stage_or_write_header()?;
Ok(new_id)
}
fn alloc_fresh_memory(&mut self, new_id: PageId, new_id_raw: u64) -> Result<PageId> {
self.extend_main_for(new_id_raw)?;
self.header.page_count = new_id_raw + 1;
let evicted = self.cache.insert(new_id, Page::zeroed(), true);
self.handle_eviction(evicted)?;
self.write_back_page(new_id, &Page::zeroed())?;
self.stage_or_write_header()?;
Ok(new_id)
}
fn extend_main_for(&mut self, new_id_raw: u64) -> Result<()> {
let _ = new_id_raw;
let stride = self.physical_stride();
match &mut self.backend {
Backend::Memory(bytes) => {
bytes.resize(bytes.len() + stride, 0);
Ok(())
}
Backend::File(_) => {
debug_assert!(
false,
"#91: file backend extends the main file at checkpoint, not at alloc",
);
Err(Error::InvalidArgument(
"internal: extend_main_for on file backend",
))
}
}
}
fn file_length_for(&self, new_id_raw: u64) -> Result<u64> {
let stride = self.physical_stride() as u64;
let data_pages = new_id_raw;
let data_bytes = data_pages
.checked_mul(stride)
.ok_or(Error::InvalidArgument("file too large"))?;
(PAGE_SIZE as u64)
.checked_add(data_bytes)
.ok_or(Error::InvalidArgument("file too large"))
}
fn main_pages_for_len(&self, file_len: u64) -> u64 {
let stride = self.physical_stride() as u64;
if file_len < PAGE_SIZE as u64 {
return 0;
}
let data_bytes = file_len - PAGE_SIZE as u64;
1 + data_bytes / stride
}
#[must_use]
pub fn is_compression_capable(&self) -> bool {
(self.header.feature_flags & FEATURE_FLAG_COMPRESSION) != 0
}
#[must_use]
pub fn is_encryption_capable(&self) -> bool {
(self.header.feature_flags & FEATURE_FLAG_ENCRYPTION) != 0
}
#[must_use]
fn physical_offset(&self, id: PageId) -> u64 {
crate::pager::page::physical_offset_for(id.get(), self.header.feature_flags)
}
#[must_use]
fn physical_stride(&self) -> usize {
crate::pager::page::physical_page_stride(self.header.feature_flags)
}
fn read_through(&self, id: PageId) -> Result<Page> {
let mut p = Page::zeroed();
let off = self.physical_offset(id);
if self.is_encryption_capable() {
self.read_encrypted_into(id, off, &mut p)?;
} else {
self.read_plain_into(id, off, &mut p)?;
}
if self.is_compression_capable() {
decode_page_v1(&p, id.get())
} else {
if !page_trailer_valid(&p) {
return Err(Error::Corruption { page_id: id.get() });
}
Ok(p)
}
}
fn read_plain_into(&self, id: PageId, off: u64, p: &mut Page) -> Result<()> {
match &self.backend {
Backend::File(handle) => handle.read_exact_at(p.as_bytes_mut(), off)?,
Backend::Memory(bytes) => {
let start =
usize::try_from(off).map_err(|_| Error::InvalidArgument("offset overflow"))?;
let end = start
.checked_add(PAGE_SIZE)
.ok_or(Error::InvalidArgument("offset overflow"))?;
if end > bytes.len() {
return Err(Error::Corruption { page_id: id.get() });
}
p.as_bytes_mut().copy_from_slice(&bytes[start..end]);
}
}
Ok(())
}
fn read_encrypted_into(&self, id: PageId, off: u64, p: &mut Page) -> Result<()> {
let stride = self.physical_stride();
let mut phys = [0u8; PAGE_SIZE + ENCRYPTION_OVERHEAD];
debug_assert_eq!(stride, phys.len(), "stride must match encrypted buffer");
match &self.backend {
Backend::File(handle) => handle.read_exact_at(&mut phys, off)?,
Backend::Memory(bytes) => {
let start =
usize::try_from(off).map_err(|_| Error::InvalidArgument("offset overflow"))?;
let end = start
.checked_add(stride)
.ok_or(Error::InvalidArgument("offset overflow"))?;
if end > bytes.len() {
return Err(Error::Corruption { page_id: id.get() });
}
phys.copy_from_slice(&bytes[start..end]);
}
}
self.decrypt_physical(id, &phys, p)
}
fn decrypt_physical(
&self,
id: PageId,
phys: &[u8; PAGE_SIZE + ENCRYPTION_OVERHEAD],
out: &mut Page,
) -> Result<()> {
let Some(key) = self.derived_key.as_ref() else {
return Err(Error::EncryptionKeyRequired);
};
#[cfg(feature = "encryption")]
{
crate::crypto::decrypt_page(key.as_bytes(), id.get(), phys, out.as_bytes_mut())
}
#[cfg(not(feature = "encryption"))]
{
let _ = (id, phys, out, key);
Err(Error::FormatFeatureUnsupported {
feature: "encryption",
})
}
}
fn write_back_page(&mut self, id: PageId, page: &Page) -> Result<()> {
let off = self.physical_offset(id);
let stamped = self.encode_page_for_disk(page)?;
if self.is_encryption_capable() {
let phys = self.encrypt_logical(id, &stamped)?;
self.write_phys_encrypted(off, &phys)?;
} else {
self.write_phys_4096(off, stamped.as_bytes())?;
}
Ok(())
}
fn encrypt_logical(
&self,
id: PageId,
page: &Page,
) -> Result<[u8; PAGE_SIZE + ENCRYPTION_OVERHEAD]> {
let Some(key) = self.derived_key.as_ref() else {
return Err(Error::EncryptionKeyRequired);
};
#[cfg(feature = "encryption")]
{
let mut out = [0u8; PAGE_SIZE + ENCRYPTION_OVERHEAD];
crate::crypto::encrypt_page(key.as_bytes(), id.get(), page.as_bytes(), &mut out)?;
Ok(out)
}
#[cfg(not(feature = "encryption"))]
{
let _ = (id, page, key);
Err(Error::FormatFeatureUnsupported {
feature: "encryption",
})
}
}
fn write_phys_encrypted(
&mut self,
off: u64,
phys: &[u8; PAGE_SIZE + ENCRYPTION_OVERHEAD],
) -> Result<()> {
let stride = PAGE_SIZE + ENCRYPTION_OVERHEAD;
match &mut self.backend {
Backend::File(handle) => handle.write_all_at(phys, off)?,
Backend::Memory(bytes) => {
let start =
usize::try_from(off).map_err(|_| Error::InvalidArgument("offset overflow"))?;
let end = start
.checked_add(stride)
.ok_or(Error::InvalidArgument("offset overflow"))?;
if end > bytes.len() {
bytes.resize(end, 0);
}
bytes[start..end].copy_from_slice(phys);
}
}
Ok(())
}
fn write_phys_4096(&mut self, off: u64, page_bytes: &[u8; PAGE_SIZE]) -> Result<()> {
match &mut self.backend {
Backend::File(handle) => handle.write_all_at(page_bytes, off)?,
Backend::Memory(bytes) => {
let start =
usize::try_from(off).map_err(|_| Error::InvalidArgument("offset overflow"))?;
let end = start
.checked_add(PAGE_SIZE)
.ok_or(Error::InvalidArgument("offset overflow"))?;
if end > bytes.len() {
bytes.resize(end, 0);
}
bytes[start..end].copy_from_slice(page_bytes);
}
}
Ok(())
}
fn encode_page_for_disk(&self, page: &Page) -> Result<Page> {
let mut stamped = page.clone();
if !self.is_compression_capable() {
write_page_trailer(&mut stamped);
return Ok(stamped);
}
encode_page_v1(page)
}
fn handle_eviction(&mut self, evicted: Option<Evicted>) -> Result<()> {
if let Some(ev) = evicted {
if ev.dirty {
self.write_back_page(ev.page_id, &ev.buffer)?;
}
}
Ok(())
}
fn write_header(&mut self) -> Result<()> {
let mut p = Page::zeroed();
encode_header(&self.header, &mut p);
match &mut self.backend {
Backend::File(handle) => handle.write_all_at(p.as_bytes(), 0)?,
Backend::Memory(bytes) => {
if bytes.len() < PAGE_SIZE {
bytes.resize(PAGE_SIZE, 0);
}
bytes[..PAGE_SIZE].copy_from_slice(p.as_bytes());
}
}
Ok(())
}
fn stage_or_write_header(&mut self) -> Result<()> {
match self.wal.as_mut() {
Some(state) => {
state.header_dirty = true;
Ok(())
}
None => self.write_header(),
}
}
pub fn begin_txn(&mut self) {
if let Some(state) = self.wal.as_mut() {
state.txn_depth = state.txn_depth.saturating_add(1);
}
}
pub fn end_txn(&mut self) {
if let Some(state) = self.wal.as_mut() {
state.txn_depth = state.txn_depth.saturating_sub(1);
}
}
#[must_use]
pub fn in_txn(&self) -> bool {
match self.wal.as_ref() {
Some(state) => state.txn_depth > 0,
None => true,
}
}
}
fn initialise_file<F: FileBackend>(
handle: &F,
compression_mode: CompressionMode,
encryption_key: Option<&[u8; 32]>,
) -> Result<FileHeader> {
let header = build_new_file_header(compression_mode, encryption_key)?;
let mut p = Page::zeroed();
encode_header(&header, &mut p);
handle.set_len(PAGE_SIZE as u64)?;
handle.write_all_at(p.as_bytes(), 0)?;
handle.sync_all()?;
Ok(header)
}
fn build_new_file_header(
compression_mode: CompressionMode,
encryption_key: Option<&[u8; 32]>,
) -> Result<FileHeader> {
match (compression_mode, encryption_key) {
(CompressionMode::Off, None) => Ok(FileHeader::new_empty()),
(CompressionMode::Lz4, None) => Ok(FileHeader::new_empty_with_compression()),
(CompressionMode::Off, Some(_)) => {
let salt = fresh_kdf_salt()?;
Ok(FileHeader::new_empty_with_encryption(salt))
}
(CompressionMode::Lz4, Some(_)) => {
let salt = fresh_kdf_salt()?;
Ok(FileHeader::new_empty_with_encryption_and_compression(salt))
}
}
}
#[allow(clippy::unnecessary_wraps)]
fn fresh_kdf_salt() -> Result<[u8; 32]> {
#[cfg(feature = "encryption")]
{
let mut out = [0u8; 32];
getrandom::getrandom(&mut out)
.map_err(|e| Error::Io(std::io::Error::other(format!("getrandom failure: {e}"))))?;
Ok(out)
}
#[cfg(not(feature = "encryption"))]
{
use rand::RngCore;
let mut out = [0u8; 32];
rand::rng().fill_bytes(&mut out);
Ok(out)
}
}
fn derive_key_for_open(config: &Config, header: &FileHeader) -> Result<Option<PageEncryptionKey>> {
let file_is_encrypted = (header.feature_flags & FEATURE_FLAG_ENCRYPTION) != 0;
let has_feature = cfg!(feature = "encryption");
match (file_is_encrypted, has_feature, config.master_key()) {
(false, _, None) => Ok(None),
(false, _, Some(_)) => Err(Error::EncryptionKeyMismatch),
(true, false, _) => Err(Error::FormatFeatureUnsupported {
feature: "encryption",
}),
(true, true, None) => Err(Error::EncryptionKeyRequired),
#[allow(unused_variables)]
(true, true, Some(user_key)) => {
#[cfg(feature = "encryption")]
{
let derived = crate::crypto::derive_page_key(user_key, &header.kdf_salt);
Ok(Some(PageEncryptionKey(wrap_master_key(derived))))
}
#[cfg(not(feature = "encryption"))]
{
let _ = user_key;
Err(Error::FormatFeatureUnsupported {
feature: "encryption",
})
}
}
}
}
fn refuse_unsupported_features(header: &FileHeader) -> Result<()> {
let uses_compression = (header.feature_flags & FEATURE_FLAG_COMPRESSION) != 0;
if uses_compression && !cfg!(feature = "compression") {
return Err(Error::FormatFeatureUnsupported {
feature: "compression",
});
}
let uses_encryption = (header.feature_flags & FEATURE_FLAG_ENCRYPTION) != 0;
if uses_encryption && !cfg!(feature = "encryption") {
return Err(Error::FormatFeatureUnsupported {
feature: "encryption",
});
}
Ok(())
}
fn refuse_compression_without_feature(mode: CompressionMode) -> Result<()> {
if matches!(mode, CompressionMode::Lz4) && !cfg!(feature = "compression") {
return Err(Error::FormatFeatureUnsupported {
feature: "compression",
});
}
Ok(())
}
fn refuse_encryption_without_feature(has_key: bool) -> Result<()> {
if has_key && !cfg!(feature = "encryption") {
return Err(Error::FormatFeatureUnsupported {
feature: "encryption",
});
}
Ok(())
}
#[must_use]
pub const fn encryption_feature_compiled_in() -> bool {
cfg!(feature = "encryption")
}
#[allow(clippy::type_complexity)]
fn recover_or_create_wal<F: FileBackend>(
main: &F,
wal: F,
wal_path: std::path::PathBuf,
header: &mut FileHeader,
config: &Config,
derived_key: Option<&PageEncryptionKey>,
) -> Result<(Wal<F>, HashMap<PageId, Page>, Option<Page>)> {
let expected_salt = salt_from_header(header);
let wal_key_bytes = derived_key.map(|k| *k.as_bytes());
let recovered = Wal::<F>::open_for_recovery_with_key(
&wal,
expected_salt,
config.wal_size_limit,
wal_key_bytes,
)?;
if recovered.committed_frames > 0 {
let mut w = Wal::<F>::from_recovered_meta(
wal,
wal_path,
recovered.salt,
recovered.next_lsn,
recovered.end_offset,
recovered.committed_frames,
config.wal_config(),
);
w.set_key(wal_key_bytes);
let recovered_header = recovered.header.clone();
if let Some(hp) = &recovered_header {
let decoded = decode_header(hp)?;
header.root_catalog = decoded.root_catalog;
header.freelist_head = decoded.freelist_head;
header.page_count = decoded.page_count;
}
Ok((w, recovered.into_view(), recovered_header))
} else {
let mut w = Wal::<F>::create_fresh_with(wal, wal_path, config.wal_config())?;
w.set_key(wal_key_bytes);
stamp_salt_into_header(header, w.salt());
write_header_to_backend(main, header)?;
main.sync_data(config.sync_mode)?;
Ok((w, HashMap::new(), None))
}
}
const V1_BODY_END: usize = PAGE_SIZE - PAGE_TRAILER_SIZE;
const V1_MAX_COMPRESSED_LEN: usize = V1_BODY_END - 2;
#[cfg_attr(not(feature = "compression"), allow(clippy::unnecessary_wraps))]
fn encode_page_v1(page: &Page) -> Result<Page> {
#[cfg(feature = "compression")]
{
let max_out = lz4_flex::block::get_maximum_output_size(V1_BODY_END);
let mut scratch = [0u8; 8192];
if max_out > scratch.len() {
} else {
let raw = &page.as_bytes()[..V1_BODY_END];
if let Ok(compressed_len) = lz4_flex::block::compress_into(raw, &mut scratch[..max_out])
{
if compressed_len > 0 && compressed_len <= V1_MAX_COMPRESSED_LEN {
let mut out = Page::zeroed();
let buf = out.as_bytes_mut();
let len_u16 = u16::try_from(compressed_len).map_err(|_| {
Error::InvalidArgument(
"encode_page_v1: compressed length exceeds u16 length prefix",
)
})?;
buf[0..2].copy_from_slice(&len_u16.to_le_bytes());
buf[2..2 + compressed_len].copy_from_slice(&scratch[..compressed_len]);
write_page_trailer_v1(&mut out, true);
return Ok(out);
}
}
}
}
let mut out = page.clone();
write_page_trailer_v1(&mut out, false);
Ok(out)
}
pub fn decode_page_v1(disk: &Page, page_id: u64) -> Result<Page> {
if !page_trailer_valid_v1(disk) {
return Err(Error::Corruption { page_id });
}
if !crate::pager::checksum::page_trailer_flag_v1(disk) {
let mut out = Page::zeroed();
out.as_bytes_mut()[..V1_BODY_END].copy_from_slice(&disk.as_bytes()[..V1_BODY_END]);
write_page_trailer(&mut out);
return Ok(out);
}
decode_compressed_page_v1(disk, page_id)
}
fn decode_compressed_page_v1(disk: &Page, page_id: u64) -> Result<Page> {
let body = &disk.as_bytes()[..V1_BODY_END];
let mut len_buf = [0u8; 2];
len_buf.copy_from_slice(&body[0..2]);
let compressed_len = usize::from(u16::from_le_bytes(len_buf));
if compressed_len == 0 || compressed_len > V1_MAX_COMPRESSED_LEN {
return Err(Error::Corruption { page_id });
}
#[cfg(feature = "compression")]
{
let input = &body[2..2 + compressed_len];
let mut out = Page::zeroed();
let decompressed = {
let dest = &mut out.as_bytes_mut()[..V1_BODY_END];
lz4_flex::block::decompress_into(input, dest)
.map_err(|_| Error::Corruption { page_id })?
};
if decompressed != V1_BODY_END {
return Err(Error::Corruption { page_id });
}
write_page_trailer(&mut out);
Ok(out)
}
#[cfg(not(feature = "compression"))]
{
let _ = compressed_len;
let _ = body;
Err(Error::FormatFeatureUnsupported {
feature: "compression",
})
}
}
fn load_header<F: FileBackend>(handle: &F) -> Result<FileHeader> {
let len = handle.len()?;
if len < PAGE_SIZE as u64 {
return Err(Error::InvalidFormat {
reason: "file is shorter than one page",
});
}
let mut p = Page::zeroed();
handle.read_exact_at(p.as_bytes_mut(), 0)?;
decode_header(&p)
}
fn salt_from_header(header: &FileHeader) -> u32 {
u32::from_le_bytes([
header.wal_salt[0],
header.wal_salt[1],
header.wal_salt[2],
header.wal_salt[3],
])
}
fn stamp_salt_into_header(header: &mut FileHeader, salt: u32) {
let bytes = salt.to_le_bytes();
header.wal_salt = [0u8; 16];
header.wal_salt[0..4].copy_from_slice(&bytes);
}
fn write_header_to_backend<F: FileBackend>(handle: &F, header: &FileHeader) -> Result<()> {
let mut p = Page::zeroed();
encode_header(header, &mut p);
handle.write_all_at(p.as_bytes(), 0)
}
#[must_use]
pub fn wal_path_for(main: &Path) -> PathBuf {
let mut buf = main.as_os_str().to_os_string();
buf.push("-wal");
PathBuf::from(buf)
}
#[must_use]
pub fn lock_path_for(main: &Path) -> PathBuf {
let mut buf = main.as_os_str().to_os_string();
buf.push("-lock");
PathBuf::from(buf)
}
#[cfg(test)]
mod tests;
#[cfg(any(test, feature = "fault-injection"))]
#[cfg(test)]
mod tests_fault;