use std::{
borrow::Cow,
cmp::Ordering,
fs::OpenOptions,
io::{SeekFrom, Write},
ops::{Bound, RangeBounds},
path::Path,
sync::Arc,
};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use super::{State, TransactionHandle};
use crate::{
error::Error,
io::{File, FileManager, FileOp, ManagedFile, OpenableFile, OperableFile},
vault::AnyVault,
ArcBytes, Context, ErrorKind,
};
const PAGE_SIZE: usize = 1024;
pub struct TransactionLog<File: ManagedFile> {
vault: Option<Arc<dyn AnyVault>>,
state: State,
log: <File::Manager as FileManager>::FileHandle,
}
impl<File: ManagedFile> TransactionLog<File> {
pub fn read(
log_path: &Path,
state: State,
context: Context<File::Manager>,
) -> Result<Self, Error> {
let log = context.file_manager.read(log_path)?;
Ok(Self {
vault: context.vault,
state,
log,
})
}
pub fn open(
log_path: &Path,
state: State,
context: Context<File::Manager>,
) -> Result<Self, Error> {
let log = context.file_manager.append(log_path)?;
Ok(Self {
vault: context.vault,
state,
log,
})
}
pub fn total_size(&self) -> u64 {
self.state.len()
}
pub fn initialize_state(state: &State, context: &Context<File::Manager>) -> Result<(), Error> {
let mut log_length = match context.file_manager.file_length(state.path()) {
Ok(length) => length,
Err(Error {
kind: ErrorKind::Io(err),
..
}) if err.kind() == std::io::ErrorKind::NotFound => 0,
Err(other) => return Err(other),
};
if log_length == 0 {
state.initialize(0, 0);
return Ok(());
}
let excess_length = log_length % PAGE_SIZE as u64;
if excess_length > 0 {
eprintln!(
"Transaction log has {} extra bytes. Truncating.",
excess_length
);
let file = OpenOptions::new()
.append(true)
.write(true)
.open(state.path())?;
log_length -= excess_length;
file.set_len(log_length)?;
file.sync_all()?;
}
let mut file = context.file_manager.read(state.path())?;
file.execute(StateInitializer {
state,
log_length,
vault: context.vault(),
})
}
pub fn push(&mut self, handles: Vec<LogEntry<'static>>) -> Result<(), Error> {
self.log.execute(LogWriter {
state: self.state.clone(),
vault: self.vault.clone(),
transactions: handles,
})
}
pub fn get(&mut self, id: u64) -> Result<Option<LogEntry<'static>>, Error> {
match self.log.execute(EntryFetcher {
id,
state: &self.state,
vault: self.vault.as_deref(),
})? {
ScanResult::Found { entry, .. } => Ok(Some(entry)),
ScanResult::NotFound { .. } => Ok(None),
}
}
pub fn scan<Callback: FnMut(LogEntry<'static>) -> bool>(
&mut self,
ids: impl RangeBounds<u64>,
callback: Callback,
) -> Result<(), Error> {
self.log.execute(EntryScanner {
ids,
callback,
state: &self.state,
vault: self.vault.as_deref(),
})
}
pub fn close(self) -> Result<(), Error> {
self.log.close()
}
pub fn current_transaction_id(&self) -> u64 {
self.state.next_transaction_id()
}
pub fn new_transaction<
'a,
I: IntoIterator<Item = &'a [u8], IntoIter = II>,
II: ExactSizeIterator<Item = &'a [u8]>,
>(
&self,
trees: I,
) -> TransactionHandle {
self.state.new_transaction(trees)
}
pub fn state(&self) -> State {
self.state.clone()
}
}
struct StateInitializer<'a> {
state: &'a State,
log_length: u64,
vault: Option<&'a dyn AnyVault>,
}
impl<'a> FileOp<Result<(), Error>> for StateInitializer<'a> {
fn execute(self, log: &mut dyn File) -> Result<(), Error> {
let block_start = self.log_length - PAGE_SIZE as u64;
let mut scratch_buffer = Vec::new();
scratch_buffer.reserve(PAGE_SIZE);
scratch_buffer.resize(4, 0);
let last_transaction =
match scan_for_transaction(log, &mut scratch_buffer, block_start, false, self.vault)? {
ScanResult::Found { entry, .. } => entry,
ScanResult::NotFound { .. } => {
return Err(Error::data_integrity(
"No entries found in an existing transaction log",
))
}
};
self.state.initialize(last_transaction.id, self.log_length);
Ok(())
}
}
pub enum ScanResult {
Found {
entry: LogEntry<'static>,
position: u64,
length: u64,
},
NotFound {
nearest_position: u64,
},
}
fn scan_for_transaction(
log: &mut dyn File,
scratch_buffer: &mut Vec<u8>,
mut block_start: u64,
scan_forward: bool,
vault: Option<&dyn AnyVault>,
) -> Result<ScanResult, Error> {
if scratch_buffer.len() < 4 {
scratch_buffer.resize(4, 0);
}
let file_length = log.length()?;
Ok(loop {
if block_start >= file_length {
return Ok(ScanResult::NotFound {
nearest_position: block_start,
});
}
log.seek(SeekFrom::Start(block_start))?;
log.read_exact(&mut scratch_buffer[0..4])?;
#[allow(clippy::match_on_vec_items)]
match scratch_buffer[0] {
0 => {
if block_start == 0 {
break ScanResult::NotFound {
nearest_position: 0,
};
}
if scan_forward {
block_start += PAGE_SIZE as u64;
} else {
block_start -= PAGE_SIZE as u64;
}
continue;
}
1 => {
let length = (scratch_buffer[1] as usize) << 16
| (scratch_buffer[2] as usize) << 8
| scratch_buffer[3] as usize;
scratch_buffer.resize(length, 0);
let mut initial_page = true;
let mut bytes_to_read = length;
let mut offset = 0;
while bytes_to_read > 0 {
let page_header_length = if initial_page {
initial_page = false;
4
} else {
log.seek(SeekFrom::Current(1))?;
1
};
let page_length = (PAGE_SIZE - page_header_length).min(length - offset);
log.read_exact(&mut scratch_buffer[offset..offset + page_length])?;
offset += page_length;
bytes_to_read -= page_length;
}
let payload = &scratch_buffer[0..length];
let decrypted = match &vault {
Some(vault) => Cow::Owned(vault.decrypt(payload)?),
None => Cow::Borrowed(payload),
};
let entry = LogEntry::deserialize(&decrypted)
.map_err(Error::data_integrity)?
.into_owned();
break ScanResult::Found {
entry,
position: block_start,
length: length as u64,
};
}
_ => unreachable!("corrupt transaction log"),
}
})
}
#[allow(clippy::redundant_pub_crate)]
pub(crate) struct EntryFetcher<'a> {
pub state: &'a State,
pub id: u64,
pub vault: Option<&'a dyn AnyVault>,
}
impl<'a> FileOp<Result<ScanResult, Error>> for EntryFetcher<'a> {
fn execute(self, log: &mut dyn File) -> Result<ScanResult, Error> {
let mut scratch = Vec::with_capacity(PAGE_SIZE);
fetch_entry(log, &mut scratch, self.state, self.id, self.vault)
}
}
fn fetch_entry(
log: &mut dyn File,
scratch_buffer: &mut Vec<u8>,
state: &State,
id: u64,
vault: Option<&dyn AnyVault>,
) -> Result<ScanResult, Error> {
if id == 0 {
return Ok(ScanResult::NotFound {
nearest_position: 0,
});
}
let mut upper_id = state.next_transaction_id();
let mut upper_location = state.len();
if upper_id <= id {
return Ok(ScanResult::NotFound {
nearest_position: upper_location,
});
}
let mut lower_id = None;
let mut lower_location = None;
loop {
let guessed_location = if let Some(page) =
guess_page(id, lower_location, lower_id, upper_location, upper_id)
{
page
} else {
return Ok(ScanResult::NotFound {
nearest_position: upper_location,
});
};
if guessed_location == upper_location {
return Ok(ScanResult::NotFound {
nearest_position: upper_location,
});
}
#[allow(clippy::cast_possible_wrap)]
let scan_forward = guessed_location >= upper_location;
match scan_for_transaction(log, scratch_buffer, guessed_location, scan_forward, vault)? {
ScanResult::Found {
entry,
position,
length,
} => {
state.note_transaction_id_status(entry.id, Some(position));
match entry.id.cmp(&id) {
Ordering::Less => {
if lower_id.is_none() || entry.id > lower_id.unwrap() {
lower_id = Some(entry.id);
lower_location = Some(position);
} else {
return Ok(ScanResult::NotFound {
nearest_position: position,
});
}
}
Ordering::Equal => {
return Ok(ScanResult::Found {
entry,
position,
length,
});
}
Ordering::Greater => {
if entry.id < upper_id {
upper_id = entry.id;
upper_location = position;
} else {
return Ok(ScanResult::NotFound {
nearest_position: position,
});
}
}
}
}
ScanResult::NotFound { nearest_position } => {
return Ok(ScanResult::NotFound { nearest_position });
}
}
}
}
pub struct EntryScanner<'a, Range: RangeBounds<u64>, Callback: FnMut(LogEntry<'static>) -> bool> {
pub state: &'a State,
pub ids: Range,
pub vault: Option<&'a dyn AnyVault>,
pub callback: Callback,
}
impl<'a, Range, Callback> FileOp<Result<(), Error>> for EntryScanner<'a, Range, Callback>
where
Range: RangeBounds<u64>,
Callback: FnMut(LogEntry<'static>) -> bool,
{
fn execute(mut self, log: &mut dyn File) -> Result<(), Error> {
let mut scratch = Vec::with_capacity(PAGE_SIZE);
let (start_location, start_transaction, start_length) = match self.ids.start_bound() {
Bound::Included(start_key) | Bound::Excluded(start_key) => {
match fetch_entry(log, &mut scratch, self.state, *start_key, self.vault)? {
ScanResult::Found {
entry,
position,
length,
} => (position, Some(entry), length),
ScanResult::NotFound { nearest_position } => (nearest_position, None, 0),
}
}
Bound::Unbounded => (0, None, 0),
};
if let Some(entry) = start_transaction {
if self.ids.contains(&entry.id) && !(self.callback)(entry) {
return Ok(());
}
}
let mut next_scan_start = next_page_start(start_location + start_length);
while let ScanResult::Found {
entry,
position,
length,
} = scan_for_transaction(log, &mut scratch, next_scan_start, true, self.vault)?
{
if self.ids.contains(&entry.id) && !(self.callback)(entry) {
break;
}
next_scan_start = next_page_start(position + length);
}
Ok(())
}
}
const fn next_page_start(position: u64) -> u64 {
let page_size = PAGE_SIZE as u64;
(position + page_size - 1) / page_size * page_size
}
struct LogWriter {
state: State,
transactions: Vec<LogEntry<'static>>,
vault: Option<Arc<dyn AnyVault>>,
}
impl FileOp<Result<(), Error>> for LogWriter {
fn execute(mut self, log: &mut dyn File) -> Result<(), Error> {
let mut log_position = self.state.lock_for_write();
let mut scratch = [0_u8; PAGE_SIZE];
let mut completed_transactions = Vec::with_capacity(self.transactions.len());
for transaction in self.transactions.drain(..) {
if transaction.id > log_position.last_written_transaction {
log_position.last_written_transaction = transaction.id;
} else {
return Err(Error::from(ErrorKind::TransactionPushedOutOfOrder));
}
completed_transactions.push((transaction.id, Some(log_position.file_offset)));
let mut bytes = transaction.serialize()?;
if let Some(vault) = &self.vault {
bytes = vault.encrypt(&bytes)?;
}
let total_length = bytes.len() + 3;
let mut offset = 0;
while offset < bytes.len() {
let header_len = if offset == 0 {
let length = u32::try_from(bytes.len())
.map_err(|_| Error::from("transaction too large"))?;
if length & 0xFF00_0000 != 0 {
return Err(Error::from("transaction too large"));
}
scratch[0] = 1;
#[allow(clippy::cast_possible_truncation)]
{
scratch[1] = (length >> 16) as u8;
scratch[2] = (length >> 8) as u8;
scratch[3] = (length & 0xFF) as u8;
}
4
} else {
scratch[0] = 0;
1
};
let total_bytes_left = total_length - (offset + 3);
let bytes_to_write = total_bytes_left.min(PAGE_SIZE - header_len as usize);
scratch[header_len..bytes_to_write + header_len]
.copy_from_slice(&bytes[offset..offset + bytes_to_write]);
log.write_all(&scratch)?;
offset += bytes_to_write;
log_position.file_offset += PAGE_SIZE as u64;
}
}
drop(log_position);
log.flush()?;
self.state
.note_transaction_ids_completed(&completed_transactions);
Ok(())
}
}
#[derive(Eq, PartialEq, Debug)]
pub struct LogEntry<'a> {
pub id: u64,
pub(crate) data: Option<ArcBytes<'a>>,
}
impl<'a> LogEntry<'a> {
#[must_use]
pub fn into_owned(self) -> LogEntry<'static> {
LogEntry {
id: self.id,
data: self.data.map(ArcBytes::into_owned),
}
}
}
impl<'a> LogEntry<'a> {
#[must_use]
pub const fn data(&self) -> Option<&ArcBytes<'a>> {
self.data.as_ref()
}
pub fn set_data(&mut self, data: impl Into<ArcBytes<'a>>) -> Result<(), Error> {
let data = data.into();
if data.len() <= 2_usize.pow(24) - 8 {
self.data = Some(data);
Ok(())
} else {
Err(Error::from(ErrorKind::ValueTooLarge))
}
}
pub(crate) fn serialize(&self) -> Result<Vec<u8>, Error> {
let mut buffer = Vec::with_capacity(8 + self.data.as_ref().map_or(0, |data| data.len()));
buffer.write_u64::<BigEndian>(self.id)?;
if let Some(data) = &self.data {
buffer.write_all(data)?;
}
Ok(buffer)
}
pub(crate) fn deserialize(mut buffer: &'a [u8]) -> Result<Self, Error> {
let id = buffer.read_u64::<BigEndian>()?;
let data = if buffer.is_empty() {
None
} else {
Some(ArcBytes::from(buffer))
};
Ok(Self { id, data })
}
}
#[test]
fn serialization_tests() {
let transaction = LogEntry {
id: 1,
data: Some(ArcBytes::from(b"hello")),
};
let serialized = transaction.serialize().unwrap();
let deserialized = LogEntry::deserialize(&serialized).unwrap();
assert_eq!(transaction, deserialized);
let transaction = LogEntry {
id: u64::MAX,
data: None,
};
let serialized = transaction.serialize().unwrap();
let deserialized = LogEntry::deserialize(&serialized).unwrap();
assert_eq!(transaction, deserialized);
let mut transaction = LogEntry { id: 0, data: None };
let mut big_data = Vec::new();
big_data.resize(2_usize.pow(24), 0);
let mut big_data = ArcBytes::from(big_data);
assert!(matches!(
transaction.set_data(big_data.clone()),
Err(Error {
kind: ErrorKind::ValueTooLarge,
..
})
));
let big_data = big_data.read_bytes(big_data.len() - 8).unwrap();
transaction.set_data(big_data).unwrap();
let serialized = transaction.serialize().unwrap();
let deserialized = LogEntry::deserialize(&serialized).unwrap();
assert_eq!(transaction, deserialized);
}
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
fn guess_page(
looking_for: u64,
lower_location: Option<u64>,
lower_id: Option<u64>,
upper_location: u64,
upper_id: u64,
) -> Option<u64> {
debug_assert_ne!(looking_for, upper_id);
let total_pages = upper_location / PAGE_SIZE as u64;
if let (Some(lower_location), Some(lower_id)) = (lower_location, lower_id) {
let current_page = lower_location / PAGE_SIZE as u64;
let delta_from_current = looking_for - lower_id;
let local_avg_per_page = (upper_id - lower_id) as f64 / (total_pages - current_page) as f64;
let delta_estimated_pages = (delta_from_current as f64 * local_avg_per_page).floor() as u64;
let guess = lower_location + delta_estimated_pages.max(1) * PAGE_SIZE as u64;
if guess >= upper_location {
let capped_guess = upper_location - PAGE_SIZE as u64;
if capped_guess > lower_location {
Some(capped_guess)
} else {
None
}
} else {
Some(guess)
}
} else if upper_id > looking_for {
let avg_per_page = upper_id as f64 / total_pages as f64;
let id_delta = upper_id - looking_for;
let delta_estimated_pages = (id_delta as f64 * avg_per_page).ceil() as u64;
let delta_bytes = delta_estimated_pages.saturating_mul(PAGE_SIZE as u64);
Some(upper_location.saturating_sub(delta_bytes))
} else {
None
}
}
#[cfg(test)]
#[allow(clippy::semicolon_if_nothing_returned, clippy::future_not_send)]
mod tests {
use std::collections::{BTreeSet, HashSet};
use nanorand::{Pcg64, Rng};
use tempfile::tempdir;
use super::*;
use crate::{
io::{
any::AnyFileManager,
fs::{StdFile, StdFileManager},
memory::MemoryFileManager,
},
test_util::RotatorVault,
transaction::TransactionManager,
ChunkCache,
};
#[test]
fn file_log_file_tests() {
log_file_tests("file_log_file", StdFileManager::default(), None, None);
log_file_tests(
"file_log_file_encrypted",
StdFileManager::default(),
Some(Arc::new(RotatorVault::new(13))),
None,
);
}
#[test]
fn memory_log_file_tests() {
log_file_tests("memory_log_file", MemoryFileManager::default(), None, None);
log_file_tests(
"memory_log_file",
MemoryFileManager::default(),
Some(Arc::new(RotatorVault::new(13))),
None,
);
}
#[test]
fn any_log_file_tests() {
log_file_tests("any_file_log_file", AnyFileManager::std(), None, None);
log_file_tests("any_memory_log_file", AnyFileManager::memory(), None, None);
}
fn log_file_tests<Manager: FileManager>(
file_name: &str,
file_manager: Manager,
vault: Option<Arc<dyn AnyVault>>,
cache: Option<ChunkCache>,
) {
let temp_dir = crate::test_util::TestDirectory::new(file_name);
let context = Context {
file_manager,
vault,
cache,
};
std::fs::create_dir(&temp_dir).unwrap();
let log_path = {
let directory: &Path = &temp_dir;
directory.join("_transactions")
};
let mut rng = Pcg64::new_seed(1);
let data = (0..PAGE_SIZE * 10)
.map(|_| rng.generate())
.collect::<Vec<u8>>();
for id in 1..=1_000 {
let state = State::from_path(&log_path);
TransactionLog::<Manager::File>::initialize_state(&state, &context).unwrap();
let mut transactions =
TransactionLog::<Manager::File>::open(&log_path, state, context.clone()).unwrap();
assert_eq!(transactions.current_transaction_id(), id);
let mut tx = transactions.new_transaction([&b"hello"[..]]);
tx.transaction.data = Some(ArcBytes::from(id.to_be_bytes()));
#[allow(clippy::cast_possible_truncation)]
if id % 2 == 0 {
if id % 3 == 0 {
tx.set_data(data[0..PAGE_SIZE * (id as usize % 10).max(3)].to_vec())
.unwrap();
} else {
tx.set_data(data[0..PAGE_SIZE * (id as usize % 10).max(2)].to_vec())
.unwrap();
}
} else {
tx.set_data(data[0..id as usize].to_vec()).unwrap();
}
assert!(tx.data.as_ref().unwrap().len() > 0);
transactions.push(vec![tx.transaction]).unwrap();
transactions.close().unwrap();
}
let state = State::from_path(&log_path);
if context.vault.is_none() {
assert!(TransactionLog::<Manager::File>::initialize_state(
&state,
&Context {
file_manager: context.file_manager.clone(),
vault: Some(Arc::new(RotatorVault::new(13))),
cache: None
}
)
.is_err());
}
TransactionLog::<Manager::File>::initialize_state(&state, &context).unwrap();
let mut transactions =
TransactionLog::<Manager::File>::open(&log_path, state, context).unwrap();
let out_of_order = transactions.new_transaction([&b"test"[..]]);
transactions
.push(vec![
transactions.new_transaction([&b"test2"[..]]).transaction,
])
.unwrap();
assert!(matches!(
transactions
.push(vec![out_of_order.transaction])
.unwrap_err()
.kind,
ErrorKind::TransactionPushedOutOfOrder
));
assert!(transactions.get(0).unwrap().is_none());
for id in 1..=1_000 {
let transaction = transactions.get(id).unwrap();
match transaction {
Some(transaction) => {
assert_eq!(transaction.id, id);
assert_eq!(
&data[..transaction.data().unwrap().len()],
transaction.data().unwrap().as_slice()
);
}
None => {
unreachable!("failed to fetch transaction {}", id)
}
}
}
assert!(transactions.get(1001).unwrap().is_none());
let mut first_ten = Vec::new();
transactions
.scan(.., |entry| {
first_ten.push(entry);
first_ten.len() < 10
})
.unwrap();
assert_eq!(first_ten.len(), 10);
let mut after_first = None;
transactions
.scan(first_ten[0].id + 1.., |entry| {
after_first = Some(entry);
false
})
.unwrap();
assert_eq!(after_first.as_ref(), first_ten.get(1));
}
#[test]
fn discontiguous_log_file_tests() {
let temp_dir = tempdir().unwrap();
let file_manager = StdFileManager::default();
let context = Context {
file_manager,
vault: None,
cache: None,
};
let log_path = temp_dir.path().join("_transactions");
let mut rng = Pcg64::new_seed(1);
let state = State::from_path(&log_path);
TransactionLog::<StdFile>::initialize_state(&state, &context).unwrap();
let mut transactions = TransactionLog::<StdFile>::open(&log_path, state, context).unwrap();
let mut valid_ids = HashSet::new();
for id in 1..=10_000 {
assert_eq!(transactions.current_transaction_id(), id);
let tx = transactions.new_transaction([&b"hello"[..]]);
if rng.generate::<u8>() < 8 {
continue;
}
valid_ids.insert(tx.id);
transactions.push(vec![tx.transaction]).unwrap();
}
for id in 1..=10_000 {
let transaction = transactions.get(id).unwrap();
match transaction {
Some(transaction) => assert_eq!(transaction.id, id),
None => {
assert!(!valid_ids.contains(&id));
}
}
}
}
#[test]
fn file_log_manager_tests() {
log_manager_tests("file_log_manager", StdFileManager::default(), None, None);
}
#[test]
fn memory_log_manager_tests() {
log_manager_tests(
"memory_log_manager",
MemoryFileManager::default(),
None,
None,
);
}
#[test]
fn any_log_manager_tests() {
log_manager_tests("any_log_manager", AnyFileManager::std(), None, None);
log_manager_tests("any_log_manager", AnyFileManager::memory(), None, None);
}
#[test]
fn file_encrypted_log_manager_tests() {
log_manager_tests(
"encrypted_file_log_manager",
MemoryFileManager::default(),
Some(Arc::new(RotatorVault::new(13))),
None,
);
}
fn log_manager_tests<Manager: FileManager>(
file_name: &str,
file_manager: Manager,
vault: Option<Arc<dyn AnyVault>>,
cache: Option<ChunkCache>,
) {
let temp_dir = crate::test_util::TestDirectory::new(file_name);
std::fs::create_dir(&temp_dir).unwrap();
let context = Context {
file_manager,
vault,
cache,
};
let manager = TransactionManager::spawn(&temp_dir, context).unwrap();
assert_eq!(manager.current_transaction_id(), None);
assert_eq!(manager.len(), 0);
assert!(manager.is_empty());
let mut handles = Vec::new();
for _ in 0..10 {
let manager = manager.clone();
handles.push(std::thread::spawn(move || {
for id in 0_u32..1_000 {
let tx = manager.new_transaction([&id.to_be_bytes()[..]]);
tx.commit().unwrap();
}
}));
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(manager.current_transaction_id(), Some(10_000));
assert_eq!(manager.next_transaction_id(), 10_001);
assert!(manager
.transaction_was_successful(manager.current_transaction_id().unwrap())
.unwrap());
assert!(!manager
.transaction_was_successful(manager.next_transaction_id())
.unwrap());
let mut ten = None;
manager
.scan(10.., |entry| {
ten = Some(entry);
false
})
.unwrap();
assert_eq!(ten.unwrap().id, 10);
}
#[test]
fn file_out_of_order_log_manager_tests() {
out_of_order_log_manager_tests(
"file_out_of_order_log_manager",
StdFileManager::default(),
None,
None,
);
}
#[test]
fn memory_out_of_order_log_manager_tests() {
out_of_order_log_manager_tests(
"memory_out_of_order_log_manager",
MemoryFileManager::default(),
None,
None,
);
}
#[test]
fn any_out_of_order_log_manager_tests() {
out_of_order_log_manager_tests(
"any_out_of_order_log_manager",
AnyFileManager::std(),
None,
None,
);
out_of_order_log_manager_tests(
"any_out_of_order_log_manager",
AnyFileManager::memory(),
None,
None,
);
}
fn out_of_order_log_manager_tests<Manager: FileManager>(
file_name: &str,
file_manager: Manager,
vault: Option<Arc<dyn AnyVault>>,
cache: Option<ChunkCache>,
) {
let temp_dir = crate::test_util::TestDirectory::new(file_name);
std::fs::create_dir(&temp_dir).unwrap();
let context = Context {
file_manager,
vault,
cache,
};
let manager = TransactionManager::spawn(&temp_dir, context).unwrap();
let mut rng = Pcg64::new_seed(1);
for batch in 1..=100_u8 {
println!("New batch");
let mut handles = Vec::new();
for tree in 1..=batch {
handles.push(manager.new_transaction([&tree.to_be_bytes()[..]]));
}
rng.shuffle(&mut handles);
let (handle_sender, handle_receiver) = flume::unbounded();
let mut should_commit_handles = Vec::new();
let mut expected_ids = BTreeSet::new();
for (index, handle) in handles.into_iter().enumerate() {
let should_commit_handle = rng.generate::<f32>() > 0.25 || expected_ids.is_empty();
if should_commit_handle {
expected_ids.insert(handle.id);
}
should_commit_handles.push(should_commit_handle);
handle_sender.send((index, handle)).unwrap();
}
let should_commit_handles = Arc::new(should_commit_handles);
let mut threads = Vec::new();
for _ in 1..=batch {
let handle_receiver = handle_receiver.clone();
let should_commit_handles = should_commit_handles.clone();
threads.push(std::thread::spawn(move || {
let (handle_index, handle) = handle_receiver.recv().unwrap();
if should_commit_handles[handle_index] {
println!("Committing handle {}", handle.id);
handle.commit().unwrap();
} else {
println!("Dropping handle {}", handle.id);
handle.rollback();
}
}));
}
for thread in threads {
thread.join().unwrap();
}
manager
.scan(dbg!(*expected_ids.iter().next().unwrap()).., |tx| {
expected_ids.remove(&tx.id);
true
})
.unwrap();
assert!(expected_ids.is_empty(), "{:?}", expected_ids);
}
}
}