#![deny(missing_docs)]
pub use spg_engine::{CatalogSnapshot, Engine, EngineError, ParsedStatement, QueryResult};
pub use spg_storage::{ColumnSchema, DataType, Value};
#[derive(Debug, Clone)]
pub struct Statement {
pub(crate) stmt: ParsedStatement,
pub(crate) sql: String,
}
impl Statement {
#[must_use]
pub fn sql(&self) -> &str {
&self.sql
}
}
fn wal_render_with_params(stmt: &mut ParsedStatement, params: &[Value]) {
let _ = spg_engine::substitute_placeholders(stmt, params);
}
use std::collections::BTreeMap;
use std::fs::{File, OpenOptions};
use std::io::{Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::thread::{self, JoinHandle};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
fn wall_clock_micros() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| i64::try_from(d.as_micros()).unwrap_or(i64::MAX))
}
use spg_manifest::{CatalogManifest, ColdSegmentEntry, manifest_path as spg_manifest_path};
const WAL_V2_SENTINEL: u32 = 0x8000_0000;
const WAL_V3_FLAG: u32 = 0x4000_0000;
const WAL_V3_TYPE_AUTO_COMMIT_SQL: u8 = 0x01;
const WAL_V3_TYPE_DURABILITY_CHECKPOINT: u8 = 0x02;
const WAL_V4_TYPE_AUTO_COMMIT_SQL: u8 = 0x10;
const WAL_V4_NO_CLOCK: i64 = i64::MIN;
const WAL_V4_EXTRA_HEADER: usize = 16;
const WAL_V4_TYPE_CHECKPOINT_MARKER: u8 = 0x11;
fn default_checkpoint_threshold_bytes() -> u64 {
std::env::var("SPG_EMBEDDED_CHECKPOINT_BYTES")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.filter(|&n| n > 0)
.unwrap_or(4 * 1024 * 1024)
}
fn encode_v3_auto_commit(sql: &str) -> Vec<u8> {
let payload = sql.as_bytes();
let mut crc_buf = Vec::with_capacity(1 + payload.len());
crc_buf.push(WAL_V3_TYPE_AUTO_COMMIT_SQL);
crc_buf.extend_from_slice(payload);
let crc = spg_crypto::crc32::crc32(&crc_buf);
let header = ((payload.len() as u32) | WAL_V2_SENTINEL | WAL_V3_FLAG).to_le_bytes();
let mut out = Vec::with_capacity(4 + 4 + 1 + payload.len());
out.extend_from_slice(&header);
out.extend_from_slice(&crc.to_le_bytes());
out.push(WAL_V3_TYPE_AUTO_COMMIT_SQL);
out.extend_from_slice(payload);
out
}
fn encode_v4_checkpoint_marker(
checkpoint_lsn: u64,
checkpoint_unix_us: i64,
snapshot_path: &Path,
) -> Vec<u8> {
let snapshot_bytes = snapshot_path.to_string_lossy().into_owned();
let snap_payload = snapshot_bytes.as_bytes();
let snap_len_u16: u16 = snap_payload.len().min(u16::MAX as usize) as u16;
let mut payload = Vec::with_capacity(8 + 8 + 2 + snap_payload.len());
payload.extend_from_slice(&checkpoint_lsn.to_le_bytes());
payload.extend_from_slice(&checkpoint_unix_us.to_le_bytes());
payload.extend_from_slice(&snap_len_u16.to_le_bytes());
payload.extend_from_slice(&snap_payload[..snap_len_u16 as usize]);
let mut crc_buf = Vec::with_capacity(1 + payload.len());
crc_buf.push(WAL_V4_TYPE_CHECKPOINT_MARKER);
crc_buf.extend_from_slice(&payload);
let crc = spg_crypto::crc32::crc32(&crc_buf);
let header = ((payload.len() as u32) | WAL_V2_SENTINEL | WAL_V3_FLAG).to_le_bytes();
let mut out = Vec::with_capacity(4 + 4 + 1 + payload.len());
out.extend_from_slice(&header);
out.extend_from_slice(&crc.to_le_bytes());
out.push(WAL_V4_TYPE_CHECKPOINT_MARKER);
out.extend_from_slice(&payload);
out
}
fn encode_v4_auto_commit(sql: &str, commit_lsn: u64, commit_unix_us: i64) -> Vec<u8> {
let payload = sql.as_bytes();
let mut crc_buf = Vec::with_capacity(1 + WAL_V4_EXTRA_HEADER + payload.len());
crc_buf.push(WAL_V4_TYPE_AUTO_COMMIT_SQL);
crc_buf.extend_from_slice(&commit_lsn.to_le_bytes());
crc_buf.extend_from_slice(&commit_unix_us.to_le_bytes());
crc_buf.extend_from_slice(payload);
let crc = spg_crypto::crc32::crc32(&crc_buf);
let header = ((payload.len() as u32) | WAL_V2_SENTINEL | WAL_V3_FLAG).to_le_bytes();
let mut out = Vec::with_capacity(4 + 4 + 1 + WAL_V4_EXTRA_HEADER + payload.len());
out.extend_from_slice(&header);
out.extend_from_slice(&crc.to_le_bytes());
out.push(WAL_V4_TYPE_AUTO_COMMIT_SQL);
out.extend_from_slice(&commit_lsn.to_le_bytes());
out.extend_from_slice(&commit_unix_us.to_le_bytes());
out.extend_from_slice(payload);
out
}
fn replay_wal_into_engine(wal_bytes: &[u8], engine: &mut Engine) -> Result<usize, String> {
let mut applied = 0usize;
let mut cur = 0usize;
while cur < wal_bytes.len() {
if wal_bytes.len() - cur < 4 {
break;
}
let raw_len = u32::from_le_bytes(wal_bytes[cur..cur + 4].try_into().unwrap());
let is_v2 = raw_len & WAL_V2_SENTINEL != 0;
let is_v3 = is_v2 && (raw_len & WAL_V3_FLAG != 0);
let len_mask = if is_v3 {
!(WAL_V2_SENTINEL | WAL_V3_FLAG)
} else {
!WAL_V2_SENTINEL
};
let rec_len = (raw_len & len_mask) as usize;
let header_len = if is_v3 {
9
} else if is_v2 {
8
} else {
4
};
if wal_bytes.len() - cur < header_len + rec_len {
break;
}
if is_v3 {
let type_byte = wal_bytes[cur + 8];
match type_byte {
WAL_V3_TYPE_AUTO_COMMIT_SQL => {}
WAL_V3_TYPE_DURABILITY_CHECKPOINT => {
cur += header_len + rec_len;
continue;
}
WAL_V4_TYPE_CHECKPOINT_MARKER => {
cur += header_len + rec_len;
continue;
}
WAL_V4_TYPE_AUTO_COMMIT_SQL => {
let v4_total = header_len + WAL_V4_EXTRA_HEADER + rec_len;
if wal_bytes.len() - cur < v4_total {
break;
}
let sql_start = cur + header_len + WAL_V4_EXTRA_HEADER;
let sql_bytes = &wal_bytes[sql_start..sql_start + rec_len];
let sql = std::str::from_utf8(sql_bytes).map_err(|e| {
format!("WAL replay: non-UTF-8 SQL at offset {cur}: {e}")
})?;
engine.execute(sql).map_err(|e| {
format!(
"WAL replay: apply {sql:?} at offset {cur} rejected: {e:?}"
)
})?;
applied += 1;
cur += v4_total;
continue;
}
other => {
return Err(format!(
"WAL replay: unknown v3 type byte {other:#04x} at offset {cur}"
));
}
}
}
let sql_bytes = &wal_bytes[cur + header_len..cur + header_len + rec_len];
let sql = std::str::from_utf8(sql_bytes)
.map_err(|e| format!("WAL replay: non-UTF-8 SQL at offset {cur}: {e}"))?;
engine
.execute(sql)
.map_err(|e| format!("WAL replay: apply {sql:?} at offset {cur} rejected: {e:?}"))?;
applied += 1;
cur += header_len + rec_len;
}
Ok(applied)
}
#[derive(Debug, Clone)]
pub struct WalRecord<'a> {
pub offset: usize,
pub type_byte: u8,
pub commit_lsn: Option<u64>,
pub commit_unix_us: Option<i64>,
pub sql: &'a [u8],
}
pub fn parse_wal_records(wal_bytes: &[u8]) -> Result<Vec<WalRecord<'_>>, String> {
let mut out = Vec::new();
let mut cur = 0usize;
while cur < wal_bytes.len() {
if wal_bytes.len() - cur < 4 {
break;
}
let raw_len = u32::from_le_bytes(wal_bytes[cur..cur + 4].try_into().unwrap());
let is_v2 = raw_len & WAL_V2_SENTINEL != 0;
let is_v3 = is_v2 && (raw_len & WAL_V3_FLAG != 0);
let len_mask = if is_v3 {
!(WAL_V2_SENTINEL | WAL_V3_FLAG)
} else {
!WAL_V2_SENTINEL
};
let rec_len = (raw_len & len_mask) as usize;
let header_len = if is_v3 {
9
} else if is_v2 {
8
} else {
4
};
if wal_bytes.len() - cur < header_len + rec_len {
break;
}
if !is_v3 {
let sql = &wal_bytes[cur + header_len..cur + header_len + rec_len];
out.push(WalRecord {
offset: cur,
type_byte: WAL_V3_TYPE_AUTO_COMMIT_SQL,
commit_lsn: None,
commit_unix_us: None,
sql,
});
cur += header_len + rec_len;
continue;
}
let type_byte = wal_bytes[cur + 8];
match type_byte {
WAL_V3_TYPE_AUTO_COMMIT_SQL => {
let sql = &wal_bytes[cur + header_len..cur + header_len + rec_len];
out.push(WalRecord {
offset: cur,
type_byte,
commit_lsn: None,
commit_unix_us: None,
sql,
});
cur += header_len + rec_len;
}
WAL_V3_TYPE_DURABILITY_CHECKPOINT => {
out.push(WalRecord {
offset: cur,
type_byte,
commit_lsn: None,
commit_unix_us: None,
sql: &[],
});
cur += header_len + rec_len;
}
WAL_V4_TYPE_CHECKPOINT_MARKER => {
if rec_len < 18 {
return Err(format!(
"WAL parse: checkpoint marker at offset {cur} too short ({rec_len} bytes)"
));
}
let lsn = u64::from_le_bytes(
wal_bytes[cur + header_len..cur + header_len + 8]
.try_into()
.unwrap(),
);
let ts_raw = i64::from_le_bytes(
wal_bytes[cur + header_len + 8..cur + header_len + 16]
.try_into()
.unwrap(),
);
let path_len = u16::from_le_bytes(
wal_bytes[cur + header_len + 16..cur + header_len + 18]
.try_into()
.unwrap(),
) as usize;
if rec_len < 18 + path_len {
return Err(format!(
"WAL parse: checkpoint marker at offset {cur} truncated path"
));
}
let path_start = cur + header_len + 18;
let path_bytes = &wal_bytes[path_start..path_start + path_len];
let commit_unix_us = if ts_raw == WAL_V4_NO_CLOCK {
None
} else {
Some(ts_raw)
};
out.push(WalRecord {
offset: cur,
type_byte,
commit_lsn: Some(lsn),
commit_unix_us,
sql: path_bytes,
});
cur += header_len + rec_len;
}
WAL_V4_TYPE_AUTO_COMMIT_SQL => {
let v4_total = header_len + WAL_V4_EXTRA_HEADER + rec_len;
if wal_bytes.len() - cur < v4_total {
break;
}
let lsn = u64::from_le_bytes(
wal_bytes[cur + header_len..cur + header_len + 8]
.try_into()
.unwrap(),
);
let ts_raw = i64::from_le_bytes(
wal_bytes[cur + header_len + 8..cur + header_len + 16]
.try_into()
.unwrap(),
);
let commit_unix_us = if ts_raw == WAL_V4_NO_CLOCK {
None
} else {
Some(ts_raw)
};
let sql_start = cur + header_len + WAL_V4_EXTRA_HEADER;
let sql = &wal_bytes[sql_start..sql_start + rec_len];
out.push(WalRecord {
offset: cur,
type_byte,
commit_lsn: Some(lsn),
commit_unix_us,
sql,
});
cur += v4_total;
}
other => {
return Err(format!(
"WAL parse: unknown type byte {other:#04x} at offset {cur}"
));
}
}
}
Ok(out)
}
fn sql_is_read_only(sql: &str) -> bool {
let t = sql.trim_start();
let head = t
.split(|c: char| c.is_whitespace() || c == ';' || c == '(')
.next()
.unwrap_or("");
matches!(
head.to_ascii_lowercase().as_str(),
"select"
| "show"
| "explain"
| "begin"
| "commit"
| "rollback"
| "checkpoint"
| "compact"
| "wait"
| "with"
)
}
#[derive(Debug)]
pub struct Database {
engine: Engine,
persistence: Option<PersistenceCtx>,
commit_lsn: AtomicU64,
}
#[derive(Debug)]
#[allow(dead_code)] struct PersistenceCtx {
db_path: PathBuf,
wal_path: PathBuf,
wal: File,
wal_len: u64,
checkpoint_threshold_bytes: u64,
cold_segments_dir: PathBuf,
cold_segment_paths: BTreeMap<u32, PathBuf>,
lock_path: PathBuf,
}
impl Database {
#[must_use]
pub fn open_in_memory() -> Self {
Self {
engine: Engine::new().with_clock(wall_clock_micros),
persistence: None,
commit_lsn: AtomicU64::new(0),
}
}
pub fn open_path(db_path: impl AsRef<Path>) -> Result<Self, EngineError> {
let db_path = db_path.as_ref().to_path_buf();
let wal_path = {
let mut p = db_path.clone();
let name = p
.file_name()
.map(|n| {
let mut s = n.to_os_string();
s.push(".wal");
s
})
.unwrap_or_else(|| std::ffi::OsString::from(".wal"));
p.set_file_name(name);
p
};
if let Some(parent) = db_path.parent()
&& !parent.as_os_str().is_empty()
{
std::fs::create_dir_all(parent).map_err(io_err)?;
}
let lock_path = {
let mut p = db_path.clone();
let name = p
.file_name()
.map(|n| {
let mut s = n.to_os_string();
s.push(".lock");
s
})
.unwrap_or_else(|| std::ffi::OsString::from(".lock"));
p.set_file_name(name);
p
};
std::fs::create_dir(&lock_path).map_err(|e| {
if e.kind() == std::io::ErrorKind::AlreadyExists {
EngineError::Unsupported(format!(
"database is locked by another process (or stale lock): {}; \
remove the directory manually after confirming no other \
process holds it, or call Database::force_unlock()",
lock_path.display()
))
} else {
io_err(e)
}
})?;
let mut engine = if db_path.exists() {
let bytes = std::fs::read(&db_path).map_err(io_err)?;
let engine = Engine::restore_envelope(&bytes).map_err(|e| {
EngineError::Storage(spg_storage::StorageError::Corrupt(format!(
"restore from {}: {e}",
db_path.display()
)))
})?;
engine.with_clock(wall_clock_micros)
} else {
Engine::new().with_clock(wall_clock_micros)
};
let cold_segments_dir = {
let parent = db_path.parent().unwrap_or_else(|| Path::new("."));
let stem = db_path
.file_stem()
.unwrap_or_else(|| std::ffi::OsStr::new("db"))
.to_string_lossy()
.into_owned();
parent.join(format!("{stem}.spg")).join("segments")
};
let mut cold_segment_paths: BTreeMap<u32, PathBuf> = BTreeMap::new();
let manifest_pth = spg_manifest_path(&db_path);
if manifest_pth.exists() && db_path.exists() {
let m_bytes = std::fs::read(&manifest_pth).map_err(io_err)?;
if let Ok(m) = CatalogManifest::deserialize(&m_bytes) {
let snap_bytes = std::fs::read(&db_path).map_err(io_err)?;
let snap_crc = spg_crypto::crc32::crc32(&snap_bytes);
if snap_crc == m.catalog_crc32 {
for entry in &m.cold_segments {
if let Ok(seg_bytes) = std::fs::read(&entry.path) {
let computed = spg_crypto::crc32::crc32(&seg_bytes);
if computed != entry.crc32 {
eprintln!(
"spg-embedded: manifest skip segment {}: CRC mismatch",
entry.segment_id
);
continue;
}
if engine.catalog().cold_segment(entry.segment_id).is_some() {
continue;
}
let mut new_cat = engine.catalog().clone();
if let Err(e) =
new_cat.load_segment_bytes_at(entry.segment_id, seg_bytes)
{
eprintln!(
"spg-embedded: manifest load segment {} failed: {e}",
entry.segment_id
);
continue;
}
engine.replace_catalog(new_cat);
cold_segment_paths.insert(entry.segment_id, entry.path.clone());
} else {
eprintln!(
"spg-embedded: manifest skip segment {}: file unreadable",
entry.segment_id
);
}
}
}
}
}
let mut initial_lsn: u64 = 0;
if wal_path.exists() {
let wal_bytes = std::fs::read(&wal_path).map_err(io_err)?;
if !wal_bytes.is_empty() {
replay_wal_into_engine(&wal_bytes, &mut engine)
.map_err(|m| EngineError::Storage(spg_storage::StorageError::Corrupt(m)))?;
if let Ok(records) = parse_wal_records(&wal_bytes) {
if let Some(max) = records.iter().filter_map(|r| r.commit_lsn).max() {
initial_lsn = max;
}
}
}
}
let wal = OpenOptions::new()
.create(true)
.append(true)
.read(true)
.open(&wal_path)
.map_err(io_err)?;
let wal_len = wal.metadata().map_err(io_err)?.len();
Ok(Self {
engine,
commit_lsn: AtomicU64::new(initial_lsn),
persistence: Some(PersistenceCtx {
db_path,
wal_path,
wal,
wal_len,
checkpoint_threshold_bytes: default_checkpoint_threshold_bytes(),
cold_segments_dir,
cold_segment_paths,
lock_path,
}),
})
}
pub fn freeze_oldest_to_cold(
&mut self,
table_name: &str,
index_name: &str,
max_rows: usize,
) -> Result<spg_storage::FreezeReport, EngineError> {
let report = self
.engine
.freeze_oldest_to_cold(table_name, index_name, max_rows)?;
if let Some(p) = &mut self.persistence {
std::fs::create_dir_all(&p.cold_segments_dir).map_err(io_err)?;
let final_path = p
.cold_segments_dir
.join(format!("seg_{}.spg", report.segment_id));
let tmp_path = p
.cold_segments_dir
.join(format!("seg_{}.spg.tmp", report.segment_id));
std::fs::write(&tmp_path, &report.segment_bytes).map_err(io_err)?;
std::fs::rename(&tmp_path, &final_path).map_err(io_err)?;
p.cold_segment_paths.insert(report.segment_id, final_path);
}
Ok(report)
}
pub fn set_checkpoint_threshold_bytes(&mut self, bytes: u64) {
if let Some(p) = &mut self.persistence {
p.checkpoint_threshold_bytes = bytes.max(1);
}
}
pub fn checkpoint(&mut self) -> Result<(), EngineError> {
let snapshot = self.engine.snapshot();
let Some(p) = &mut self.persistence else {
return Ok(());
};
let tmp = {
let mut t = p.db_path.clone();
let mut name = t
.file_name()
.map(std::ffi::OsStr::to_os_string)
.unwrap_or_default();
name.push(".tmp");
t.set_file_name(name);
t
};
std::fs::write(&tmp, &snapshot).map_err(io_err)?;
std::fs::rename(&tmp, &p.db_path).map_err(io_err)?;
if !p.cold_segment_paths.is_empty() {
let snap_crc = spg_crypto::crc32::crc32(&snapshot);
let entries: Vec<ColdSegmentEntry> = p
.cold_segment_paths
.iter()
.filter_map(|(&segment_id, path)| {
let bytes = std::fs::read(path).ok()?;
Some(ColdSegmentEntry {
segment_id,
path: path.clone(),
crc32: spg_crypto::crc32::crc32(&bytes),
})
})
.collect();
let manifest = CatalogManifest {
catalog_crc32: snap_crc,
cold_segments: entries,
wal_baseline_offset: 0,
};
let m_bytes = manifest.serialize();
let m_path = spg_manifest_path(&p.db_path);
if let Some(dir) = m_path.parent() {
std::fs::create_dir_all(dir).map_err(io_err)?;
}
let m_tmp = {
let mut t = m_path.clone();
let mut name = t
.file_name()
.map(std::ffi::OsStr::to_os_string)
.unwrap_or_default();
name.push(".tmp");
t.set_file_name(name);
t
};
std::fs::write(&m_tmp, &m_bytes).map_err(io_err)?;
std::fs::rename(&m_tmp, &m_path).map_err(io_err)?;
}
let marker_lsn = self.commit_lsn.load(Ordering::SeqCst);
let marker_ts = wall_clock_micros();
let marker = encode_v4_checkpoint_marker(marker_lsn, marker_ts, &p.db_path);
p.wal.write_all(&marker).map_err(io_err)?;
p.wal.sync_data().map_err(io_err)?;
p.wal.set_len(0).map_err(io_err)?;
p.wal.seek(SeekFrom::Start(0)).map_err(io_err)?;
p.wal.sync_data().map_err(io_err)?;
p.wal_len = 0;
Ok(())
}
pub fn restore(snapshot: &[u8]) -> Result<Self, EngineError> {
let engine = Engine::restore_envelope(snapshot).map_err(|e| {
EngineError::Storage(spg_storage::StorageError::Corrupt(format!("restore: {e}")))
})?;
Ok(Self {
engine,
persistence: None,
commit_lsn: AtomicU64::new(0),
})
}
#[must_use]
pub fn snapshot(&self) -> Vec<u8> {
self.engine.snapshot()
}
pub fn execute(&mut self, sql: &str) -> Result<QueryResult, EngineError> {
let result = self.engine.execute(sql)?;
if self.persistence.is_some()
&& !sql_is_read_only(sql)
&& matches!(
&result,
QueryResult::CommandOk {
modified_catalog: true,
..
}
)
{
let lsn = self.commit_lsn.fetch_add(1, Ordering::SeqCst) + 1;
let ts = wall_clock_micros();
let record = encode_v4_auto_commit(sql, lsn, ts);
let p = self.persistence.as_mut().expect("checked above");
p.wal.write_all(&record).map_err(io_err)?;
p.wal.sync_data().map_err(io_err)?;
p.wal_len = p.wal_len.saturating_add(record.len() as u64);
if p.wal_len >= p.checkpoint_threshold_bytes {
self.checkpoint()?;
}
}
Ok(result)
}
pub fn query_typed<T: FromSpgRow>(&mut self, sql: &str) -> Result<Vec<T>, EngineError> {
let rows = self.query(sql)?;
rows.into_iter().map(|r| T::from_spg_row(&r)).collect()
}
pub fn query(&mut self, sql: &str) -> Result<Vec<Vec<Value>>, EngineError> {
match self.engine.execute(sql)? {
QueryResult::Rows { rows, .. } => Ok(rows.into_iter().map(|r| r.values).collect()),
QueryResult::CommandOk { .. } => Err(EngineError::Unsupported(
"query() expects a SELECT — use execute() for DML/DDL".into(),
)),
_ => Err(EngineError::Unsupported(
"query() expects a SELECT — use execute() for DML/DDL".into(),
)),
}
}
pub fn query_with_columns(
&mut self,
sql: &str,
) -> Result<(Vec<spg_storage::ColumnSchema>, Vec<Vec<Value>>), EngineError> {
match self.engine.execute(sql)? {
QueryResult::Rows { columns, rows } => {
Ok((columns, rows.into_iter().map(|r| r.values).collect()))
}
QueryResult::CommandOk { .. } => Err(EngineError::Unsupported(
"query_with_columns() expects a SELECT — use execute() for DML/DDL".into(),
)),
_ => Err(EngineError::Unsupported(
"query_with_columns() expects a SELECT — use execute() for DML/DDL".into(),
)),
}
}
pub fn query_prepared_with_columns(
&mut self,
stmt: &Statement,
params: &[Value],
) -> Result<(Vec<spg_storage::ColumnSchema>, Vec<Vec<Value>>), EngineError> {
match self.engine.execute_prepared(stmt.stmt.clone(), params)? {
QueryResult::Rows { columns, rows } => {
Ok((columns, rows.into_iter().map(|r| r.values).collect()))
}
QueryResult::CommandOk { .. } => Err(EngineError::Unsupported(
"query_prepared_with_columns() expects a SELECT — use execute_prepared() for DML/DDL".into(),
)),
_ => Err(EngineError::Unsupported(
"query_prepared_with_columns() expects a SELECT — use execute_prepared() for DML/DDL".into(),
)),
}
}
#[must_use]
pub const fn engine(&self) -> &Engine {
&self.engine
}
pub const fn engine_mut(&mut self) -> &mut Engine {
&mut self.engine
}
pub fn prepare(&mut self, sql: &str) -> Result<Statement, EngineError> {
let stmt = self
.engine
.prepare_cached(sql)
.map_err(EngineError::Parse)?;
Ok(Statement {
stmt,
sql: sql.to_string(),
})
}
pub fn describe(&mut self, sql: &str) -> Result<(Vec<u32>, Vec<ColumnSchema>), EngineError> {
let stmt = self
.engine
.prepare_cached(sql)
.map_err(EngineError::Parse)?;
Ok(self.engine.describe_prepared(&stmt))
}
pub fn execute_prepared(
&mut self,
stmt: &Statement,
params: &[Value],
) -> Result<QueryResult, EngineError> {
let result = self.engine.execute_prepared(stmt.stmt.clone(), params)?;
if self.persistence.is_some()
&& matches!(
&result,
QueryResult::CommandOk {
modified_catalog: true,
..
}
)
{
let mut wal_stmt = stmt.stmt.clone();
crate::wal_render_with_params(&mut wal_stmt, params);
let canonical = format!("{wal_stmt}");
let lsn = self.commit_lsn.fetch_add(1, Ordering::SeqCst) + 1;
let ts = wall_clock_micros();
let record = encode_v4_auto_commit(&canonical, lsn, ts);
let p = self.persistence.as_mut().expect("checked above");
p.wal.write_all(&record).map_err(io_err)?;
p.wal.sync_data().map_err(io_err)?;
p.wal_len = p.wal_len.saturating_add(record.len() as u64);
if p.wal_len >= p.checkpoint_threshold_bytes {
self.checkpoint()?;
}
}
Ok(result)
}
pub fn query_prepared(
&mut self,
stmt: &Statement,
params: &[Value],
) -> Result<Vec<Vec<Value>>, EngineError> {
match self.engine.execute_prepared(stmt.stmt.clone(), params)? {
QueryResult::Rows { rows, .. } => Ok(rows.into_iter().map(|r| r.values).collect()),
QueryResult::CommandOk { .. } => Err(EngineError::Unsupported(
"query_prepared() expects a SELECT — use execute_prepared() for DML/DDL".into(),
)),
_ => Err(EngineError::Unsupported(
"query_prepared() expects a SELECT — use execute_prepared() for DML/DDL".into(),
)),
}
}
pub fn prepare_on_snapshot(
snapshot: &CatalogSnapshot,
sql: &str,
) -> Result<Statement, EngineError> {
let stmt = spg_engine::Engine::prepare_on_snapshot(snapshot, sql)
.map_err(EngineError::Parse)?;
Ok(Statement {
stmt,
sql: sql.to_string(),
})
}
pub fn execute_prepared_on_snapshot(
snapshot: &CatalogSnapshot,
stmt: &Statement,
params: &[Value],
) -> Result<QueryResult, EngineError> {
spg_engine::Engine::execute_readonly_prepared_on_snapshot(
snapshot,
stmt.stmt.clone(),
params,
)
}
pub fn describe_on_snapshot(
snapshot: &CatalogSnapshot,
sql: &str,
) -> Result<(Vec<u32>, Vec<ColumnSchema>), EngineError> {
let stmt = spg_engine::Engine::prepare_on_snapshot(snapshot, sql)
.map_err(EngineError::Parse)?;
Ok(spg_engine::Engine::describe_prepared_on_snapshot(
snapshot, &stmt,
))
}
pub fn with_transaction<R, F>(&mut self, body: F) -> Result<R, EngineError>
where
F: FnOnce(&mut Self) -> Result<R, EngineError>,
{
self.execute("BEGIN")?;
match body(self) {
Ok(value) => {
self.execute("COMMIT")?;
Ok(value)
}
Err(e) => {
let _ = self.execute("ROLLBACK");
Err(e)
}
}
}
}
impl Default for Database {
fn default() -> Self {
Self::open_in_memory()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub struct EmbeddedMetrics {
pub hot_rows: u64,
pub hot_bytes: u64,
pub cold_segments: u64,
pub tables: u64,
pub wal_bytes: u64,
pub persistent: bool,
}
#[must_use = "the background freezer keeps running until this handle is dropped"]
#[derive(Debug)]
pub struct FreezerHandle {
shutdown: Arc<AtomicBool>,
join: Option<JoinHandle<()>>,
}
impl FreezerHandle {
pub fn stop(&mut self) {
self.shutdown.store(true, Ordering::Release);
if let Some(h) = self.join.take() {
let _ = h.join();
}
}
}
impl Drop for FreezerHandle {
fn drop(&mut self) {
self.stop();
}
}
#[derive(Debug, Clone)]
pub struct FreezerOptions {
pub tick: Duration,
pub hot_tier_bytes: u64,
pub batch_rows: usize,
pub compact_when_segments_exceed: usize,
pub compact_target_bytes: u64,
}
impl Default for FreezerOptions {
fn default() -> Self {
Self {
tick: Duration::from_secs(1),
hot_tier_bytes: 4 * 1024 * 1024 * 1024,
batch_rows: 1000,
compact_when_segments_exceed: 64,
compact_target_bytes: 64 * 1024 * 1024,
}
}
}
impl Database {
#[must_use]
pub fn cold_segment_count(&self) -> usize {
self.engine.catalog().cold_segment_count()
}
#[must_use]
pub fn metrics(&self) -> EmbeddedMetrics {
let cat = self.engine.catalog();
let mut hot_rows: u64 = 0;
let mut hot_bytes: u64 = 0;
for name in cat.table_names() {
if let Some(t) = cat.get(&name) {
hot_rows = hot_rows.saturating_add(t.row_count() as u64);
hot_bytes = hot_bytes.saturating_add(t.hot_bytes());
}
}
let (wal_bytes, persistent) = match &self.persistence {
Some(p) => (p.wal_len, true),
None => (0, false),
};
EmbeddedMetrics {
hot_rows,
hot_bytes,
cold_segments: cat.cold_segment_count() as u64,
tables: cat.table_count() as u64,
wal_bytes,
persistent,
}
}
pub fn spawn_background_freezer(
db: Arc<Mutex<Database>>,
opts: FreezerOptions,
) -> FreezerHandle {
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_for_thread = Arc::clone(&shutdown);
let join = thread::Builder::new()
.name("spg-embedded-freezer".into())
.spawn(move || {
background_freezer_loop(db, opts, shutdown_for_thread);
})
.expect("spawn background freezer thread");
FreezerHandle {
shutdown,
join: Some(join),
}
}
}
fn background_freezer_loop(
db: Arc<Mutex<Database>>,
opts: FreezerOptions,
shutdown: Arc<AtomicBool>,
) {
let slice = Duration::from_millis(50.min(opts.tick.as_millis() as u64));
let mut last_tick = std::time::Instant::now();
loop {
if shutdown.load(Ordering::Acquire) {
return;
}
thread::sleep(slice);
if last_tick.elapsed() < opts.tick {
continue;
}
last_tick = std::time::Instant::now();
let Ok(mut guard) = db.lock() else {
return;
};
if guard.engine.catalog().hot_tier_bytes() <= opts.hot_tier_bytes {
continue;
}
let Some((table, index)) = pick_freeze_target(&guard) else {
continue;
};
let row_count = guard
.engine
.catalog()
.get(&table)
.map_or(0, spg_storage::Table::row_count);
let to_freeze = opts.batch_rows.min(row_count);
if to_freeze == 0 {
continue;
}
if let Err(e) = guard.freeze_oldest_to_cold(&table, &index, to_freeze) {
eprintln!("spg-embedded: background freeze on {table}.{index} failed: {e:?}");
continue;
}
let count = guard.engine.catalog().cold_segment_count();
if count > opts.compact_when_segments_exceed {
if let Err(e) = guard
.engine
.compact_cold_segments_with_target(opts.compact_target_bytes)
{
eprintln!(
"spg-embedded: background compact failed (segments={count}, \
threshold={}): {e:?}",
opts.compact_when_segments_exceed,
);
}
}
}
}
fn pick_freeze_target(db: &Database) -> Option<(String, String)> {
let cat = db.engine.catalog();
let mut best: Option<(String, String, u64)> = None;
for name in cat.table_names() {
let Some(t) = cat.get(&name) else { continue };
if t.row_count() == 0 {
continue;
}
let cols = &t.schema().columns;
let Some(idx) = t.indices().iter().find(|i| {
matches!(i.kind, spg_storage::IndexKind::BTree(_))
&& i.column_position < cols.len()
&& matches!(
cols[i.column_position].ty,
spg_storage::DataType::SmallInt
| spg_storage::DataType::Int
| spg_storage::DataType::BigInt
)
}) else {
continue;
};
let hot = t.hot_bytes();
match best {
None => best = Some((name, idx.name.clone(), hot)),
Some((_, _, best_hot)) if hot > best_hot => {
best = Some((name, idx.name.clone(), hot));
}
_ => {}
}
}
best.map(|(t, i, _)| (t, i))
}
pub fn revert_wal_to_seq(
wal_path: impl AsRef<Path>,
to_seq: u64,
out_db_path: impl AsRef<Path>,
) -> Result<u64, EngineError> {
let wal_bytes = std::fs::read(wal_path.as_ref()).map_err(io_err)?;
let mut engine = Engine::new();
let mut applied = 0u64;
let mut cur = 0usize;
while cur < wal_bytes.len() && applied < to_seq {
let (sql_bytes, total) = decode_wal_record(&wal_bytes[cur..])?;
cur += total;
if sql_bytes.is_empty() {
continue;
}
let sql = core::str::from_utf8(&sql_bytes).map_err(|e| {
EngineError::Storage(spg_storage::StorageError::Corrupt(format!(
"WAL record at offset {cur}: non-UTF-8 SQL: {e}"
)))
})?;
engine.execute(sql)?;
applied += 1;
}
let snapshot = engine.snapshot();
std::fs::write(out_db_path.as_ref(), &snapshot).map_err(io_err)?;
Ok(applied)
}
fn decode_wal_record(tail: &[u8]) -> Result<(Vec<u8>, usize), EngineError> {
if tail.len() < 4 {
return Err(EngineError::Storage(spg_storage::StorageError::Corrupt(
format!("WAL truncated record: {} < 4 header bytes", tail.len()),
)));
}
let raw_len = u32::from_le_bytes(tail[..4].try_into().unwrap());
let is_v2 = raw_len & WAL_V2_SENTINEL != 0;
let is_v3 = is_v2 && (raw_len & WAL_V3_FLAG != 0);
let len_mask = if is_v3 {
!(WAL_V2_SENTINEL | WAL_V3_FLAG)
} else {
!WAL_V2_SENTINEL
};
let rec_len = (raw_len & len_mask) as usize;
let header_len = if is_v3 {
9
} else if is_v2 {
8
} else {
4
};
if tail.len() < header_len + rec_len {
return Err(EngineError::Storage(spg_storage::StorageError::Corrupt(
format!(
"WAL truncated record: header+payload {} > available {}",
header_len + rec_len,
tail.len()
),
)));
}
if is_v3 {
let type_byte = tail[8];
if type_byte == WAL_V3_TYPE_AUTO_COMMIT_SQL {
let payload = &tail[header_len..header_len + rec_len];
return Ok((payload.to_vec(), header_len + rec_len));
}
if type_byte == WAL_V4_TYPE_AUTO_COMMIT_SQL {
let v4_total = header_len + WAL_V4_EXTRA_HEADER + rec_len;
if tail.len() < v4_total {
return Err(EngineError::Storage(spg_storage::StorageError::Corrupt(
format!(
"WAL truncated v4 record: header+payload {v4_total} > available {}",
tail.len()
),
)));
}
let sql_start = header_len + WAL_V4_EXTRA_HEADER;
let sql_bytes = tail[sql_start..sql_start + rec_len].to_vec();
return Ok((sql_bytes, v4_total));
}
return Ok((Vec::new(), header_len + rec_len));
}
let payload = &tail[header_len..header_len + rec_len];
Ok((payload.to_vec(), header_len + rec_len))
}
impl Drop for Database {
fn drop(&mut self) {
if self.persistence.is_some() {
if let Err(e) = self.checkpoint() {
eprintln!(
"spg-embedded: final checkpoint on Drop failed: {e:?} \
(WAL is intact; next open_path will replay)"
);
}
}
if let Some(ctx) = &self.persistence
&& ctx.lock_path.exists()
{
if let Err(e) = std::fs::remove_dir(&ctx.lock_path) {
eprintln!(
"spg-embedded: lock release on Drop failed for {}: {e:?}",
ctx.lock_path.display()
);
}
}
}
}
impl Database {
pub fn force_unlock(db_path: impl AsRef<Path>) -> Result<(), EngineError> {
let lock_path = {
let mut p = db_path.as_ref().to_path_buf();
let name = p
.file_name()
.map(|n| {
let mut s = n.to_os_string();
s.push(".lock");
s
})
.unwrap_or_else(|| std::ffi::OsString::from(".lock"));
p.set_file_name(name);
p
};
if !lock_path.exists() {
return Ok(());
}
std::fs::remove_dir(&lock_path).map_err(io_err)
}
}
fn io_err(e: std::io::Error) -> EngineError {
EngineError::Storage(spg_storage::StorageError::Corrupt(format!("io: {e}")))
}
#[allow(dead_code)]
fn _database_is_send() {
fn assert_send<T: Send>() {}
assert_send::<Database>();
}
pub trait FromSpgRow: Sized {
fn from_spg_row(row: &[Value]) -> Result<Self, EngineError>;
}
#[macro_export]
macro_rules! spg_row {
(
$(#[$meta:meta])*
$vis:vis struct $name:ident {
$(
$(#[$fmeta:meta])*
$fvis:vis $field:ident : $ty:ty,
)*
}
) => {
$(#[$meta])*
#[derive(Debug, Clone)]
$vis struct $name {
$(
$(#[$fmeta])*
$fvis $field : $ty,
)*
}
impl $crate::FromSpgRow for $name {
fn from_spg_row(row: &[$crate::Value]) -> ::core::result::Result<Self, $crate::EngineError> {
let mut __spg_row_iter = row.iter();
$(
let $field: $ty = {
let v = __spg_row_iter
.next()
.ok_or_else(|| $crate::EngineError::Unsupported(
::std::format!(
"spg_row! {}: missing column for field `{}`",
::core::stringify!($name),
::core::stringify!($field)
)
))?;
<$ty as $crate::FromSpgValue>::from_spg_value(v)
.map_err(|e| $crate::EngineError::Unsupported(
::std::format!(
"spg_row! {}: column `{}`: {}",
::core::stringify!($name),
::core::stringify!($field),
e
)
))?
};
)*
Ok(Self { $($field,)* })
}
}
};
}
pub trait FromSpgValue: Sized {
fn from_spg_value(v: &Value) -> Result<Self, &'static str>;
}
macro_rules! impl_from_value_int {
($($t:ty),* $(,)?) => {
$(
impl FromSpgValue for $t {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::SmallInt(n) => <$t>::try_from(*n).map_err(|_| "SmallInt does not fit target int type"),
Value::Int(n) => <$t>::try_from(*n).map_err(|_| "Int does not fit target int type"),
Value::BigInt(n) => <$t>::try_from(*n).map_err(|_| "BigInt does not fit target int type"),
Value::Null => Err("NULL in non-Option int column"),
_ => Err("non-integer value in int column"),
}
}
}
)*
};
}
impl_from_value_int!(i16, i32, i64);
impl FromSpgValue for f32 {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Float(f) => Ok(*f as f32),
Value::Null => Err("NULL in non-Option float column"),
_ => Err("non-float value in float column"),
}
}
}
impl FromSpgValue for f64 {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Float(f) => Ok(*f),
Value::Null => Err("NULL in non-Option float column"),
_ => Err("non-float value in float column"),
}
}
}
impl FromSpgValue for bool {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Bool(b) => Ok(*b),
Value::Null => Err("NULL in non-Option bool column"),
_ => Err("non-bool value in bool column"),
}
}
}
impl FromSpgValue for String {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Text(s) => Ok(s.clone()),
Value::Null => Err("NULL in non-Option text column"),
_ => Err("non-text value in String column"),
}
}
}
impl FromSpgValue for Vec<f32> {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Vector(xs) => Ok(xs.clone()),
Value::Null => Err("NULL in non-Option vector column"),
_ => Err("non-vector value in Vec<f32> column"),
}
}
}
impl<T: FromSpgValue> FromSpgValue for Option<T> {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Null => Ok(None),
other => T::from_spg_value(other).map(Some),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn in_memory_create_insert_select() {
let mut db = Database::open_in_memory();
db.execute("CREATE TABLE t (id INT NOT NULL, name TEXT)")
.unwrap();
db.execute("INSERT INTO t VALUES (1, 'alice')").unwrap();
db.execute("INSERT INTO t VALUES (2, 'bob')").unwrap();
let rows = db.query("SELECT id FROM t WHERE id = 1").unwrap();
assert_eq!(rows.len(), 1);
match &rows[0][0] {
Value::Int(1) => {}
other => panic!("expected Int(1), got {other:?}"),
}
}
#[test]
fn query_on_non_select_errors() {
let mut db = Database::open_in_memory();
db.execute("CREATE TABLE t (id INT)").unwrap();
let r = db.query("INSERT INTO t VALUES (1)");
assert!(r.is_err(), "query() on INSERT must error");
}
#[test]
fn snapshot_roundtrip() {
let mut db = Database::open_in_memory();
db.execute("CREATE TABLE t (id INT NOT NULL)").unwrap();
db.execute("INSERT INTO t VALUES (42)").unwrap();
let bytes = db.snapshot();
let mut restored = Database::restore(&bytes).unwrap();
let rows = restored.query("SELECT id FROM t WHERE id = 42").unwrap();
assert_eq!(rows.len(), 1);
match &rows[0][0] {
Value::Int(42) => {}
other => panic!("expected Int(42), got {other:?}"),
}
}
#[test]
fn from_spg_row_trait_shape() {
struct User {
_id: i32,
}
impl FromSpgRow for User {
fn from_spg_row(row: &[Value]) -> Result<Self, EngineError> {
match row.first() {
Some(Value::Int(n)) => Ok(Self { _id: *n }),
_ => Err(EngineError::Unsupported("bad id".into())),
}
}
}
let row = vec![Value::Int(7)];
let _u = User::from_spg_row(&row).unwrap();
}
}