use std::collections::HashMap;
use std::fs;
use std::hash::Hash;
use std::path::Path;
use zerocopy::FromBytes;
use crate::Key;
use crate::const_map::MapEntry;
use crate::disk_loc::DiskLoc;
use crate::entry::{EntryHeader, compute_crc32, entry_size};
use crate::error::{DbError, DbResult};
use crate::hint;
use crate::io::direct;
use crate::skiplist::SkipList;
#[cfg(feature = "var-collections")]
use crate::skiplist::node::VarNode;
use crate::skiplist::node::{ConstNode, SkipNode, random_height};
use crate::sync::{self, Mutex};
#[cfg(feature = "typed-tree")]
use crate::codec::Codec;
#[cfg(feature = "typed-tree")]
use crate::skiplist::node::{TypedData, TypedNode};
#[cfg(feature = "encryption")]
use std::sync::Arc;
#[cfg(feature = "encryption")]
use crate::crypto::PageCipher;
#[cfg(feature = "encryption")]
use crate::io::tags::{self, TagFile};
type ReadFn = dyn Fn(&std::fs::File, u64, usize) -> DbResult<Vec<u8>>;
fn plain_reader() -> Box<ReadFn> {
Box::new(direct::pread_value)
}
#[cfg(feature = "encryption")]
fn encrypted_reader(
cipher: &Arc<PageCipher>,
tag_file: &Arc<TagFile>,
file_id: u32,
) -> Box<ReadFn> {
let cipher = cipher.clone();
let tag_file = tag_file.clone();
Box::new(move |file, offset, len| {
direct::pread_value_encrypted(file, &tag_file, &cipher, file_id, offset, len)
})
}
#[cfg(feature = "encryption")]
fn make_reader(
cipher: &Option<Arc<PageCipher>>,
dir: &Path,
file_id: u32,
) -> DbResult<Box<ReadFn>> {
if let Some(cipher) = cipher {
let tp = tags::tags_path_for_data(&dir.join(format!("{file_id:06}.data")));
if tp.exists() {
let tag_file = Arc::new(TagFile::open_read(&tp)?);
return Ok(encrypted_reader(cipher, &tag_file, file_id));
}
return Err(crate::error::DbError::EncryptionError(format!(
"tag file missing for encrypted data file {file_id:06}.data"
)));
}
Ok(plain_reader())
}
#[cfg(not(feature = "encryption"))]
fn make_reader(_dir: &Path, _file_id: u32) -> DbResult<Box<ReadFn>> {
Ok(plain_reader())
}
pub fn recover_const_tree<K: Key, const V: usize>(
shard_dirs: &[&Path],
shard_ids: &[u8],
index: &SkipList<ConstNode<K, V>>,
hints: bool,
#[cfg(feature = "encryption")] cipher: Option<Arc<PageCipher>>,
) -> DbResult<u64> {
let results: Vec<DbResult<u64>> = std::thread::scope(|s| {
let handles: Vec<_> = shard_dirs
.iter()
.enumerate()
.map(|(i, &dir)| {
let shard_id = shard_ids[i];
#[cfg(feature = "encryption")]
let cipher = cipher.clone();
s.spawn(move || -> DbResult<u64> {
let shard_start = std::time::Instant::now();
let mut local_max_gsn: u64 = 0;
let mut used_hints = false;
let guard = index.collector().enter();
let file_ids = scan_data_files(dir)?;
let hint_ids = if hints {
hint::scan_hint_files(dir)?
} else {
Vec::new()
};
for file_id in &file_ids {
let data_path = dir.join(format!("{file_id:06}.data"));
#[cfg(feature = "encryption")]
let read_fn = make_reader(&cipher, dir, *file_id)?;
#[cfg(not(feature = "encryption"))]
let read_fn = make_reader(dir, *file_id)?;
if hint_ids.contains(file_id) {
let hint_path = dir.join(format!("{file_id:06}.hint"));
if let Some(hint_data) = hint::read_hint_file(&hint_path)? {
if hint_data.len() % hint::hint_entry_size(size_of::<K>()) == 0 {
let data_file = direct::open_read(&data_path)?;
recover_const_hint::<K, V>(
shard_id,
*file_id,
&hint_data,
&data_file,
index,
&guard,
&mut local_max_gsn,
&read_fn,
)?;
used_hints = true;
continue;
}
tracing::warn!(
shard_id,
file_id,
"hint file has unexpected size, falling back to full scan"
);
}
}
recover_const_full_scan::<K, V>(
shard_id,
*file_id,
&data_path,
index,
&guard,
&mut local_max_gsn,
&read_fn,
)?;
}
if let Some(&last_id) = file_ids.last() {
let _ = fs::remove_file(dir.join(format!("{last_id:06}.hint")));
}
let elapsed = shard_start.elapsed().as_secs_f64();
metrics::histogram!("armdb.recovery.duration_seconds", "path" => if used_hints { "hint" } else { "full_scan" })
.record(elapsed);
Ok(local_max_gsn)
})
})
.collect();
handles
.into_iter()
.map(|h| h.join().expect("recovery thread panicked"))
.collect()
});
let mut max_gsn: u64 = 0;
for result in results {
let shard_gsn = result?;
if shard_gsn > max_gsn {
max_gsn = shard_gsn;
}
}
Ok(max_gsn)
}
#[cfg(feature = "var-collections")]
pub fn recover_var_tree<K: Key>(
shard_dirs: &[&Path],
shard_ids: &[u8],
index: &SkipList<VarNode<K>>,
hints: bool,
#[cfg(feature = "encryption")] cipher: Option<Arc<PageCipher>>,
) -> DbResult<u64> {
let results: Vec<DbResult<u64>> = std::thread::scope(|s| {
let handles: Vec<_> = shard_dirs
.iter()
.enumerate()
.map(|(i, &dir)| {
let shard_id = shard_ids[i];
#[cfg(feature = "encryption")]
let cipher = cipher.clone();
s.spawn(move || -> DbResult<u64> {
let shard_start = std::time::Instant::now();
let mut local_max_gsn: u64 = 0;
let mut used_hints = false;
let guard = index.collector().enter();
let file_ids = scan_data_files(dir)?;
let hint_ids = if hints {
hint::scan_hint_files(dir)?
} else {
Vec::new()
};
let mut key_gsn: std::collections::HashMap<Vec<u8>, u64> =
std::collections::HashMap::new();
for file_id in &file_ids {
let data_path = dir.join(format!("{file_id:06}.data"));
#[cfg(feature = "encryption")]
let read_fn = make_reader(&cipher, dir, *file_id)?;
#[cfg(not(feature = "encryption"))]
let read_fn = make_reader(dir, *file_id)?;
if hint_ids.contains(file_id) {
let hint_path = dir.join(format!("{file_id:06}.hint"));
if let Some(hint_data) = hint::read_hint_file(&hint_path)? {
if hint_data.len() % hint::hint_entry_size(size_of::<K>()) == 0 {
recover_var_hint::<K>(
shard_id,
*file_id,
&hint_data,
index,
&guard,
&mut local_max_gsn,
&mut key_gsn,
)?;
used_hints = true;
continue;
}
tracing::warn!(
shard_id,
file_id,
"hint file has unexpected size, falling back to full scan"
);
}
}
recover_var_full_scan::<K>(
shard_id,
*file_id,
&data_path,
index,
&guard,
&mut local_max_gsn,
&read_fn,
&mut key_gsn,
)?;
}
if let Some(&last_id) = file_ids.last() {
let _ = fs::remove_file(dir.join(format!("{last_id:06}.hint")));
}
let elapsed = shard_start.elapsed().as_secs_f64();
metrics::histogram!("armdb.recovery.duration_seconds", "path" => if used_hints { "hint" } else { "full_scan" })
.record(elapsed);
Ok(local_max_gsn)
})
})
.collect();
handles
.into_iter()
.map(|h| h.join().expect("recovery thread panicked"))
.collect()
});
let mut max_gsn: u64 = 0;
for result in results {
let shard_gsn = result?;
if shard_gsn > max_gsn {
max_gsn = shard_gsn;
}
}
Ok(max_gsn)
}
#[cfg(feature = "typed-tree")]
pub fn recover_typed_tree<K: Key, T: Send + Sync, C: Codec<T> + Sync>(
shard_dirs: &[&Path],
shard_ids: &[u8],
index: &SkipList<TypedNode<K, T>>,
codec: &C,
hints: bool,
#[cfg(feature = "encryption")] cipher: Option<Arc<PageCipher>>,
) -> DbResult<u64> {
let results: Vec<DbResult<u64>> = std::thread::scope(|s| {
let handles: Vec<_> = shard_dirs
.iter()
.enumerate()
.map(|(i, &dir)| {
let shard_id = shard_ids[i];
#[cfg(feature = "encryption")]
let cipher = cipher.clone();
s.spawn(move || -> DbResult<u64> {
let shard_start = std::time::Instant::now();
let mut local_max_gsn: u64 = 0;
let mut used_hints = false;
let guard = index.collector().enter();
let file_ids = scan_data_files(dir)?;
let hint_ids = if hints {
hint::scan_hint_files(dir)?
} else {
Vec::new()
};
for file_id in &file_ids {
let data_path = dir.join(format!("{file_id:06}.data"));
#[cfg(feature = "encryption")]
let read_fn = make_reader(&cipher, dir, *file_id)?;
#[cfg(not(feature = "encryption"))]
let read_fn = make_reader(dir, *file_id)?;
if hint_ids.contains(file_id) {
let hint_path = dir.join(format!("{file_id:06}.hint"));
if let Some(hint_data) = hint::read_hint_file(&hint_path)? {
if hint_data.len() % hint::hint_entry_size(size_of::<K>()) == 0 {
let data_file = direct::open_read(&data_path)?;
recover_typed_hint::<K, T, C>(
shard_id,
*file_id,
&hint_data,
&data_file,
index,
codec,
&guard,
&mut local_max_gsn,
&read_fn,
)?;
used_hints = true;
continue;
}
tracing::warn!(
shard_id,
file_id,
"hint file has unexpected size, falling back to full scan"
);
}
}
recover_typed_full_scan::<K, T, C>(
shard_id,
*file_id,
&data_path,
index,
codec,
&guard,
&mut local_max_gsn,
&read_fn,
)?;
}
if let Some(&last_id) = file_ids.last() {
let _ = fs::remove_file(dir.join(format!("{last_id:06}.hint")));
}
let elapsed = shard_start.elapsed().as_secs_f64();
metrics::histogram!("armdb.recovery.duration_seconds", "path" => if used_hints { "hint" } else { "full_scan" })
.record(elapsed);
Ok(local_max_gsn)
})
})
.collect();
handles
.into_iter()
.map(|h| h.join().expect("recovery thread panicked"))
.collect()
});
let mut max_gsn: u64 = 0;
for result in results {
let shard_gsn = result?;
if shard_gsn > max_gsn {
max_gsn = shard_gsn;
}
}
Ok(max_gsn)
}
#[allow(clippy::too_many_arguments)]
fn recover_const_hint<K: Key, const V: usize>(
shard_id: u8,
file_id: u32,
hint_data: &[u8],
data_file: &std::fs::File,
index: &SkipList<ConstNode<K, V>>,
guard: &seize::LocalGuard<'_>,
max_gsn: &mut u64,
read_fn: &ReadFn,
) -> DbResult<()> {
for entry in hint::parse_hint_entries::<K>(hint_data) {
let seq = entry.sequence();
if seq > *max_gsn {
*max_gsn = seq;
}
if entry.is_tombstone() {
index.remove(entry.key.as_bytes(), guard);
} else {
let value_bytes = read_fn(data_file, entry.value_offset, entry.value_len as usize)?;
let value: [u8; V] = value_bytes
.try_into()
.map_err(|_| DbError::CorruptedEntry {
offset: entry.value_offset,
})?;
let disk = DiskLoc::new(
shard_id,
file_id as u16,
entry.value_offset as u32,
entry.value_len,
);
let height = random_height();
let node_ptr = ConstNode::alloc(entry.key, value, disk, height);
match index.insert(node_ptr, guard) {
crate::skiplist::InsertResult::Inserted => {}
crate::skiplist::InsertResult::Exists(existing) => {
existing.write_data(disk, &value);
unsafe {
ConstNode::<K, V>::dealloc_node(node_ptr);
}
}
}
}
}
Ok(())
}
#[cfg(feature = "var-collections")]
fn recover_var_hint<K: Key>(
shard_id: u8,
file_id: u32,
hint_data: &[u8],
index: &SkipList<VarNode<K>>,
guard: &seize::LocalGuard<'_>,
max_gsn: &mut u64,
key_gsn: &mut std::collections::HashMap<Vec<u8>, u64>,
) -> DbResult<()> {
for entry in hint::parse_hint_entries::<K>(hint_data) {
let seq = entry.sequence();
if seq > *max_gsn {
*max_gsn = seq;
}
if entry.is_tombstone() {
if let Some(&prev_gsn) = key_gsn.get(entry.key.as_bytes())
&& seq <= prev_gsn
{
continue;
}
index.remove(entry.key.as_bytes(), guard);
key_gsn.insert(entry.key.as_bytes().to_vec(), seq);
} else {
if let Some(&prev_gsn) = key_gsn.get(entry.key.as_bytes())
&& seq <= prev_gsn
{
continue;
}
key_gsn.insert(entry.key.as_bytes().to_vec(), seq);
let disk = DiskLoc::new(
shard_id,
file_id as u16,
entry.value_offset as u32,
entry.value_len,
);
let height = random_height();
let node_ptr = VarNode::alloc(entry.key, disk, height);
match index.insert(node_ptr, guard) {
crate::skiplist::InsertResult::Inserted => {}
crate::skiplist::InsertResult::Exists(existing) => {
let new_disk = Box::into_raw(Box::new(disk));
let old_disk = existing.swap_disk(new_disk);
unsafe { drop(Box::from_raw(old_disk)) }
unsafe {
(*node_ptr)
.disk
.store(std::ptr::null_mut(), std::sync::atomic::Ordering::Relaxed);
VarNode::<K>::dealloc_node(node_ptr);
}
}
}
}
}
Ok(())
}
#[cfg(feature = "typed-tree")]
#[allow(clippy::too_many_arguments)]
fn recover_typed_hint<K: Key, T: Send + Sync, C: Codec<T>>(
shard_id: u8,
file_id: u32,
hint_data: &[u8],
data_file: &std::fs::File,
index: &SkipList<TypedNode<K, T>>,
codec: &C,
guard: &seize::LocalGuard<'_>,
max_gsn: &mut u64,
read_fn: &ReadFn,
) -> DbResult<()> {
for entry in hint::parse_hint_entries::<K>(hint_data) {
let seq = entry.sequence();
if seq > *max_gsn {
*max_gsn = seq;
}
if entry.is_tombstone() {
index.remove(entry.key.as_bytes(), guard);
} else {
let value_bytes = read_fn(data_file, entry.value_offset, entry.value_len as usize)?;
let value: T = codec.decode_from(&value_bytes[..entry.value_len as usize])?;
let disk = DiskLoc::new(
shard_id,
file_id as u16,
entry.value_offset as u32,
entry.value_len,
);
let height = random_height();
let node_ptr = TypedNode::alloc(entry.key, value, disk, height);
match index.insert(node_ptr, guard) {
crate::skiplist::InsertResult::Inserted => {}
crate::skiplist::InsertResult::Exists(existing) => unsafe {
let new_data = (*node_ptr)
.data
.swap(std::ptr::null_mut(), std::sync::atomic::Ordering::Relaxed);
let old_data = existing.swap_data(new_data);
drop(Box::from_raw(old_data));
TypedNode::<K, T>::dealloc_node(node_ptr);
},
}
}
}
Ok(())
}
fn recover_const_full_scan<K: Key, const V: usize>(
shard_id: u8,
file_id: u32,
data_path: &Path,
index: &SkipList<ConstNode<K, V>>,
guard: &seize::LocalGuard<'_>,
max_gsn: &mut u64,
read_fn: &ReadFn,
) -> DbResult<()> {
let file = direct::open_read(data_path)?;
let file_len = file.metadata()?.len();
let mut offset: u64 = 0;
while offset + size_of::<EntryHeader>() as u64 <= file_len {
let header_bytes = match read_fn(&file, offset, size_of::<EntryHeader>()) {
Ok(b) => b,
Err(_) => break,
};
let header = match EntryHeader::read_from_bytes(&header_bytes) {
Ok(h) => h,
Err(_) => break,
};
let total = entry_size(size_of::<K>(), header.value_len);
if offset + total > file_len {
tracing::warn!(
shard_id,
file_id,
offset,
"truncating partial entry at end of file"
);
break;
}
let key_value_len = size_of::<K>() + header.value_len as usize;
let key_value_bytes = read_fn(
&file,
offset + size_of::<EntryHeader>() as u64,
key_value_len,
)?;
let key_bytes = &key_value_bytes[..size_of::<K>()];
let value_bytes = &key_value_bytes[size_of::<K>()..];
let expected_crc = compute_crc32(header.gsn, header.value_len, key_bytes, value_bytes);
if expected_crc != header.crc32 {
if offset + total >= file_len {
tracing::warn!(shard_id, file_id, offset, "CRC mismatch at end of file");
break;
}
return Err(DbError::CrcMismatch {
expected: expected_crc,
actual: header.crc32,
});
}
let seq = header.sequence();
if seq > *max_gsn {
*max_gsn = seq;
}
let key: K = K::from_bytes(key_bytes);
if header.is_tombstone() {
index.remove(key.as_bytes(), guard);
} else {
let value: [u8; V] = value_bytes
.try_into()
.map_err(|_| DbError::CorruptedEntry { offset })?;
let value_offset = offset + size_of::<EntryHeader>() as u64 + size_of::<K>() as u64;
let disk = DiskLoc::new(
shard_id,
file_id as u16,
value_offset as u32,
header.value_len,
);
let height = random_height();
let node_ptr = ConstNode::alloc(key, value, disk, height);
match index.insert(node_ptr, guard) {
crate::skiplist::InsertResult::Inserted => {}
crate::skiplist::InsertResult::Exists(existing) => {
existing.write_data(disk, &value);
unsafe {
ConstNode::<K, V>::dealloc_node(node_ptr);
}
}
}
}
offset += total;
}
Ok(())
}
#[cfg(feature = "var-collections")]
#[allow(clippy::too_many_arguments)]
fn recover_var_full_scan<K: Key>(
shard_id: u8,
file_id: u32,
data_path: &Path,
index: &SkipList<VarNode<K>>,
guard: &seize::LocalGuard<'_>,
max_gsn: &mut u64,
read_fn: &ReadFn,
key_gsn: &mut std::collections::HashMap<Vec<u8>, u64>,
) -> DbResult<()> {
let file = direct::open_read(data_path)?;
let file_len = file.metadata()?.len();
let mut offset: u64 = 0;
while offset + size_of::<EntryHeader>() as u64 <= file_len {
let header_bytes = match read_fn(&file, offset, size_of::<EntryHeader>()) {
Ok(b) => b,
Err(_) => break,
};
let header = match EntryHeader::read_from_bytes(&header_bytes) {
Ok(h) => h,
Err(_) => break,
};
let total = entry_size(size_of::<K>(), header.value_len);
if offset + total > file_len {
tracing::warn!(
shard_id,
file_id,
offset,
"truncating partial entry at end of file"
);
break;
}
let key_value_len = size_of::<K>() + header.value_len as usize;
let key_value_bytes = read_fn(
&file,
offset + size_of::<EntryHeader>() as u64,
key_value_len,
)?;
let key_bytes = &key_value_bytes[..size_of::<K>()];
let value_bytes = &key_value_bytes[size_of::<K>()..];
let expected_crc = compute_crc32(header.gsn, header.value_len, key_bytes, value_bytes);
if expected_crc != header.crc32 {
if offset + total >= file_len {
tracing::warn!(shard_id, file_id, offset, "CRC mismatch at end of file");
break;
}
return Err(DbError::CrcMismatch {
expected: expected_crc,
actual: header.crc32,
});
}
let seq = header.sequence();
if seq > *max_gsn {
*max_gsn = seq;
}
let key: K = K::from_bytes(key_bytes);
if header.is_tombstone() {
if let Some(&prev_gsn) = key_gsn.get(key.as_bytes())
&& seq <= prev_gsn
{
offset += total;
continue;
}
index.remove(key.as_bytes(), guard);
key_gsn.insert(key.as_bytes().to_vec(), seq);
} else {
if let Some(&prev_gsn) = key_gsn.get(key.as_bytes())
&& seq <= prev_gsn
{
offset += total;
continue;
}
key_gsn.insert(key.as_bytes().to_vec(), seq);
let value_offset = offset + size_of::<EntryHeader>() as u64 + size_of::<K>() as u64;
let disk = DiskLoc::new(
shard_id,
file_id as u16,
value_offset as u32,
header.value_len,
);
let height = random_height();
let node_ptr = VarNode::alloc(key, disk, height);
match index.insert(node_ptr, guard) {
crate::skiplist::InsertResult::Inserted => {}
crate::skiplist::InsertResult::Exists(existing) => {
let new_disk = Box::into_raw(Box::new(disk));
let old_disk = existing.swap_disk(new_disk);
unsafe { drop(Box::from_raw(old_disk)) }
unsafe {
(*node_ptr)
.disk
.store(std::ptr::null_mut(), std::sync::atomic::Ordering::Relaxed);
VarNode::<K>::dealloc_node(node_ptr);
}
}
}
}
offset += total;
}
Ok(())
}
#[cfg(feature = "typed-tree")]
#[allow(clippy::too_many_arguments)]
fn recover_typed_full_scan<K: Key, T: Send + Sync, C: Codec<T>>(
shard_id: u8,
file_id: u32,
data_path: &Path,
index: &SkipList<TypedNode<K, T>>,
codec: &C,
guard: &seize::LocalGuard<'_>,
max_gsn: &mut u64,
read_fn: &ReadFn,
) -> DbResult<()> {
let file = direct::open_read(data_path)?;
let file_len = file.metadata()?.len();
let mut offset: u64 = 0;
while offset + size_of::<EntryHeader>() as u64 <= file_len {
let header_bytes = match read_fn(&file, offset, size_of::<EntryHeader>()) {
Ok(b) => b,
Err(_) => break,
};
let header = match EntryHeader::read_from_bytes(&header_bytes) {
Ok(h) => h,
Err(_) => break,
};
let total = entry_size(size_of::<K>(), header.value_len);
if offset + total > file_len {
tracing::warn!(
shard_id,
file_id,
offset,
"truncating partial entry at end of file"
);
break;
}
let key_value_len = size_of::<K>() + header.value_len as usize;
let key_value_bytes = read_fn(
&file,
offset + size_of::<EntryHeader>() as u64,
key_value_len,
)?;
let key_bytes = &key_value_bytes[..size_of::<K>()];
let value_bytes = &key_value_bytes[size_of::<K>()..];
let expected_crc = compute_crc32(header.gsn, header.value_len, key_bytes, value_bytes);
if expected_crc != header.crc32 {
if offset + total >= file_len {
tracing::warn!(shard_id, file_id, offset, "CRC mismatch at end of file");
break;
}
return Err(DbError::CrcMismatch {
expected: expected_crc,
actual: header.crc32,
});
}
let seq = header.sequence();
if seq > *max_gsn {
*max_gsn = seq;
}
let key: K = K::from_bytes(key_bytes);
if header.is_tombstone() {
index.remove(key.as_bytes(), guard);
} else {
let value: T = codec.decode_from(value_bytes)?;
let value_offset = offset + size_of::<EntryHeader>() as u64 + size_of::<K>() as u64;
let disk = DiskLoc::new(
shard_id,
file_id as u16,
value_offset as u32,
header.value_len,
);
let height = random_height();
let node_ptr = TypedNode::alloc(key, value, disk, height);
match index.insert(node_ptr, guard) {
crate::skiplist::InsertResult::Inserted => {}
crate::skiplist::InsertResult::Exists(existing) => unsafe {
let new_data = (*node_ptr)
.data
.swap(std::ptr::null_mut(), std::sync::atomic::Ordering::Relaxed);
let old_data = existing.swap_data(new_data);
drop(Box::from_raw(old_data));
TypedNode::<K, T>::dealloc_node(node_ptr);
},
}
}
offset += total;
}
Ok(())
}
pub fn recover_const_map<K: Key + Send + Sync + Hash + Eq, const V: usize>(
shard_dirs: &[&Path],
shard_ids: &[u8],
indexes: &[Mutex<HashMap<K, MapEntry<V, DiskLoc>>>],
hints: bool,
#[cfg(feature = "encryption")] cipher: Option<Arc<PageCipher>>,
) -> DbResult<u64> {
let results: Vec<DbResult<u64>> = std::thread::scope(|s| {
let handles: Vec<_> = shard_dirs
.iter()
.enumerate()
.map(|(i, &dir)| {
let shard_id = shard_ids[i];
let index = &indexes[i];
#[cfg(feature = "encryption")]
let cipher = cipher.clone();
s.spawn(move || -> DbResult<u64> {
let shard_start = std::time::Instant::now();
let mut local_max_gsn: u64 = 0;
let mut used_hints = false;
let mut map = sync::lock(index);
let file_ids = scan_data_files(dir)?;
let hint_ids = if hints {
hint::scan_hint_files(dir)?
} else {
Vec::new()
};
for file_id in &file_ids {
let data_path = dir.join(format!("{file_id:06}.data"));
#[cfg(feature = "encryption")]
let read_fn = make_reader(&cipher, dir, *file_id)?;
#[cfg(not(feature = "encryption"))]
let read_fn = make_reader(dir, *file_id)?;
if hint_ids.contains(file_id) {
let hint_path = dir.join(format!("{file_id:06}.hint"));
if let Some(hint_data) = hint::read_hint_file(&hint_path)? {
if hint_data.len() % hint::hint_entry_size(size_of::<K>()) == 0 {
let data_file = direct::open_read(&data_path)?;
recover_const_map_hint::<K, V>(
shard_id,
*file_id,
&hint_data,
&data_file,
&mut map,
&mut local_max_gsn,
&read_fn,
)?;
used_hints = true;
continue;
}
tracing::warn!(
shard_id,
file_id,
"hint file has unexpected size, falling back to full scan"
);
}
}
recover_const_map_full_scan::<K, V>(
shard_id,
*file_id,
&data_path,
&mut map,
&mut local_max_gsn,
&read_fn,
)?;
}
if let Some(&last_id) = file_ids.last() {
let _ = fs::remove_file(dir.join(format!("{last_id:06}.hint")));
}
let elapsed = shard_start.elapsed().as_secs_f64();
metrics::histogram!("armdb.recovery.duration_seconds", "path" => if used_hints { "hint" } else { "full_scan" })
.record(elapsed);
Ok(local_max_gsn)
})
})
.collect();
handles
.into_iter()
.map(|h| h.join().expect("recovery thread panicked"))
.collect()
});
let mut max_gsn: u64 = 0;
for result in results {
let shard_gsn = result?;
if shard_gsn > max_gsn {
max_gsn = shard_gsn;
}
}
Ok(max_gsn)
}
#[cfg(feature = "var-collections")]
pub fn recover_var_map<K: Key + Send + Sync + Hash + Eq>(
shard_dirs: &[&Path],
shard_ids: &[u8],
indexes: &[Mutex<HashMap<K, DiskLoc>>],
hints: bool,
#[cfg(feature = "encryption")] cipher: Option<Arc<PageCipher>>,
) -> DbResult<u64> {
let results: Vec<DbResult<u64>> = std::thread::scope(|s| {
let handles: Vec<_> = shard_dirs
.iter()
.enumerate()
.map(|(i, &dir)| {
let shard_id = shard_ids[i];
let index = &indexes[i];
#[cfg(feature = "encryption")]
let cipher = cipher.clone();
s.spawn(move || -> DbResult<u64> {
let shard_start = std::time::Instant::now();
let mut local_max_gsn: u64 = 0;
let mut used_hints = false;
let mut map = sync::lock(index);
let file_ids = scan_data_files(dir)?;
let hint_ids = if hints {
hint::scan_hint_files(dir)?
} else {
Vec::new()
};
for file_id in &file_ids {
let data_path = dir.join(format!("{file_id:06}.data"));
#[cfg(feature = "encryption")]
let read_fn = make_reader(&cipher, dir, *file_id)?;
#[cfg(not(feature = "encryption"))]
let read_fn = make_reader(dir, *file_id)?;
if hint_ids.contains(file_id) {
let hint_path = dir.join(format!("{file_id:06}.hint"));
if let Some(hint_data) = hint::read_hint_file(&hint_path)? {
if hint_data.len() % hint::hint_entry_size(size_of::<K>()) == 0 {
recover_var_map_hint::<K>(
shard_id,
*file_id,
&hint_data,
&mut map,
&mut local_max_gsn,
)?;
used_hints = true;
continue;
}
tracing::warn!(
shard_id,
file_id,
"hint file has unexpected size, falling back to full scan"
);
}
}
recover_var_map_full_scan::<K>(
shard_id,
*file_id,
&data_path,
&mut map,
&mut local_max_gsn,
&read_fn,
)?;
}
if let Some(&last_id) = file_ids.last() {
let _ = fs::remove_file(dir.join(format!("{last_id:06}.hint")));
}
let elapsed = shard_start.elapsed().as_secs_f64();
metrics::histogram!("armdb.recovery.duration_seconds", "path" => if used_hints { "hint" } else { "full_scan" })
.record(elapsed);
Ok(local_max_gsn)
})
})
.collect();
handles
.into_iter()
.map(|h| h.join().expect("recovery thread panicked"))
.collect()
});
let mut max_gsn: u64 = 0;
for result in results {
let shard_gsn = result?;
if shard_gsn > max_gsn {
max_gsn = shard_gsn;
}
}
Ok(max_gsn)
}
fn recover_const_map_hint<K: Key + Send + Sync + Hash + Eq, const V: usize>(
shard_id: u8,
file_id: u32,
hint_data: &[u8],
data_file: &std::fs::File,
map: &mut HashMap<K, crate::const_map::MapEntry<V, DiskLoc>>,
max_gsn: &mut u64,
read_fn: &ReadFn,
) -> DbResult<()> {
for entry in hint::parse_hint_entries::<K>(hint_data) {
let seq = entry.sequence();
if seq > *max_gsn {
*max_gsn = seq;
}
if entry.is_tombstone() {
map.remove(&entry.key);
} else {
let value_bytes = read_fn(data_file, entry.value_offset, entry.value_len as usize)?;
let value: [u8; V] = value_bytes
.try_into()
.map_err(|_| DbError::CorruptedEntry {
offset: entry.value_offset,
})?;
let disk = DiskLoc::new(
shard_id,
file_id as u16,
entry.value_offset as u32,
entry.value_len,
);
map.insert(entry.key, crate::const_map::MapEntry { loc: disk, value });
}
}
Ok(())
}
fn recover_const_map_full_scan<K: Key + Send + Sync + Hash + Eq, const V: usize>(
shard_id: u8,
file_id: u32,
data_path: &Path,
map: &mut HashMap<K, crate::const_map::MapEntry<V, DiskLoc>>,
max_gsn: &mut u64,
read_fn: &ReadFn,
) -> DbResult<()> {
let file = direct::open_read(data_path)?;
let file_len = file.metadata()?.len();
let mut offset: u64 = 0;
while offset + size_of::<EntryHeader>() as u64 <= file_len {
let header_bytes = match read_fn(&file, offset, size_of::<EntryHeader>()) {
Ok(b) => b,
Err(_) => break,
};
let header = match EntryHeader::read_from_bytes(&header_bytes) {
Ok(h) => h,
Err(_) => break,
};
let total = entry_size(size_of::<K>(), header.value_len);
if offset + total > file_len {
tracing::warn!(
shard_id,
file_id,
offset,
"truncating partial entry at end of file"
);
break;
}
let key_value_len = size_of::<K>() + header.value_len as usize;
let key_value_bytes = read_fn(
&file,
offset + size_of::<EntryHeader>() as u64,
key_value_len,
)?;
let key_bytes = &key_value_bytes[..size_of::<K>()];
let value_bytes = &key_value_bytes[size_of::<K>()..];
let expected_crc = compute_crc32(header.gsn, header.value_len, key_bytes, value_bytes);
if expected_crc != header.crc32 {
if offset + total >= file_len {
tracing::warn!(shard_id, file_id, offset, "CRC mismatch at end of file");
break;
}
return Err(DbError::CrcMismatch {
expected: expected_crc,
actual: header.crc32,
});
}
let seq = header.sequence();
if seq > *max_gsn {
*max_gsn = seq;
}
let key: K = K::from_bytes(key_bytes);
if header.is_tombstone() {
map.remove(&key);
} else {
let value: [u8; V] = value_bytes
.try_into()
.map_err(|_| DbError::CorruptedEntry { offset })?;
let value_offset = offset + size_of::<EntryHeader>() as u64 + size_of::<K>() as u64;
let disk = DiskLoc::new(
shard_id,
file_id as u16,
value_offset as u32,
header.value_len,
);
map.insert(key, crate::const_map::MapEntry { loc: disk, value });
}
offset += total;
}
Ok(())
}
#[cfg(feature = "var-collections")]
fn recover_var_map_hint<K: Key + Send + Sync + Hash + Eq>(
shard_id: u8,
file_id: u32,
hint_data: &[u8],
map: &mut HashMap<K, DiskLoc>,
max_gsn: &mut u64,
) -> DbResult<()> {
for entry in hint::parse_hint_entries::<K>(hint_data) {
let seq = entry.sequence();
if seq > *max_gsn {
*max_gsn = seq;
}
if entry.is_tombstone() {
map.remove(&entry.key);
} else {
let disk = DiskLoc::new(
shard_id,
file_id as u16,
entry.value_offset as u32,
entry.value_len,
);
map.insert(entry.key, disk);
}
}
Ok(())
}
#[cfg(feature = "var-collections")]
fn recover_var_map_full_scan<K: Key + Send + Sync + Hash + Eq>(
shard_id: u8,
file_id: u32,
data_path: &Path,
map: &mut HashMap<K, DiskLoc>,
max_gsn: &mut u64,
read_fn: &ReadFn,
) -> DbResult<()> {
let file = direct::open_read(data_path)?;
let file_len = file.metadata()?.len();
let mut offset: u64 = 0;
while offset + size_of::<EntryHeader>() as u64 <= file_len {
let header_bytes = match read_fn(&file, offset, size_of::<EntryHeader>()) {
Ok(b) => b,
Err(_) => break,
};
let header = match EntryHeader::read_from_bytes(&header_bytes) {
Ok(h) => h,
Err(_) => break,
};
let total = entry_size(size_of::<K>(), header.value_len);
if offset + total > file_len {
tracing::warn!(
shard_id,
file_id,
offset,
"truncating partial entry at end of file"
);
break;
}
let key_value_len = size_of::<K>() + header.value_len as usize;
let key_value_bytes = read_fn(
&file,
offset + size_of::<EntryHeader>() as u64,
key_value_len,
)?;
let key_bytes = &key_value_bytes[..size_of::<K>()];
let value_bytes = &key_value_bytes[size_of::<K>()..];
let expected_crc = compute_crc32(header.gsn, header.value_len, key_bytes, value_bytes);
if expected_crc != header.crc32 {
if offset + total >= file_len {
tracing::warn!(shard_id, file_id, offset, "CRC mismatch at end of file");
break;
}
return Err(DbError::CrcMismatch {
expected: expected_crc,
actual: header.crc32,
});
}
let seq = header.sequence();
if seq > *max_gsn {
*max_gsn = seq;
}
let key: K = K::from_bytes(key_bytes);
if header.is_tombstone() {
map.remove(&key);
} else {
let value_offset = offset + size_of::<EntryHeader>() as u64 + size_of::<K>() as u64;
let disk = DiskLoc::new(
shard_id,
file_id as u16,
value_offset as u32,
header.value_len,
);
map.insert(key, disk);
}
offset += total;
}
Ok(())
}
#[cfg(feature = "typed-tree")]
pub fn recover_typed_map<K: Key + Send + Sync + Hash + Eq, T: Send + Sync, C: Codec<T> + Sync>(
shard_dirs: &[&Path],
shard_ids: &[u8],
indexes: &[Mutex<HashMap<K, crate::typed_map::TypedMapEntry<T>>>],
codec: &C,
hints: bool,
#[cfg(feature = "encryption")] cipher: Option<Arc<PageCipher>>,
) -> DbResult<u64> {
let results: Vec<DbResult<u64>> = std::thread::scope(|s| {
let handles: Vec<_> = shard_dirs
.iter()
.enumerate()
.map(|(i, &dir)| {
let shard_id = shard_ids[i];
let index = &indexes[i];
#[cfg(feature = "encryption")]
let cipher = cipher.clone();
s.spawn(move || -> DbResult<u64> {
let shard_start = std::time::Instant::now();
let mut local_max_gsn: u64 = 0;
let mut used_hints = false;
let mut map = sync::lock(index);
let file_ids = scan_data_files(dir)?;
let hint_ids = if hints {
hint::scan_hint_files(dir)?
} else {
Vec::new()
};
for file_id in &file_ids {
let data_path = dir.join(format!("{file_id:06}.data"));
#[cfg(feature = "encryption")]
let read_fn = make_reader(&cipher, dir, *file_id)?;
#[cfg(not(feature = "encryption"))]
let read_fn = make_reader(dir, *file_id)?;
if hint_ids.contains(file_id) {
let hint_path = dir.join(format!("{file_id:06}.hint"));
if let Some(hint_data) = hint::read_hint_file(&hint_path)? {
if hint_data.len() % hint::hint_entry_size(size_of::<K>()) == 0 {
let data_file = direct::open_read(&data_path)?;
recover_typed_map_hint::<K, T, C>(
shard_id,
*file_id,
&hint_data,
&data_file,
&mut map,
codec,
&mut local_max_gsn,
&read_fn,
)?;
used_hints = true;
continue;
}
tracing::warn!(
shard_id,
file_id,
"hint file has unexpected size, falling back to full scan"
);
}
}
recover_typed_map_full_scan::<K, T, C>(
shard_id,
*file_id,
&data_path,
&mut map,
codec,
&mut local_max_gsn,
&read_fn,
)?;
}
if let Some(&last_id) = file_ids.last() {
let _ = fs::remove_file(dir.join(format!("{last_id:06}.hint")));
}
let elapsed = shard_start.elapsed().as_secs_f64();
metrics::histogram!("armdb.recovery.duration_seconds", "path" => if used_hints { "hint" } else { "full_scan" })
.record(elapsed);
Ok(local_max_gsn)
})
})
.collect();
handles
.into_iter()
.map(|h| h.join().expect("recovery thread panicked"))
.collect()
});
let mut max_gsn: u64 = 0;
for result in results {
let shard_gsn = result?;
if shard_gsn > max_gsn {
max_gsn = shard_gsn;
}
}
Ok(max_gsn)
}
#[cfg(feature = "typed-tree")]
#[allow(clippy::too_many_arguments)]
fn recover_typed_map_hint<K: Key + Send + Sync + Hash + Eq, T: Send + Sync, C: Codec<T>>(
shard_id: u8,
file_id: u32,
hint_data: &[u8],
data_file: &std::fs::File,
map: &mut HashMap<K, crate::typed_map::TypedMapEntry<T>>,
codec: &C,
max_gsn: &mut u64,
read_fn: &ReadFn,
) -> DbResult<()> {
for entry in hint::parse_hint_entries::<K>(hint_data) {
let seq = entry.sequence();
if seq > *max_gsn {
*max_gsn = seq;
}
if entry.is_tombstone() {
if let Some(old) = map.remove(&entry.key) {
unsafe {
drop(Box::from_raw(old.ptr));
}
}
} else {
let value_bytes = read_fn(data_file, entry.value_offset, entry.value_len as usize)?;
let value = codec.decode_from(&value_bytes)?;
let disk = DiskLoc::new(
shard_id,
file_id as u16,
entry.value_offset as u32,
entry.value_len,
);
let ptr = Box::into_raw(Box::new(TypedData { disk, value }));
if let Some(old) = map.insert(entry.key, crate::typed_map::TypedMapEntry { ptr }) {
unsafe {
drop(Box::from_raw(old.ptr));
}
}
}
}
Ok(())
}
#[cfg(feature = "typed-tree")]
fn recover_typed_map_full_scan<K: Key + Send + Sync + Hash + Eq, T: Send + Sync, C: Codec<T>>(
shard_id: u8,
file_id: u32,
data_path: &Path,
map: &mut HashMap<K, crate::typed_map::TypedMapEntry<T>>,
codec: &C,
max_gsn: &mut u64,
read_fn: &ReadFn,
) -> DbResult<()> {
let file = direct::open_read(data_path)?;
let file_len = file.metadata()?.len();
let mut offset: u64 = 0;
while offset + size_of::<EntryHeader>() as u64 <= file_len {
let header_bytes = match read_fn(&file, offset, size_of::<EntryHeader>()) {
Ok(b) => b,
Err(_) => break,
};
let header = match EntryHeader::read_from_bytes(&header_bytes) {
Ok(h) => h,
Err(_) => break,
};
let total = entry_size(size_of::<K>(), header.value_len);
if offset + total > file_len {
tracing::warn!(
shard_id,
file_id,
offset,
"truncating partial entry at end of file"
);
break;
}
let key_value_len = size_of::<K>() + header.value_len as usize;
let key_value_bytes = read_fn(
&file,
offset + size_of::<EntryHeader>() as u64,
key_value_len,
)?;
let key_bytes = &key_value_bytes[..size_of::<K>()];
let value_bytes = &key_value_bytes[size_of::<K>()..];
let expected_crc = compute_crc32(header.gsn, header.value_len, key_bytes, value_bytes);
if expected_crc != header.crc32 {
if offset + total >= file_len {
tracing::warn!(shard_id, file_id, offset, "CRC mismatch at end of file");
break;
}
return Err(DbError::CrcMismatch {
expected: expected_crc,
actual: header.crc32,
});
}
let seq = header.sequence();
if seq > *max_gsn {
*max_gsn = seq;
}
let key: K = K::from_bytes(key_bytes);
if header.is_tombstone() {
if let Some(old) = map.remove(&key) {
unsafe {
drop(Box::from_raw(old.ptr));
}
}
} else {
let value = codec
.decode_from(value_bytes)
.map_err(|_| DbError::CorruptedEntry { offset })?;
let value_offset = offset + size_of::<EntryHeader>() as u64 + size_of::<K>() as u64;
let disk = DiskLoc::new(
shard_id,
file_id as u16,
value_offset as u32,
header.value_len,
);
let ptr = Box::into_raw(Box::new(TypedData { disk, value }));
if let Some(old) = map.insert(key, crate::typed_map::TypedMapEntry { ptr }) {
unsafe {
drop(Box::from_raw(old.ptr));
}
}
}
offset += total;
}
Ok(())
}
fn scan_data_files(dir: &Path) -> DbResult<Vec<u32>> {
let mut ids = Vec::new();
if !dir.exists() {
return Ok(ids);
}
for entry in fs::read_dir(dir)? {
let entry = entry?;
let name = entry.file_name();
let name = name.to_string_lossy();
if name.ends_with(".data")
&& let Ok(id) = name.trim_end_matches(".data").parse::<u32>()
{
ids.push(id);
}
}
ids.sort();
Ok(ids)
}