use std::collections::{BTreeMap, HashMap};
use crate::catalog::{encode_catalog_payload, Catalog, CatalogRecordWire};
use crate::db::build_non_pk_values_in_schema_order;
use crate::error::{DbError, FormatError, SchemaError};
use crate::file_format::{check_decode_entry_count, check_segment_payload_len};
use crate::index::{decode_index_payload, encode_index_payload, IndexEntry, IndexState};
use crate::record::{
encode_record_payload_v2, encode_record_payload_v3, non_pk_defs_in_order, RowValue, ScalarValue,
};
use crate::schema::CollectionId;
use crate::db::LatestMap;
pub const CHECKPOINT_VERSION_V0: u16 = 0;
#[derive(Debug, Clone)]
pub struct CheckpointV0 {
pub replay_from_offset: u64,
pub catalog_records: Vec<CatalogRecordWire>,
pub record_payloads: Vec<Vec<u8>>,
pub index_entries: Vec<IndexEntry>,
}
pub fn encode_checkpoint_payload_v0(cp: &CheckpointV0) -> Vec<u8> {
#[cfg(feature = "tracing")]
tracing::debug!(
catalog_records = cp.catalog_records.len(),
record_payloads = cp.record_payloads.len(),
index_entries = cp.index_entries.len(),
"encode_checkpoint_payload_v0"
);
let mut out = Vec::new();
out.extend_from_slice(&CHECKPOINT_VERSION_V0.to_le_bytes());
out.extend_from_slice(&cp.replay_from_offset.to_le_bytes());
out.extend_from_slice(&(cp.catalog_records.len() as u32).to_le_bytes());
for r in &cp.catalog_records {
let b = encode_catalog_payload(r);
out.extend_from_slice(&(b.len() as u32).to_le_bytes());
out.extend_from_slice(b.as_slice());
}
out.extend_from_slice(&(cp.record_payloads.len() as u32).to_le_bytes());
for b in &cp.record_payloads {
out.extend_from_slice(&(b.len() as u32).to_le_bytes());
out.extend_from_slice(b.as_slice());
}
let idx_blob = encode_index_payload(&cp.index_entries);
out.extend_from_slice(&(idx_blob.len() as u32).to_le_bytes());
out.extend_from_slice(&idx_blob);
out
}
pub fn decode_checkpoint_payload(bytes: &[u8]) -> Result<CheckpointV0, DbError> {
#[cfg(feature = "tracing")]
let _span = tracing::info_span!("decode_checkpoint_payload", bytes = bytes.len()).entered();
let mut cur = Cursor::new(bytes);
let ver = cur.take_u16()?;
if ver != CHECKPOINT_VERSION_V0 {
return Err(DbError::Format(FormatError::UnsupportedVersion {
major: 0,
minor: ver,
}));
}
let replay_from_offset = cur.take_u64()?;
let n_catalog = cur.take_u32()? as usize;
check_decode_entry_count(n_catalog)?;
let mut catalog_records = Vec::with_capacity(n_catalog.min(1024));
for _ in 0..n_catalog {
let n = cur.take_u32()? as usize;
check_segment_payload_len(n as u64)?;
let b = cur.take_bytes(n)?;
let rec = crate::catalog::decode_catalog_payload(&b)?;
catalog_records.push(rec);
}
let n_records = cur.take_u32()? as usize;
check_decode_entry_count(n_records)?;
let mut record_payloads = Vec::with_capacity(n_records.min(1024));
for _ in 0..n_records {
let n = cur.take_u32()? as usize;
check_segment_payload_len(n as u64)?;
record_payloads.push(cur.take_bytes(n)?);
}
let idx_blob_len = cur.take_u32()? as usize;
check_segment_payload_len(idx_blob_len as u64)?;
let idx_blob = cur.take_bytes(idx_blob_len)?;
let index_entries = decode_index_payload(&idx_blob)?;
if cur.remaining() != 0 {
return Err(DbError::Format(FormatError::InvalidCatalogPayload {
message: "trailing bytes in checkpoint payload".to_string(),
}));
}
let cp = CheckpointV0 {
replay_from_offset,
catalog_records,
record_payloads,
index_entries,
};
#[cfg(feature = "tracing")]
tracing::info!(
replay_from_offset = cp.replay_from_offset,
catalog_records = cp.catalog_records.len(),
record_payloads = cp.record_payloads.len(),
index_entries = cp.index_entries.len(),
"decode_checkpoint_payload_ok"
);
Ok(cp)
}
pub fn checkpoint_from_state(
catalog: &Catalog,
latest: &LatestMap,
indexes: &IndexState,
) -> Result<CheckpointV0, DbError> {
#[cfg(feature = "tracing")]
let _span = tracing::info_span!("checkpoint_from_state").entered();
let mut catalog_records: Vec<CatalogRecordWire> = Vec::new();
let mut cols = catalog.collections();
cols.sort_by_key(|c| c.id.0);
for c in &cols {
let pk = c
.primary_field
.as_deref()
.ok_or(DbError::Schema(SchemaError::NoPrimaryKey {
collection_id: c.id.0,
}))?;
catalog_records.push(CatalogRecordWire::CreateCollection {
collection_id: c.id.0,
name: c.name.clone(),
schema_version: 1,
fields: c.fields.clone(),
indexes: c.indexes.clone(),
primary_field: Some(pk.to_string()),
});
for v in 2..=c.current_version.0 {
catalog_records.push(CatalogRecordWire::NewSchemaVersion {
collection_id: c.id.0,
schema_version: v,
fields: c.fields.clone(),
indexes: c.indexes.clone(),
});
}
}
let mut record_payloads: Vec<Vec<u8>> = Vec::with_capacity(latest.len().min(1_000_000));
for ((cid, _pk_key), row) in latest.iter() {
let col = catalog
.get(CollectionId(*cid))
.ok_or(DbError::Schema(SchemaError::UnknownCollection { id: *cid }))?;
let pk_name =
col.primary_field
.as_deref()
.ok_or(DbError::Schema(SchemaError::NoPrimaryKey {
collection_id: col.id.0,
}))?;
let pk_def = col
.fields
.iter()
.find(|f| f.path.0.len() == 1 && f.path.0[0] == pk_name)
.ok_or(DbError::Schema(SchemaError::PrimaryFieldNotFound {
name: pk_name.to_string(),
}))?;
let pk_cell = row
.get(pk_name)
.ok_or(DbError::Schema(SchemaError::RowMissingPrimary {
name: pk_name.to_string(),
}))?;
let pk_scalar: ScalarValue = pk_cell.clone().into_scalar()?;
let has_multi_segment_schema = col.fields.iter().any(|f| f.path.0.len() != 1);
let non_pk_defs = if has_multi_segment_schema {
col.fields
.iter()
.filter(|f| !(f.path.0.len() == 1 && f.path.0[0] == pk_name))
.collect::<Vec<_>>()
} else {
non_pk_defs_in_order(&col.fields, pk_name)
};
let ordered = build_non_pk_values_in_schema_order(row, &non_pk_defs)?;
record_payloads.push(if has_multi_segment_schema {
encode_record_payload_v3(
*cid,
col.current_version.0,
&pk_scalar,
&pk_def.ty,
&ordered,
)?
} else {
encode_record_payload_v2(
*cid,
col.current_version.0,
&pk_scalar,
&pk_def.ty,
&ordered,
)?
});
}
let index_entries = indexes.entries_for_checkpoint();
let cp = CheckpointV0 {
replay_from_offset: 0,
catalog_records,
record_payloads,
index_entries,
};
#[cfg(feature = "tracing")]
tracing::info!(
catalog_records = cp.catalog_records.len(),
record_payloads = cp.record_payloads.len(),
index_entries = cp.index_entries.len(),
"checkpoint_from_state_ok"
);
Ok(cp)
}
pub fn state_from_checkpoint_payload(
payload: &[u8],
) -> Result<(u64, Catalog, LatestMap, IndexState), DbError> {
let cp = decode_checkpoint_payload(payload)?;
let mut catalog = Catalog::default();
for r in &cp.catalog_records {
catalog.apply_record(r.clone())?;
}
let mut latest: LatestMap = HashMap::new();
for rec in &cp.record_payloads {
apply_checkpoint_record_payload(rec, &catalog, &mut latest)?;
}
let mut indexes = IndexState::default();
for e in cp.index_entries {
indexes.apply(e)?;
}
Ok((cp.replay_from_offset, catalog, latest, indexes))
}
fn apply_checkpoint_record_payload(
payload: &[u8],
catalog: &Catalog,
latest: &mut LatestMap,
) -> Result<(), DbError> {
if payload.len() < 6 {
return Err(DbError::Format(FormatError::TruncatedRecordPayload));
}
let collection_id = u32::from_le_bytes([payload[2], payload[3], payload[4], payload[5]]);
let col = catalog
.get(CollectionId(collection_id))
.ok_or(DbError::Schema(SchemaError::UnknownCollection {
id: collection_id,
}))?;
let pk_name =
col.primary_field
.as_deref()
.ok_or(DbError::Schema(SchemaError::NoPrimaryKey {
collection_id: col.id.0,
}))?;
let pk_ty = col
.fields
.iter()
.find(|f| f.path.0.len() == 1 && f.path.0[0] == pk_name)
.map(|f| &f.ty)
.ok_or(DbError::Schema(SchemaError::PrimaryFieldNotFound {
name: pk_name.to_string(),
}))?;
let decode_fields = crate::record::fields_for_record_decode(col, payload, pk_name, pk_ty)?;
let decoded = crate::record::decode_record_payload(payload, pk_name, pk_ty, &decode_fields)?;
if decoded.schema_version > col.current_version.0 {
return Err(DbError::Schema(SchemaError::InvalidSchemaVersion {
expected: col.current_version.0,
got: decoded.schema_version,
}));
}
if decoded.schema_version < col.current_version.0 {
let non_pk_count = col
.fields
.iter()
.filter(|f| f.path.0.len() == 1 && f.path.0[0] != pk_name)
.count();
if decoded.op != crate::record::OP_DELETE && decoded.fields.len() > non_pk_count {
return Err(DbError::Format(FormatError::InvalidCatalogPayload {
message: format!(
"checkpoint record schema_version {} layout incompatible with catalog version {}",
decoded.schema_version, col.current_version.0
),
}));
}
}
let pk_key = decoded.pk.canonical_key_bytes();
if decoded.op == crate::record::OP_DELETE {
latest.remove(&(collection_id, pk_key));
return Ok(());
}
let mut full: BTreeMap<String, RowValue> = BTreeMap::new();
full.insert(pk_name.to_string(), RowValue::from_scalar(decoded.pk));
for (k, v) in decoded.fields {
full.insert(k, v);
}
latest.insert((collection_id, pk_key), full);
Ok(())
}
struct Cursor<'a> {
bytes: &'a [u8],
pos: usize,
}
impl<'a> Cursor<'a> {
fn new(bytes: &'a [u8]) -> Self {
Self { bytes, pos: 0 }
}
fn remaining(&self) -> usize {
self.bytes.len().saturating_sub(self.pos)
}
fn take_u16(&mut self) -> Result<u16, DbError> {
if self.remaining() < 2 {
return Err(DbError::Format(FormatError::InvalidCatalogPayload {
message: "unexpected eof".to_string(),
}));
}
let v = u16::from_le_bytes([self.bytes[self.pos], self.bytes[self.pos + 1]]);
self.pos += 2;
Ok(v)
}
fn take_u32(&mut self) -> Result<u32, DbError> {
if self.remaining() < 4 {
return Err(DbError::Format(FormatError::InvalidCatalogPayload {
message: "unexpected eof".to_string(),
}));
}
let v = u32::from_le_bytes([
self.bytes[self.pos],
self.bytes[self.pos + 1],
self.bytes[self.pos + 2],
self.bytes[self.pos + 3],
]);
self.pos += 4;
Ok(v)
}
fn take_u64(&mut self) -> Result<u64, DbError> {
if self.remaining() < 8 {
return Err(DbError::Format(FormatError::InvalidCatalogPayload {
message: "unexpected eof".to_string(),
}));
}
let v = u64::from_le_bytes(self.bytes[self.pos..self.pos + 8].try_into().unwrap());
self.pos += 8;
Ok(v)
}
fn take_bytes(&mut self, n: usize) -> Result<Vec<u8>, DbError> {
if self.remaining() < n {
return Err(DbError::Format(FormatError::InvalidCatalogPayload {
message: "unexpected eof".to_string(),
}));
}
let slice = &self.bytes[self.pos..self.pos + n];
self.pos += n;
Ok(slice.to_vec())
}
}
#[cfg(test)]
mod tests {
include!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/unit/src_checkpoint_tests.rs"
));
}