use crate::error::{PicanteError, PicanteResult};
use crate::key::QueryKindId;
use crate::revision::Revision;
use crate::runtime::Runtime;
use crate::wal::{WalEntry, WalOperation, WalReader, WalWriter};
use facet::Facet;
use futures_util::future::BoxFuture;
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use tracing::{trace, warn};
const FORMAT_VERSION: u32 = 1;
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum OnCorruptCache {
Error,
Ignore,
Delete,
}
#[derive(Debug, Clone)]
pub struct CacheLoadOptions {
pub max_bytes: Option<usize>,
pub on_corrupt: OnCorruptCache,
}
impl Default for CacheLoadOptions {
fn default() -> Self {
Self {
max_bytes: None,
on_corrupt: OnCorruptCache::Error,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheSaveOptions {
pub max_bytes: Option<usize>,
pub max_records_per_section: Option<usize>,
pub max_record_bytes: Option<usize>,
}
#[derive(Debug, Clone, Facet)]
pub struct CacheFile {
pub format_version: u32,
pub current_revision: u64,
pub sections: Vec<Section>,
}
#[derive(Debug, Clone, Facet)]
pub struct Section {
pub kind_id: u32,
pub kind_name: String,
pub section_type: SectionType,
pub records: Vec<Vec<u8>>,
}
#[repr(u8)]
#[derive(Debug, Copy, Clone, Eq, PartialEq, Facet)]
pub enum SectionType {
Input,
Derived,
Interned,
}
pub trait PersistableIngredient: Send + Sync {
fn kind(&self) -> QueryKindId;
fn kind_name(&self) -> &'static str;
fn section_type(&self) -> SectionType;
fn clear(&self);
fn save_records(&self) -> BoxFuture<'_, PicanteResult<Vec<Vec<u8>>>>;
fn load_records(&self, records: Vec<Vec<u8>>) -> PicanteResult<()>;
fn restore_runtime_state<'a>(
&'a self,
_runtime: &'a Runtime,
) -> BoxFuture<'a, PicanteResult<()>> {
Box::pin(async { Ok(()) })
}
#[allow(clippy::type_complexity)]
fn save_incremental_records(
&self,
_since_revision: u64,
) -> BoxFuture<'_, PicanteResult<Vec<(u64, Vec<u8>, Option<Vec<u8>>)>>> {
Box::pin(async { Ok(vec![]) })
}
fn apply_wal_entry(
&self,
_revision: u64,
_key: Vec<u8>,
_value: Option<Vec<u8>>,
) -> PicanteResult<()> {
Ok(())
}
}
pub async fn save_cache(
path: impl AsRef<Path>,
runtime: &Runtime,
ingredients: &[&dyn PersistableIngredient],
) -> PicanteResult<()> {
save_cache_with_options(path, runtime, ingredients, &CacheSaveOptions::default()).await
}
pub async fn save_cache_with_options(
path: impl AsRef<Path>,
runtime: &Runtime,
ingredients: &[&dyn PersistableIngredient],
options: &CacheSaveOptions,
) -> PicanteResult<()> {
use std::time::Instant;
let total_start = Instant::now();
let path = path.as_ref();
trace!(path = %path.display(), "save_cache: start");
ensure_unique_kinds(ingredients)?;
let collect_start = Instant::now();
let mut sections = Vec::with_capacity(ingredients.len());
for ingredient in ingredients {
let mut records = ingredient.save_records().await?;
if let Some(max) = options.max_record_bytes {
let before = records.len();
records.retain(|r| r.len() <= max);
let dropped = before - records.len();
if dropped != 0 {
warn!(
kind = ingredient.kind().as_u32(),
dropped,
max_record_bytes = max,
"save_cache: skipped oversized records"
);
}
}
sections.push(Section {
kind_id: ingredient.kind().as_u32(),
kind_name: ingredient.kind_name().to_string(),
section_type: ingredient.section_type(),
records,
});
}
let collect_elapsed = collect_start.elapsed();
let num_sections = sections.len();
let total_records: usize = sections.iter().map(|s| s.records.len()).sum();
let mut cache = CacheFile {
format_version: FORMAT_VERSION,
current_revision: runtime.current_revision().0,
sections,
};
if let Some(max) = options.max_records_per_section {
for section in &mut cache.sections {
if section.records.len() > max {
section.records.truncate(max);
}
}
}
if let Some(max_bytes) = options.max_bytes {
shrink_cache_to_fit(&mut cache, max_bytes)?;
}
let encode_start = Instant::now();
let bytes = encode_cache_file(&cache)?;
let encode_elapsed = encode_start.elapsed();
if let Some(parent) = path.parent() {
tokio::fs::create_dir_all(parent).await.map_err(|e| {
Arc::new(PicanteError::Cache {
message: format!("create_dir_all {}: {e}", parent.display()),
})
})?;
}
let write_start = Instant::now();
let tmp = path.with_extension("tmp");
tokio::fs::write(&tmp, &bytes).await.map_err(|e| {
Arc::new(PicanteError::Cache {
message: format!("write {}: {e}", tmp.display()),
})
})?;
tokio::fs::rename(&tmp, path).await.map_err(|e| {
Arc::new(PicanteError::Cache {
message: format!("rename {} -> {}: {e}", tmp.display(), path.display()),
})
})?;
let write_elapsed = write_start.elapsed();
let total_elapsed = total_start.elapsed();
trace!(
path = %path.display(),
bytes = bytes.len(),
rev = runtime.current_revision().0,
sections = num_sections,
records = total_records,
collect_ms = collect_elapsed.as_millis(),
encode_ms = encode_elapsed.as_millis(),
write_ms = write_elapsed.as_millis(),
total_ms = total_elapsed.as_millis(),
"save_cache: done"
);
Ok(())
}
pub async fn load_cache(
path: impl AsRef<Path>,
runtime: &Runtime,
ingredients: &[&dyn PersistableIngredient],
) -> PicanteResult<bool> {
load_cache_with_options(path, runtime, ingredients, &CacheLoadOptions::default()).await
}
pub async fn load_cache_with_options(
path: impl AsRef<Path>,
runtime: &Runtime,
ingredients: &[&dyn PersistableIngredient],
options: &CacheLoadOptions,
) -> PicanteResult<bool> {
match load_cache_inner(path.as_ref(), runtime, ingredients, options).await {
Ok(v) => Ok(v),
Err(e) => match options.on_corrupt {
OnCorruptCache::Error => Err(e),
OnCorruptCache::Ignore => {
warn!(error = %e, "load_cache: ignoring corrupt cache");
Ok(false)
}
OnCorruptCache::Delete => {
warn!(error = %e, "load_cache: deleting corrupt cache");
let path = path.as_ref();
let _ = tokio::fs::remove_file(path).await;
Ok(false)
}
},
}
}
async fn load_cache_inner(
path: &Path,
runtime: &Runtime,
ingredients: &[&dyn PersistableIngredient],
options: &CacheLoadOptions,
) -> PicanteResult<bool> {
use std::time::Instant;
let total_start = Instant::now();
trace!(path = %path.display(), "load_cache: start");
ensure_unique_kinds(ingredients)?;
let read_start = Instant::now();
let bytes = match tokio::fs::read(path).await {
Ok(b) => b,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(false),
Err(e) => {
return Err(Arc::new(PicanteError::Cache {
message: format!("read {}: {e}", path.display()),
}));
}
};
let read_elapsed = read_start.elapsed();
if let Some(max) = options.max_bytes
&& bytes.len() > max
{
return Err(Arc::new(PicanteError::Cache {
message: format!("cache file too large ({} bytes > max {max})", bytes.len()),
}));
}
let decode_start = Instant::now();
let cache: CacheFile = decode_cache_file(&bytes)?;
let decode_elapsed = decode_start.elapsed();
let num_sections = cache.sections.len();
let total_records: usize = cache.sections.iter().map(|s| s.records.len()).sum();
if cache.format_version != FORMAT_VERSION {
return Err(Arc::new(PicanteError::Cache {
message: format!(
"unsupported cache format version {}; expected {}",
cache.format_version, FORMAT_VERSION
),
}));
}
let mut by_kind: HashMap<u32, &dyn PersistableIngredient> = HashMap::new();
for ingredient in ingredients {
by_kind.insert(ingredient.kind().as_u32(), *ingredient);
}
runtime.clear_dependency_graph();
for ingredient in ingredients {
ingredient.clear();
}
let load_start = Instant::now();
for section in cache.sections {
let Some(ingredient) = by_kind.get(§ion.kind_id).copied() else {
warn!(
kind_id = section.kind_id,
kind_name = %section.kind_name,
"load_cache: ignoring unknown section"
);
continue;
};
if section.kind_name != ingredient.kind_name() {
return Err(Arc::new(PicanteError::Cache {
message: format!(
"kind name mismatch for id {}: file has `{}`, runtime has `{}`",
section.kind_id,
section.kind_name,
ingredient.kind_name()
),
}));
}
if section.section_type != ingredient.section_type() {
return Err(Arc::new(PicanteError::Cache {
message: format!(
"section type mismatch for id {} (`{}`)",
section.kind_id, section.kind_name
),
}));
}
ingredient.load_records(section.records)?;
}
let load_elapsed = load_start.elapsed();
let restore_start = Instant::now();
for ingredient in ingredients {
ingredient.restore_runtime_state(runtime).await?;
}
let restore_elapsed = restore_start.elapsed();
runtime.set_current_revision(Revision(cache.current_revision));
let total_elapsed = total_start.elapsed();
trace!(
path = %path.display(),
bytes = bytes.len(),
rev = runtime.current_revision().0,
sections = num_sections,
records = total_records,
read_ms = read_elapsed.as_millis(),
decode_ms = decode_elapsed.as_millis(),
load_ms = load_elapsed.as_millis(),
restore_ms = restore_elapsed.as_millis(),
total_ms = total_elapsed.as_millis(),
"load_cache: done"
);
Ok(true)
}
fn ensure_unique_kinds(ingredients: &[&dyn PersistableIngredient]) -> PicanteResult<()> {
let mut seen = std::collections::HashSet::<u32>::new();
for i in ingredients {
let id = i.kind().as_u32();
if !seen.insert(id) {
return Err(Arc::new(PicanteError::Cache {
message: format!("duplicate ingredient kind id {id}"),
}));
}
}
Ok(())
}
fn encode_cache_file(cache: &CacheFile) -> PicanteResult<Vec<u8>> {
facet_postcard::to_vec(cache).map_err(|e| {
Arc::new(PicanteError::Encode {
what: "cache file",
message: format!("{e:?}"),
})
})
}
fn decode_cache_file(bytes: &[u8]) -> PicanteResult<CacheFile> {
facet_postcard::from_slice(bytes).map_err(|e| {
Arc::new(PicanteError::Decode {
what: "cache file",
message: format!("{e:?}"),
})
})
}
fn shrink_cache_to_fit(cache: &mut CacheFile, max_bytes: usize) -> PicanteResult<()> {
let bytes = encode_cache_file(cache)?;
if bytes.len() <= max_bytes {
return Ok(());
}
let record_bytes = cache
.sections
.iter()
.map(|s| s.records.iter().map(|r| r.len()).sum::<usize>())
.sum::<usize>();
let overhead = bytes.len().checked_sub(record_bytes).unwrap_or(bytes.len());
if overhead >= max_bytes {
return Err(Arc::new(PicanteError::Cache {
message: format!("cache overhead ({overhead} bytes) exceeds max_bytes ({max_bytes})"),
}));
}
let mut budget_for_records = max_bytes - overhead;
for section in &mut cache.sections {
section.records.sort_by_key(|r| r.len());
}
let mut current_record_bytes = record_bytes;
while current_record_bytes > budget_for_records {
if !drop_one_record(cache, SectionType::Derived, &mut current_record_bytes)
&& !drop_one_record(cache, SectionType::Input, &mut current_record_bytes)
&& !drop_one_record(cache, SectionType::Interned, &mut current_record_bytes)
{
break;
}
}
for _ in 0..3 {
let bytes = encode_cache_file(cache)?;
if bytes.len() <= max_bytes {
trace!(
before_bytes = bytes.len(),
max_bytes, "save_cache: cache truncated to fit"
);
return Ok(());
}
let record_bytes = cache
.sections
.iter()
.map(|s| s.records.iter().map(|r| r.len()).sum::<usize>())
.sum::<usize>();
let overhead = bytes.len().saturating_sub(record_bytes);
if overhead >= max_bytes {
break;
}
budget_for_records = max_bytes - overhead;
current_record_bytes = record_bytes;
while current_record_bytes > budget_for_records {
if !drop_one_record(cache, SectionType::Derived, &mut current_record_bytes)
&& !drop_one_record(cache, SectionType::Input, &mut current_record_bytes)
&& !drop_one_record(cache, SectionType::Interned, &mut current_record_bytes)
{
break;
}
}
}
let bytes = encode_cache_file(cache)?;
if bytes.len() > max_bytes {
return Err(Arc::new(PicanteError::Cache {
message: format!(
"cache remains too large after truncation ({} > {max_bytes})",
bytes.len()
),
}));
}
Ok(())
}
fn drop_one_record(
cache: &mut CacheFile,
ty: SectionType,
current_record_bytes: &mut usize,
) -> bool {
let mut best: Option<(usize, usize)> = None; for (idx, section) in cache.sections.iter().enumerate() {
if section.section_type != ty {
continue;
}
let Some(len) = section.records.last().map(|r| r.len()) else {
continue;
};
if best.is_none_or(|(_, best_len)| len > best_len) {
best = Some((idx, len));
}
}
let Some((idx, len)) = best else {
return false;
};
let section = &mut cache.sections[idx];
section.records.pop();
*current_record_bytes = current_record_bytes.saturating_sub(len);
true
}
pub async fn append_to_wal(
wal: &mut WalWriter,
_runtime: &Runtime,
ingredients: &[&dyn PersistableIngredient],
) -> PicanteResult<usize> {
let since_revision = wal.base_revision();
let mut entry_count = 0;
for ingredient in ingredients {
let kind_id = ingredient.kind().0;
let changes = ingredient.save_incremental_records(since_revision).await?;
for (changed_revision, key, value) in changes {
let operation = match value {
Some(val) => WalOperation::Set { key, value: val },
None => WalOperation::Delete { key },
};
let entry = WalEntry {
revision: changed_revision,
kind_id,
operation,
};
wal.append(entry)?;
entry_count += 1;
}
}
trace!("Appended {entry_count} entries to WAL");
Ok(entry_count)
}
pub async fn replay_wal(
path: impl AsRef<Path>,
runtime: &Runtime,
ingredients: &[&dyn PersistableIngredient],
) -> PicanteResult<usize> {
let path = path.as_ref();
if let Err(e) = std::fs::metadata(path)
&& e.kind() == std::io::ErrorKind::NotFound
{
trace!("No WAL file found at {}, skipping replay", path.display());
return Ok(0);
}
let mut reader = WalReader::open(path)?;
let base_revision = reader.header().base_revision;
if runtime.current_revision().0 != base_revision {
return Err(Arc::new(PicanteError::Cache {
message: format!(
"WAL base revision ({}) does not match snapshot revision ({})",
base_revision,
runtime.current_revision().0,
),
}));
}
trace!(
"Replaying WAL from {} (base revision: {})",
path.display(),
base_revision
);
let mut ingredient_map: HashMap<u32, &dyn PersistableIngredient> = HashMap::new();
for ingredient in ingredients {
ingredient_map.insert(ingredient.kind().0, *ingredient);
}
let mut entry_count = 0;
let mut max_revision = base_revision;
for entry_result in reader.entries() {
let entry = entry_result?;
let Some(ingredient) = ingredient_map.get(&entry.kind_id) else {
warn!(
"WAL entry references unknown ingredient kind_id={}, skipping",
entry.kind_id
);
continue;
};
match entry.operation {
WalOperation::Set { key, value } => {
ingredient.apply_wal_entry(entry.revision, key, Some(value))?;
}
WalOperation::Delete { key } => {
ingredient.apply_wal_entry(entry.revision, key, None)?;
}
}
max_revision = max_revision.max(entry.revision);
entry_count += 1;
}
if max_revision > base_revision {
runtime.set_current_revision(Revision(max_revision));
trace!("Set runtime revision to {max_revision} from WAL");
}
for ingredient in ingredients {
ingredient.restore_runtime_state(runtime).await?;
}
trace!("Replayed {entry_count} WAL entries");
Ok(entry_count)
}
pub async fn compact_wal(
cache_path: impl AsRef<Path>,
wal_path: impl AsRef<Path>,
runtime: &Runtime,
ingredients: &[&dyn PersistableIngredient],
options: &CacheSaveOptions,
create_new_wal: bool,
) -> PicanteResult<u64> {
let cache_path = cache_path.as_ref();
let wal_path = wal_path.as_ref();
trace!("Compacting WAL: creating new snapshot");
let temp_cache_path = {
let temp_name = match cache_path.file_name().and_then(|s| s.to_str()) {
Some(name) => format!("{name}.compact.tmp"),
None => format!("cache-{}.compact.tmp", std::process::id()),
};
cache_path.with_file_name(temp_name)
};
save_cache_with_options(&temp_cache_path, runtime, ingredients, options).await?;
let new_revision = runtime.current_revision().0;
let rename_result = tokio::fs::rename(&temp_cache_path, cache_path).await;
if let Err(e) = rename_result {
if let Err(cleanup_err) = tokio::fs::remove_file(&temp_cache_path).await {
warn!(
"Failed to remove temporary snapshot at {} after rename error: {}",
temp_cache_path.display(),
cleanup_err
);
}
return Err(Arc::new(PicanteError::Cache {
message: format!(
"Failed to rename temporary snapshot from {} to {}: {}",
temp_cache_path.display(),
cache_path.display(),
e
),
}));
}
trace!("Atomically installed new snapshot");
if wal_path.exists() {
tokio::fs::remove_file(wal_path).await.map_err(|e| {
Arc::new(PicanteError::Cache {
message: format!("Failed to delete old WAL at {}: {}", wal_path.display(), e),
})
})?;
trace!("Deleted old WAL file");
}
if create_new_wal {
let _new_wal = WalWriter::create(wal_path, new_revision)?;
trace!("Created new WAL at revision {new_revision}");
}
trace!("WAL compaction complete at revision {new_revision}");
Ok(new_revision)
}