#![allow(dead_code)]
use crate::bplustree::NodeView;
use crate::database::metadata::MetadataPage;
use crate::layout::PAGE_SIZE;
use crate::storage::epoch::EpochManager;
use crate::storage::{HasEpoch, NodeStorage, PageStorage, StorageError};
use std::path::Path;
use std::sync::{
Arc, Mutex,
atomic::{AtomicBool, AtomicU64, Ordering},
};
use zerocopy::FromBytes;
#[derive(Default, Debug)]
pub struct StorageState {
pub commits: Vec<(u8, u64, u64, u64, u64, u64)>,
pub flushes: u64,
pub freed: Vec<u64>,
}
#[derive(Clone)]
pub struct TestStorage {
pub state: Arc<Mutex<StorageState>>,
pub fail_commit: Arc<AtomicBool>,
pub fail_flush: Arc<AtomicBool>,
pub next_page_id: Arc<AtomicU64>,
pub epoch_mgr: Arc<EpochManager>,
}
impl TestStorage {
pub fn new() -> Self {
Self {
state: Arc::new(Mutex::new(StorageState::default())),
fail_commit: Arc::new(AtomicBool::new(false)),
fail_flush: Arc::new(AtomicBool::new(false)),
next_page_id: Arc::new(AtomicU64::new(16)),
epoch_mgr: EpochManager::new_shared(),
}
}
pub fn inject_commit_failure(&self, on: bool) {
self.fail_commit.store(on, Ordering::Relaxed);
}
pub fn inject_flush_failure(&self, on: bool) {
self.fail_flush.store(on, Ordering::Relaxed);
}
pub fn last_commit(&self) -> Option<(u8, u64, u64, u64, u64, u64)> {
self.state.lock().unwrap().commits.last().copied()
}
pub fn all_commits(&self) -> Vec<(u8, u64, u64, u64, u64, u64)> {
self.state.lock().unwrap().commits.clone()
}
pub fn flush_count(&self) -> u64 {
self.state.lock().unwrap().flushes
}
pub fn freed_pages(&self) -> Vec<u64> {
self.state.lock().unwrap().freed.clone()
}
}
impl Default for TestStorage {
fn default() -> Self {
Self::new()
}
}
impl HasEpoch for TestStorage {
fn epoch_mgr(&self) -> &Arc<EpochManager> {
&self.epoch_mgr
}
}
impl NodeStorage for TestStorage {
fn read_node_view(&self, _id: u64) -> Result<Option<NodeView>, StorageError> {
Ok(None)
}
fn write_node_view(&self, _node_view: &NodeView) -> Result<u64, StorageError> {
Ok(self.next_page_id.fetch_add(1, Ordering::SeqCst))
}
fn write_node_view_at_offset(
&self,
_node_view: &NodeView,
offset: u64,
) -> Result<u64, StorageError> {
Ok(offset)
}
fn flush(&self) -> Result<(), StorageError> {
if self.fail_flush.load(Ordering::Relaxed) {
return Err(StorageError::Io(std::io::Error::other(
"flush (injected failure)",
)));
}
self.state.lock().unwrap().flushes += 1;
Ok(())
}
fn free_node(&self, pid: u64) -> Result<(), StorageError> {
self.state.lock().unwrap().freed.push(pid);
Ok(())
}
}
impl PageStorage for TestStorage {
fn open<P: AsRef<Path>>(_path: P) -> Result<Self, std::io::Error>
where
Self: Sized,
{
Ok(Self::new())
}
fn close(&self) -> Result<(), std::io::Error> {
Ok(())
}
fn read_page(&self, _page_id: u64, target: &mut [u8; PAGE_SIZE]) -> Result<(), std::io::Error> {
target.fill(0);
Ok(())
}
fn write_page(&self, _data: &[u8]) -> Result<u64, std::io::Error> {
Ok(self.next_page_id.fetch_add(1, Ordering::SeqCst))
}
fn write_page_at_offset(&self, offset: u64, data: &[u8]) -> Result<u64, std::io::Error> {
if self.fail_commit.load(Ordering::Relaxed) {
return Err(std::io::Error::other("commit (injected failure)"));
}
if data.len() == PAGE_SIZE {
if let Some(page) = MetadataPage::ref_from(data) {
let m = &page.data;
self.state.lock().unwrap().commits.push((
offset as u8,
m.txn_id,
m.root_node_id,
m.height,
m.order,
m.size,
));
}
}
Ok(offset)
}
fn allocate_page(&self) -> Result<u64, std::io::Error> {
Ok(self.next_page_id.fetch_add(1, Ordering::SeqCst))
}
fn free_page(&self, page_id: u64) -> Result<(), std::io::Error> {
self.state.lock().unwrap().freed.push(page_id);
Ok(())
}
fn flush(&self) -> Result<(), std::io::Error> {
if self.fail_flush.load(Ordering::Relaxed) {
return Err(std::io::Error::other("flush (injected failure)"));
}
self.state.lock().unwrap().flushes += 1;
Ok(())
}
fn set_next_page_id(&self, next_page_id: u64) -> Result<(), std::io::Error> {
self.next_page_id.store(next_page_id, Ordering::SeqCst);
Ok(())
}
fn set_freelist(&self, _freed_pages: Vec<u64>) -> Result<(), std::io::Error> {
Ok(())
}
fn get_next_page_id(&self) -> u64 {
self.next_page_id.load(Ordering::SeqCst)
}
fn get_freelist(&self) -> Vec<u64> {
Vec::new()
}
}