#![deny(missing_docs)]
pub use spg_engine::{CatalogSnapshot, Engine, EngineError, ParsedStatement, QueryResult};
pub use spg_storage::{ColumnSchema, DataType, Value};
#[derive(Debug, Clone)]
pub struct Statement {
pub(crate) stmt: ParsedStatement,
pub(crate) sql: String,
}
impl Statement {
#[must_use]
pub fn sql(&self) -> &str {
&self.sql
}
}
fn wal_render_with_params(stmt: &mut ParsedStatement, params: &[Value]) {
let _ = spg_engine::substitute_placeholders(stmt, params);
}
use std::collections::BTreeMap;
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::thread::{self, JoinHandle};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
fn wall_clock_micros() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| i64::try_from(d.as_micros()).unwrap_or(i64::MAX))
}
use spg_manifest::{CatalogManifest, ColdSegmentEntry, manifest_path as spg_manifest_path};
const WAL_V2_SENTINEL: u32 = 0x8000_0000;
const WAL_V3_FLAG: u32 = 0x4000_0000;
const WAL_V3_TYPE_AUTO_COMMIT_SQL: u8 = 0x01;
const WAL_V3_TYPE_DURABILITY_CHECKPOINT: u8 = 0x02;
const WAL_V4_TYPE_AUTO_COMMIT_SQL: u8 = 0x10;
const WAL_V4_NO_CLOCK: i64 = i64::MIN;
const WAL_V4_EXTRA_HEADER: usize = 16;
const WAL_V4_TYPE_CHECKPOINT_MARKER: u8 = 0x11;
const WAL_V4_TYPE_TX_COMMIT_SQL: u8 = 0x12;
fn default_checkpoint_threshold_bytes() -> u64 {
std::env::var("SPG_EMBEDDED_CHECKPOINT_BYTES")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.filter(|&n| n > 0)
.unwrap_or(4 * 1024 * 1024)
}
fn encode_v3_auto_commit(sql: &str) -> Vec<u8> {
let payload = sql.as_bytes();
let mut crc_buf = Vec::with_capacity(1 + payload.len());
crc_buf.push(WAL_V3_TYPE_AUTO_COMMIT_SQL);
crc_buf.extend_from_slice(payload);
let crc = spg_crypto::crc32::crc32(&crc_buf);
let header = ((payload.len() as u32) | WAL_V2_SENTINEL | WAL_V3_FLAG).to_le_bytes();
let mut out = Vec::with_capacity(4 + 4 + 1 + payload.len());
out.extend_from_slice(&header);
out.extend_from_slice(&crc.to_le_bytes());
out.push(WAL_V3_TYPE_AUTO_COMMIT_SQL);
out.extend_from_slice(payload);
out
}
#[derive(Debug)]
struct WalGroup {
state: Mutex<WalGroupState>,
cond: std::sync::Condvar,
file: Mutex<File>,
}
#[derive(Debug)]
struct WalGroupState {
buf: Vec<u8>,
enqueued_seq: u64,
flushed_seq: u64,
leader_active: bool,
failed: Option<String>,
written_len: u64,
}
#[derive(Debug)]
pub struct WalTicket {
group: Arc<WalGroup>,
seq: u64,
}
impl WalGroup {
fn new(file: File, initial_len: u64) -> Self {
Self {
state: Mutex::new(WalGroupState {
buf: Vec::new(),
enqueued_seq: 0,
flushed_seq: 0,
leader_active: false,
failed: None,
written_len: initial_len,
}),
cond: std::sync::Condvar::new(),
file: Mutex::new(file),
}
}
fn enqueue(&self, record: &[u8]) -> u64 {
let mut g = self.state.lock().expect("wal state poisoned");
g.buf.extend_from_slice(record);
g.enqueued_seq += 1;
g.enqueued_seq
}
fn wait_flushed(&self, seq: u64) -> Result<(), EngineError> {
let mut g = self.state.lock().expect("wal state poisoned");
loop {
if let Some(e) = &g.failed {
return Err(EngineError::Storage(spg_storage::StorageError::Corrupt(
format!("WAL poisoned by earlier flush failure: {e}"),
)));
}
if g.flushed_seq >= seq {
return Ok(());
}
if !g.leader_active {
g.leader_active = true;
drop(g);
let delay = commit_delay_us();
if delay > 0 {
std::thread::sleep(std::time::Duration::from_micros(delay));
}
let (batch, flush_to) = {
let mut g2 = self.state.lock().expect("wal state poisoned");
(core::mem::take(&mut g2.buf), g2.enqueued_seq)
};
let io_result: std::io::Result<()> = (|| {
let mut f = self.file.lock().expect("wal file poisoned");
f.write_all(&batch)?;
f.sync_data()
})();
g = self.state.lock().expect("wal state poisoned");
g.leader_active = false;
match io_result {
Ok(()) => {
g.flushed_seq = flush_to;
g.written_len = g.written_len.saturating_add(batch.len() as u64);
}
Err(e) => {
g.failed = Some(e.to_string());
}
}
self.cond.notify_all();
continue;
}
g = self.cond.wait(g).expect("wal condvar poisoned");
}
}
fn flush_now(&self) -> Result<(), EngineError> {
let mut g = self.state.lock().expect("wal state poisoned");
if let Some(e) = &g.failed {
return Err(EngineError::Storage(spg_storage::StorageError::Corrupt(
format!("WAL poisoned: {e}"),
)));
}
let batch = core::mem::take(&mut g.buf);
let flush_to = g.enqueued_seq;
if batch.is_empty() {
return Ok(());
}
drop(g);
let io: std::io::Result<()> = (|| {
let mut f = self.file.lock().expect("wal file poisoned");
f.write_all(&batch)?;
f.sync_data()
})();
let mut g = self.state.lock().expect("wal state poisoned");
match io {
Ok(()) => {
g.flushed_seq = flush_to;
g.written_len = g.written_len.saturating_add(batch.len() as u64);
self.cond.notify_all();
Ok(())
}
Err(e) => {
g.failed = Some(e.to_string());
self.cond.notify_all();
Err(io_err(e))
}
}
}
fn rotate_file(&self, new_file: File) {
let mut g = self.state.lock().expect("wal state poisoned");
let mut f = self.file.lock().expect("wal file poisoned");
*f = new_file;
g.written_len = 0;
}
fn written_len(&self) -> u64 {
let g = self.state.lock().expect("wal state poisoned");
g.written_len + g.buf.len() as u64
}
}
impl WalTicket {
pub fn wait(&self) -> Result<(), EngineError> {
if !synchronous_commit_on() {
return Ok(());
}
self.group.wait_flushed(self.seq)
}
}
fn retention_sweep_loop(
wal_dir: PathBuf,
retention_hours: u64,
check_interval: std::time::Duration,
archive_cmd: Option<String>,
shutdown: Arc<AtomicBool>,
) {
while !shutdown.load(Ordering::SeqCst) {
if let Err(e) = retention_sweep_once(&wal_dir, retention_hours, archive_cmd.as_deref()) {
eprintln!("spg-embedded: retention sweep error: {e}");
}
let mut elapsed = std::time::Duration::ZERO;
let tick = std::time::Duration::from_millis(250);
while elapsed < check_interval {
if shutdown.load(Ordering::SeqCst) {
return;
}
std::thread::sleep(tick);
elapsed += tick;
}
}
}
pub fn retention_sweep_once(
wal_dir: &Path,
retention_hours: u64,
archive_cmd: Option<&str>,
) -> std::io::Result<()> {
if !wal_dir.exists() {
return Ok(());
}
let now_us = wall_clock_micros();
let cutoff_us = (now_us as i128 - (retention_hours as i128 * 3_600 * 1_000_000)) as i64;
let chunks = sorted_wal_chunks(wal_dir)?;
for chunk in chunks {
let stem = match chunk.file_stem().and_then(|s| s.to_str()) {
Some(s) => s,
None => continue,
};
let chunk_us: i64 = stem
.split_once('_')
.and_then(|(prefix, _)| i64::from_str_radix(prefix, 16).ok())
.unwrap_or(0);
if chunk_us >= cutoff_us {
continue;
}
if let Some(cmd) = archive_cmd {
if !cmd.is_empty() {
let output = std::process::Command::new("sh")
.arg("-c")
.arg(cmd)
.arg("--")
.arg(&chunk)
.output()?;
if !output.status.success() {
eprintln!(
"spg-embedded: SPG_PITR_ARCHIVE_CMD failed for {} (exit {}); chunk stays on disk",
chunk.display(),
output.status.code().unwrap_or(-1)
);
continue;
}
}
}
if let Err(e) = std::fs::remove_file(&chunk) {
eprintln!(
"spg-embedded: retention remove {} failed: {e}",
chunk.display()
);
continue;
}
let mut cs = chunk.clone();
let mut name = cs.file_name().map(|n| n.to_os_string()).unwrap_or_default();
name.push(".checksum");
cs.set_file_name(name);
let _ = std::fs::remove_file(&cs);
}
Ok(())
}
fn commit_delay_us() -> u64 {
static CACHED: std::sync::OnceLock<u64> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
std::env::var("SPG_COMMIT_DELAY_US")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(150)
})
}
fn synchronous_commit_on() -> bool {
static CACHED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
!std::env::var("SPG_SYNCHRONOUS_COMMIT")
.map(|v| v.eq_ignore_ascii_case("off") || v == "0" || v.eq_ignore_ascii_case("false"))
.unwrap_or(false)
})
}
fn wal_writer_delay_ms() -> u64 {
static CACHED: std::sync::OnceLock<u64> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
std::env::var("SPG_WAL_WRITER_DELAY_MS")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.filter(|&n| n > 0)
.unwrap_or(200)
})
}
fn pitr_retention_hours() -> u64 {
std::env::var("SPG_PITR_RETENTION_HOURS")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0)
}
fn pitr_retention_check_sec() -> u64 {
std::env::var("SPG_PITR_RETENTION_CHECK_SEC")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.filter(|&n| n > 0)
.unwrap_or(60)
}
fn pitr_archive_cmd() -> Option<String> {
std::env::var("SPG_PITR_ARCHIVE_CMD")
.ok()
.filter(|s| !s.is_empty())
}
fn replay_wal_filtered(
wal_bytes: &[u8],
engine: &mut Engine,
floor_lsn: u64,
) -> Result<usize, String> {
let records = parse_wal_records(wal_bytes)?;
let mut applied = 0usize;
for r in &records {
if r.type_byte == WAL_V3_TYPE_DURABILITY_CHECKPOINT
|| r.type_byte == WAL_V4_TYPE_CHECKPOINT_MARKER
{
continue;
}
if r.type_byte == WAL_V4_TYPE_AUTO_COMMIT_SQL || r.type_byte == WAL_V4_TYPE_TX_COMMIT_SQL {
if let Some(lsn) = r.commit_lsn {
if lsn <= floor_lsn {
continue;
}
}
}
let sql = match std::str::from_utf8(r.sql) {
Ok(s) => s,
Err(e) => return Err(format!("non-UTF-8 SQL at offset {}: {e}", r.offset)),
};
for stmt in split_statements(sql) {
engine.execute(stmt).map_err(|e| {
format!(
"WAL replay: apply {stmt:?} at offset {} rejected: {e:?}",
r.offset
)
})?;
}
applied += 1;
}
Ok(applied)
}
fn chunk_filename(unix_us: i64, leading_lsn: u64) -> String {
let us = unix_us.max(0) as u64;
format!("{us:016x}_{leading_lsn:016x}.wal")
}
fn legacy_chunk_filename() -> String {
chunk_filename(0, 0)
}
fn sorted_wal_chunks(wal_dir: &Path) -> std::io::Result<Vec<PathBuf>> {
let mut paths = Vec::new();
let read_dir = match std::fs::read_dir(wal_dir) {
Ok(rd) => rd,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(paths),
Err(e) => return Err(e),
};
for entry in read_dir {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("wal") {
paths.push(path);
}
}
paths.sort();
Ok(paths)
}
fn encode_v4_checkpoint_marker(
checkpoint_lsn: u64,
checkpoint_unix_us: i64,
snapshot_path: &Path,
) -> Vec<u8> {
let snapshot_bytes = snapshot_path.to_string_lossy().into_owned();
let snap_payload = snapshot_bytes.as_bytes();
let snap_len_u16: u16 = snap_payload.len().min(u16::MAX as usize) as u16;
let mut payload = Vec::with_capacity(8 + 8 + 2 + snap_payload.len());
payload.extend_from_slice(&checkpoint_lsn.to_le_bytes());
payload.extend_from_slice(&checkpoint_unix_us.to_le_bytes());
payload.extend_from_slice(&snap_len_u16.to_le_bytes());
payload.extend_from_slice(&snap_payload[..snap_len_u16 as usize]);
let mut crc_buf = Vec::with_capacity(1 + payload.len());
crc_buf.push(WAL_V4_TYPE_CHECKPOINT_MARKER);
crc_buf.extend_from_slice(&payload);
let crc = spg_crypto::crc32::crc32(&crc_buf);
let header = ((payload.len() as u32) | WAL_V2_SENTINEL | WAL_V3_FLAG).to_le_bytes();
let mut out = Vec::with_capacity(4 + 4 + 1 + payload.len());
out.extend_from_slice(&header);
out.extend_from_slice(&crc.to_le_bytes());
out.push(WAL_V4_TYPE_CHECKPOINT_MARKER);
out.extend_from_slice(&payload);
out
}
fn encode_v4_auto_commit(sql: &str, commit_lsn: u64, commit_unix_us: i64) -> Vec<u8> {
encode_v4_sql_record(WAL_V4_TYPE_AUTO_COMMIT_SQL, sql, commit_lsn, commit_unix_us)
}
fn encode_v4_tx_commit(script: &str, commit_lsn: u64, commit_unix_us: i64) -> Vec<u8> {
encode_v4_sql_record(
WAL_V4_TYPE_TX_COMMIT_SQL,
script,
commit_lsn,
commit_unix_us,
)
}
fn encode_v4_sql_record(type_byte: u8, sql: &str, commit_lsn: u64, commit_unix_us: i64) -> Vec<u8> {
let payload = sql.as_bytes();
let mut crc_buf = Vec::with_capacity(1 + WAL_V4_EXTRA_HEADER + payload.len());
crc_buf.push(type_byte);
crc_buf.extend_from_slice(&commit_lsn.to_le_bytes());
crc_buf.extend_from_slice(&commit_unix_us.to_le_bytes());
crc_buf.extend_from_slice(payload);
let crc = spg_crypto::crc32::crc32(&crc_buf);
let header = ((payload.len() as u32) | WAL_V2_SENTINEL | WAL_V3_FLAG).to_le_bytes();
let mut out = Vec::with_capacity(4 + 4 + 1 + WAL_V4_EXTRA_HEADER + payload.len());
out.extend_from_slice(&header);
out.extend_from_slice(&crc.to_le_bytes());
out.push(type_byte);
out.extend_from_slice(&commit_lsn.to_le_bytes());
out.extend_from_slice(&commit_unix_us.to_le_bytes());
out.extend_from_slice(payload);
out
}
fn replay_wal_into_engine(wal_bytes: &[u8], engine: &mut Engine) -> Result<usize, String> {
let mut applied = 0usize;
let mut cur = 0usize;
while cur < wal_bytes.len() {
if wal_bytes.len() - cur < 4 {
break;
}
let raw_len = u32::from_le_bytes(wal_bytes[cur..cur + 4].try_into().unwrap());
let is_v2 = raw_len & WAL_V2_SENTINEL != 0;
let is_v3 = is_v2 && (raw_len & WAL_V3_FLAG != 0);
let len_mask = if is_v3 {
!(WAL_V2_SENTINEL | WAL_V3_FLAG)
} else {
!WAL_V2_SENTINEL
};
let rec_len = (raw_len & len_mask) as usize;
let header_len = if is_v3 {
9
} else if is_v2 {
8
} else {
4
};
if wal_bytes.len() - cur < header_len + rec_len {
break;
}
if is_v3 {
let type_byte = wal_bytes[cur + 8];
match type_byte {
WAL_V3_TYPE_AUTO_COMMIT_SQL => {}
WAL_V3_TYPE_DURABILITY_CHECKPOINT => {
cur += header_len + rec_len;
continue;
}
WAL_V4_TYPE_CHECKPOINT_MARKER => {
cur += header_len + rec_len;
continue;
}
WAL_V4_TYPE_AUTO_COMMIT_SQL | WAL_V4_TYPE_TX_COMMIT_SQL => {
let v4_total = header_len + WAL_V4_EXTRA_HEADER + rec_len;
if wal_bytes.len() - cur < v4_total {
break;
}
let sql_start = cur + header_len + WAL_V4_EXTRA_HEADER;
let sql_bytes = &wal_bytes[sql_start..sql_start + rec_len];
let sql = std::str::from_utf8(sql_bytes)
.map_err(|e| format!("WAL replay: non-UTF-8 SQL at offset {cur}: {e}"))?;
for stmt in split_statements(sql) {
engine.execute(stmt).map_err(|e| {
format!("WAL replay: apply {stmt:?} at offset {cur} rejected: {e:?}")
})?;
}
applied += 1;
cur += v4_total;
continue;
}
other => {
return Err(format!(
"WAL replay: unknown v3 type byte {other:#04x} at offset {cur}"
));
}
}
}
let sql_bytes = &wal_bytes[cur + header_len..cur + header_len + rec_len];
let sql = std::str::from_utf8(sql_bytes)
.map_err(|e| format!("WAL replay: non-UTF-8 SQL at offset {cur}: {e}"))?;
engine
.execute(sql)
.map_err(|e| format!("WAL replay: apply {sql:?} at offset {cur} rejected: {e:?}"))?;
applied += 1;
cur += header_len + rec_len;
}
Ok(applied)
}
#[derive(Debug, Clone)]
pub struct WalRecord<'a> {
pub offset: usize,
pub type_byte: u8,
pub commit_lsn: Option<u64>,
pub commit_unix_us: Option<i64>,
pub sql: &'a [u8],
}
pub fn parse_wal_records(wal_bytes: &[u8]) -> Result<Vec<WalRecord<'_>>, String> {
let mut out = Vec::new();
let mut cur = 0usize;
while cur < wal_bytes.len() {
if wal_bytes.len() - cur < 4 {
break;
}
let raw_len = u32::from_le_bytes(wal_bytes[cur..cur + 4].try_into().unwrap());
let is_v2 = raw_len & WAL_V2_SENTINEL != 0;
let is_v3 = is_v2 && (raw_len & WAL_V3_FLAG != 0);
let len_mask = if is_v3 {
!(WAL_V2_SENTINEL | WAL_V3_FLAG)
} else {
!WAL_V2_SENTINEL
};
let rec_len = (raw_len & len_mask) as usize;
let header_len = if is_v3 {
9
} else if is_v2 {
8
} else {
4
};
if wal_bytes.len() - cur < header_len + rec_len {
break;
}
if !is_v3 {
let sql = &wal_bytes[cur + header_len..cur + header_len + rec_len];
out.push(WalRecord {
offset: cur,
type_byte: WAL_V3_TYPE_AUTO_COMMIT_SQL,
commit_lsn: None,
commit_unix_us: None,
sql,
});
cur += header_len + rec_len;
continue;
}
let type_byte = wal_bytes[cur + 8];
match type_byte {
WAL_V3_TYPE_AUTO_COMMIT_SQL => {
let sql = &wal_bytes[cur + header_len..cur + header_len + rec_len];
out.push(WalRecord {
offset: cur,
type_byte,
commit_lsn: None,
commit_unix_us: None,
sql,
});
cur += header_len + rec_len;
}
WAL_V3_TYPE_DURABILITY_CHECKPOINT => {
out.push(WalRecord {
offset: cur,
type_byte,
commit_lsn: None,
commit_unix_us: None,
sql: &[],
});
cur += header_len + rec_len;
}
WAL_V4_TYPE_CHECKPOINT_MARKER => {
if rec_len < 18 {
return Err(format!(
"WAL parse: checkpoint marker at offset {cur} too short ({rec_len} bytes)"
));
}
let lsn = u64::from_le_bytes(
wal_bytes[cur + header_len..cur + header_len + 8]
.try_into()
.unwrap(),
);
let ts_raw = i64::from_le_bytes(
wal_bytes[cur + header_len + 8..cur + header_len + 16]
.try_into()
.unwrap(),
);
let path_len = u16::from_le_bytes(
wal_bytes[cur + header_len + 16..cur + header_len + 18]
.try_into()
.unwrap(),
) as usize;
if rec_len < 18 + path_len {
return Err(format!(
"WAL parse: checkpoint marker at offset {cur} truncated path"
));
}
let path_start = cur + header_len + 18;
let path_bytes = &wal_bytes[path_start..path_start + path_len];
let commit_unix_us = if ts_raw == WAL_V4_NO_CLOCK {
None
} else {
Some(ts_raw)
};
out.push(WalRecord {
offset: cur,
type_byte,
commit_lsn: Some(lsn),
commit_unix_us,
sql: path_bytes,
});
cur += header_len + rec_len;
}
WAL_V4_TYPE_AUTO_COMMIT_SQL | WAL_V4_TYPE_TX_COMMIT_SQL => {
let v4_total = header_len + WAL_V4_EXTRA_HEADER + rec_len;
if wal_bytes.len() - cur < v4_total {
break;
}
let lsn = u64::from_le_bytes(
wal_bytes[cur + header_len..cur + header_len + 8]
.try_into()
.unwrap(),
);
let ts_raw = i64::from_le_bytes(
wal_bytes[cur + header_len + 8..cur + header_len + 16]
.try_into()
.unwrap(),
);
let commit_unix_us = if ts_raw == WAL_V4_NO_CLOCK {
None
} else {
Some(ts_raw)
};
let sql_start = cur + header_len + WAL_V4_EXTRA_HEADER;
let sql = &wal_bytes[sql_start..sql_start + rec_len];
out.push(WalRecord {
offset: cur,
type_byte,
commit_lsn: Some(lsn),
commit_unix_us,
sql,
});
cur += v4_total;
}
other => {
return Err(format!(
"WAL parse: unknown type byte {other:#04x} at offset {cur}"
));
}
}
}
Ok(out)
}
fn sql_is_read_only(sql: &str) -> bool {
let t = sql.trim_start();
let head = t
.split(|c: char| c.is_whitespace() || c == ';' || c == '(')
.next()
.unwrap_or("");
matches!(
head.to_ascii_lowercase().as_str(),
"select"
| "show"
| "explain"
| "begin"
| "commit"
| "rollback"
| "checkpoint"
| "compact"
| "wait"
| "with"
)
}
#[derive(Debug)]
pub struct Database {
engine: Engine,
persistence: Option<PersistenceCtx>,
commit_lsn: AtomicU64,
tx_wal: Option<TxWalBuffer>,
}
#[derive(Debug, Default)]
struct TxWalBuffer {
statements: Vec<String>,
savepoints: Vec<(String, usize)>,
}
enum TxControl {
Begin,
Commit,
Rollback,
RollbackToSavepoint(String),
Savepoint(String),
ReleaseSavepoint,
}
fn tx_control_kind(sql: &str) -> Option<TxControl> {
let mut words = sql
.split(|c: char| c.is_whitespace() || c == ';')
.filter(|w| !w.is_empty())
.map(str::to_ascii_lowercase);
let head = words.next()?;
match head.as_str() {
"begin" | "start" => Some(TxControl::Begin),
"commit" | "end" => Some(TxControl::Commit),
"savepoint" => words.next().map(TxControl::Savepoint),
"release" => Some(TxControl::ReleaseSavepoint),
"rollback" => match words.next().as_deref() {
Some("to") => {
let next = words.next()?;
let name = if next == "savepoint" {
words.next()?
} else {
next
};
Some(TxControl::RollbackToSavepoint(name))
}
_ => Some(TxControl::Rollback),
},
_ => None,
}
}
#[derive(Debug)]
#[allow(dead_code)] struct PersistenceCtx {
db_path: PathBuf,
wal_dir: PathBuf,
current_chunk_path: PathBuf,
retention_shutdown: Option<Arc<AtomicBool>>,
retention_thread: Option<std::thread::JoinHandle<()>>,
flusher_shutdown: Option<Arc<AtomicBool>>,
flusher_thread: Option<std::thread::JoinHandle<()>>,
wal: Arc<WalGroup>,
checkpoint_threshold_bytes: u64,
cold_segments_dir: PathBuf,
cold_segment_paths: BTreeMap<u32, PathBuf>,
lock_path: PathBuf,
}
impl Database {
#[must_use]
pub fn open_in_memory() -> Self {
Self {
engine: Engine::new().with_clock(wall_clock_micros),
persistence: None,
commit_lsn: AtomicU64::new(0),
tx_wal: None,
}
}
pub fn open_path(db_path: impl AsRef<Path>) -> Result<Self, EngineError> {
let db_path = db_path.as_ref().to_path_buf();
let wal_path = {
let mut p = db_path.clone();
let name = p
.file_name()
.map(|n| {
let mut s = n.to_os_string();
s.push(".wal");
s
})
.unwrap_or_else(|| std::ffi::OsString::from(".wal"));
p.set_file_name(name);
p
};
let wal_dir = wal_path.clone();
if let Some(parent) = db_path.parent()
&& !parent.as_os_str().is_empty()
{
std::fs::create_dir_all(parent).map_err(io_err)?;
}
let lock_path = {
let mut p = db_path.clone();
let name = p
.file_name()
.map(|n| {
let mut s = n.to_os_string();
s.push(".lock");
s
})
.unwrap_or_else(|| std::ffi::OsString::from(".lock"));
p.set_file_name(name);
p
};
acquire_path_lock(&lock_path)?;
let mut engine = if db_path.exists() {
let bytes = std::fs::read(&db_path).map_err(io_err)?;
let engine = Engine::restore_envelope(&bytes).map_err(|e| {
EngineError::Storage(spg_storage::StorageError::Corrupt(format!(
"restore from {}: {e}",
db_path.display()
)))
})?;
engine.with_clock(wall_clock_micros)
} else {
Engine::new().with_clock(wall_clock_micros)
};
let cold_segments_dir = {
let parent = db_path.parent().unwrap_or_else(|| Path::new("."));
let stem = db_path
.file_stem()
.unwrap_or_else(|| std::ffi::OsStr::new("db"))
.to_string_lossy()
.into_owned();
parent.join(format!("{stem}.spg")).join("segments")
};
let mut cold_segment_paths: BTreeMap<u32, PathBuf> = BTreeMap::new();
let manifest_pth = spg_manifest_path(&db_path);
if manifest_pth.exists() && db_path.exists() {
let m_bytes = std::fs::read(&manifest_pth).map_err(io_err)?;
if let Ok(m) = CatalogManifest::deserialize(&m_bytes) {
let snap_bytes = std::fs::read(&db_path).map_err(io_err)?;
let snap_crc = spg_crypto::crc32::crc32(&snap_bytes);
if snap_crc == m.catalog_crc32 {
for entry in &m.cold_segments {
if let Ok(seg_bytes) = std::fs::read(&entry.path) {
let computed = spg_crypto::crc32::crc32(&seg_bytes);
if computed != entry.crc32 {
eprintln!(
"spg-embedded: manifest skip segment {}: CRC mismatch",
entry.segment_id
);
continue;
}
if engine.catalog().cold_segment(entry.segment_id).is_some() {
continue;
}
let mut new_cat = engine.catalog().clone();
if let Err(e) =
new_cat.load_segment_bytes_at(entry.segment_id, seg_bytes)
{
eprintln!(
"spg-embedded: manifest load segment {} failed: {e}",
entry.segment_id
);
continue;
}
engine.replace_catalog(new_cat);
cold_segment_paths.insert(entry.segment_id, entry.path.clone());
} else {
eprintln!(
"spg-embedded: manifest skip segment {}: file unreadable",
entry.segment_id
);
}
}
}
}
}
let mut initial_lsn: u64 = 0;
if wal_path.is_file() {
let legacy_bytes = std::fs::read(&wal_path).map_err(io_err)?;
std::fs::remove_file(&wal_path).map_err(io_err)?;
std::fs::create_dir_all(&wal_dir).map_err(io_err)?;
if !legacy_bytes.is_empty() {
let migrated = wal_dir.join(legacy_chunk_filename());
std::fs::write(&migrated, &legacy_bytes).map_err(io_err)?;
}
} else if !wal_dir.exists() {
std::fs::create_dir_all(&wal_dir).map_err(io_err)?;
}
let chunk_paths = sorted_wal_chunks(&wal_dir).map_err(io_err)?;
let mut snapshot_lsn: u64 = 0;
for chunk in &chunk_paths {
let bytes = std::fs::read(chunk).map_err(io_err)?;
if let Ok(records) = parse_wal_records(&bytes) {
for r in &records {
if r.type_byte == WAL_V4_TYPE_CHECKPOINT_MARKER {
if let Some(l) = r.commit_lsn {
if l > snapshot_lsn {
snapshot_lsn = l;
}
}
}
}
}
}
for chunk in &chunk_paths {
let bytes = std::fs::read(chunk).map_err(io_err)?;
if bytes.is_empty() {
continue;
}
replay_wal_filtered(&bytes, &mut engine, snapshot_lsn)
.map_err(|m| EngineError::Storage(spg_storage::StorageError::Corrupt(m)))?;
if let Ok(records) = parse_wal_records(&bytes) {
if let Some(max) = records.iter().filter_map(|r| r.commit_lsn).max() {
if max > initial_lsn {
initial_lsn = max;
}
}
}
}
let now_us = wall_clock_micros();
let current_chunk_path = if let Some(last) = chunk_paths.last() {
last.clone()
} else {
wal_dir.join(chunk_filename(now_us, initial_lsn + 1))
};
let wal_file = OpenOptions::new()
.create(true)
.append(true)
.read(true)
.open(¤t_chunk_path)
.map_err(io_err)?;
let wal_len = wal_file.metadata().map_err(io_err)?.len();
let wal = Arc::new(WalGroup::new(wal_file, wal_len));
let retention_hours = pitr_retention_hours();
let (retention_shutdown, retention_thread) = if retention_hours > 0 {
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = Arc::clone(&shutdown);
let wal_dir_clone = wal_dir.clone();
let check_interval = std::time::Duration::from_secs(pitr_retention_check_sec());
let archive_cmd = pitr_archive_cmd();
let handle = std::thread::Builder::new()
.name("spg-pitr-retention".into())
.spawn(move || {
retention_sweep_loop(
wal_dir_clone,
retention_hours,
check_interval,
archive_cmd,
shutdown_clone,
);
})
.map_err(io_err)?;
(Some(shutdown), Some(handle))
} else {
(None, None)
};
let (flusher_shutdown, flusher_thread) = if synchronous_commit_on() {
(None, None)
} else {
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = Arc::clone(&shutdown);
let group = Arc::clone(&wal);
let interval = std::time::Duration::from_millis(wal_writer_delay_ms());
let handle = std::thread::Builder::new()
.name("spg-wal-flusher".into())
.spawn(move || {
while !shutdown_clone.load(Ordering::SeqCst) {
std::thread::sleep(interval);
if let Err(e) = group.flush_now() {
eprintln!("spg-embedded: background WAL flush failed: {e:?}");
}
}
let _ = group.flush_now();
})
.map_err(io_err)?;
(Some(shutdown), Some(handle))
};
Ok(Self {
engine,
commit_lsn: AtomicU64::new(initial_lsn),
tx_wal: None,
persistence: Some(PersistenceCtx {
db_path,
wal_dir,
current_chunk_path,
wal,
checkpoint_threshold_bytes: default_checkpoint_threshold_bytes(),
cold_segments_dir,
cold_segment_paths,
lock_path,
retention_shutdown,
retention_thread,
flusher_shutdown,
flusher_thread,
}),
})
}
pub fn freeze_oldest_to_cold(
&mut self,
table_name: &str,
index_name: &str,
max_rows: usize,
) -> Result<spg_storage::FreezeReport, EngineError> {
let report = self
.engine
.freeze_oldest_to_cold(table_name, index_name, max_rows)?;
if let Some(p) = &mut self.persistence {
std::fs::create_dir_all(&p.cold_segments_dir).map_err(io_err)?;
let final_path = p
.cold_segments_dir
.join(format!("seg_{}.spg", report.segment_id));
let tmp_path = p
.cold_segments_dir
.join(format!("seg_{}.spg.tmp", report.segment_id));
std::fs::write(&tmp_path, &report.segment_bytes).map_err(io_err)?;
std::fs::rename(&tmp_path, &final_path).map_err(io_err)?;
p.cold_segment_paths.insert(report.segment_id, final_path);
}
Ok(report)
}
pub fn set_checkpoint_threshold_bytes(&mut self, bytes: u64) {
if let Some(p) = &mut self.persistence {
p.checkpoint_threshold_bytes = bytes.max(1);
}
}
pub fn checkpoint(&mut self) -> Result<(), EngineError> {
let snapshot = self.engine.snapshot();
let Some(p) = &mut self.persistence else {
return Ok(());
};
let tmp = {
let mut t = p.db_path.clone();
let mut name = t
.file_name()
.map(std::ffi::OsStr::to_os_string)
.unwrap_or_default();
name.push(".tmp");
t.set_file_name(name);
t
};
std::fs::write(&tmp, &snapshot).map_err(io_err)?;
std::fs::rename(&tmp, &p.db_path).map_err(io_err)?;
if !p.cold_segment_paths.is_empty() {
let snap_crc = spg_crypto::crc32::crc32(&snapshot);
let entries: Vec<ColdSegmentEntry> = p
.cold_segment_paths
.iter()
.filter_map(|(&segment_id, path)| {
let bytes = std::fs::read(path).ok()?;
Some(ColdSegmentEntry {
segment_id,
path: path.clone(),
crc32: spg_crypto::crc32::crc32(&bytes),
})
})
.collect();
let manifest = CatalogManifest {
catalog_crc32: snap_crc,
cold_segments: entries,
wal_baseline_offset: 0,
};
let m_bytes = manifest.serialize();
let m_path = spg_manifest_path(&p.db_path);
if let Some(dir) = m_path.parent() {
std::fs::create_dir_all(dir).map_err(io_err)?;
}
let m_tmp = {
let mut t = m_path.clone();
let mut name = t
.file_name()
.map(std::ffi::OsStr::to_os_string)
.unwrap_or_default();
name.push(".tmp");
t.set_file_name(name);
t
};
std::fs::write(&m_tmp, &m_bytes).map_err(io_err)?;
std::fs::rename(&m_tmp, &m_path).map_err(io_err)?;
}
let marker_lsn = self.commit_lsn.load(Ordering::SeqCst);
let marker_ts = wall_clock_micros();
let marker = encode_v4_checkpoint_marker(marker_lsn, marker_ts, &p.db_path);
p.wal.enqueue(&marker);
p.wal.flush_now()?;
let new_chunk_path = p.wal_dir.join(chunk_filename(marker_ts, marker_lsn + 1));
let new_handle = OpenOptions::new()
.create(true)
.append(true)
.read(true)
.open(&new_chunk_path)
.map_err(io_err)?;
p.current_chunk_path = new_chunk_path;
p.wal.rotate_file(new_handle);
Ok(())
}
pub fn restore(snapshot: &[u8]) -> Result<Self, EngineError> {
let engine = Engine::restore_envelope(snapshot).map_err(|e| {
EngineError::Storage(spg_storage::StorageError::Corrupt(format!("restore: {e}")))
})?;
Ok(Self {
engine,
persistence: None,
commit_lsn: AtomicU64::new(0),
tx_wal: None,
})
}
#[must_use]
pub fn snapshot(&self) -> Vec<u8> {
self.engine.snapshot()
}
pub fn execute(&mut self, sql: &str) -> Result<QueryResult, EngineError> {
let (result, ticket) = self.execute_buffered(sql)?;
if let Some(t) = ticket {
t.wait()?;
}
Ok(result)
}
pub fn execute_buffered(
&mut self,
sql: &str,
) -> Result<(QueryResult, Option<WalTicket>), EngineError> {
let result = self.engine.execute(sql)?;
let modified = matches!(
&result,
QueryResult::CommandOk {
modified_catalog: true,
..
}
);
let ticket = self.wal_after_ok(sql, modified)?;
Ok((result, ticket))
}
fn wal_after_ok(
&mut self,
canonical: &str,
modified_catalog: bool,
) -> Result<Option<WalTicket>, EngineError> {
if self.persistence.is_none() {
return Ok(None);
}
let mut record = None;
match tx_control_kind(canonical) {
Some(TxControl::Begin) => {
self.tx_wal = Some(TxWalBuffer::default());
}
Some(TxControl::Commit) => {
if let Some(buf) = self.tx_wal.take()
&& !buf.statements.is_empty()
{
let script = buf.statements.join(";\n");
let lsn = self.commit_lsn.fetch_add(1, Ordering::SeqCst) + 1;
record = Some(encode_v4_tx_commit(&script, lsn, wall_clock_micros()));
}
}
Some(TxControl::Rollback) => {
self.tx_wal = None;
}
Some(TxControl::Savepoint(name)) => {
if let Some(buf) = &mut self.tx_wal {
buf.savepoints.retain(|(n, _)| n != &name);
let mark = buf.statements.len();
buf.savepoints.push((name, mark));
}
}
Some(TxControl::RollbackToSavepoint(name)) => {
if let Some(buf) = &mut self.tx_wal
&& let Some(pos) = buf.savepoints.iter().position(|(n, _)| n == &name)
{
let mark = buf.savepoints[pos].1;
buf.statements.truncate(mark);
buf.savepoints.truncate(pos + 1);
}
}
Some(TxControl::ReleaseSavepoint) => {
}
None => {
if let Some(buf) = &mut self.tx_wal {
if !sql_is_read_only(canonical) {
buf.statements.push(canonical.to_string());
}
} else if modified_catalog && !sql_is_read_only(canonical) {
let lsn = self.commit_lsn.fetch_add(1, Ordering::SeqCst) + 1;
record = Some(encode_v4_auto_commit(canonical, lsn, wall_clock_micros()));
}
}
}
let mut ticket = None;
if let Some(record) = record {
let p = self.persistence.as_mut().expect("checked above");
let seq = p.wal.enqueue(&record);
ticket = Some(WalTicket {
group: Arc::clone(&p.wal),
seq,
});
if p.wal.written_len() >= p.checkpoint_threshold_bytes {
self.checkpoint()?;
}
}
Ok(ticket)
}
pub fn query_typed<T: FromSpgRow>(&mut self, sql: &str) -> Result<Vec<T>, EngineError> {
let rows = self.query(sql)?;
rows.into_iter().map(|r| T::from_spg_row(&r)).collect()
}
pub fn query(&mut self, sql: &str) -> Result<Vec<Vec<Value>>, EngineError> {
match self.engine.execute(sql)? {
QueryResult::Rows { rows, .. } => Ok(rows.into_iter().map(|r| r.values).collect()),
QueryResult::CommandOk { .. } => Err(EngineError::Unsupported(
"query() expects a SELECT — use execute() for DML/DDL".into(),
)),
_ => Err(EngineError::Unsupported(
"query() expects a SELECT — use execute() for DML/DDL".into(),
)),
}
}
pub fn query_with_columns(
&mut self,
sql: &str,
) -> Result<(Vec<spg_storage::ColumnSchema>, Vec<Vec<Value>>), EngineError> {
match self.engine.execute(sql)? {
QueryResult::Rows { columns, rows } => {
Ok((columns, rows.into_iter().map(|r| r.values).collect()))
}
QueryResult::CommandOk { .. } => Err(EngineError::Unsupported(
"query_with_columns() expects a SELECT — use execute() for DML/DDL".into(),
)),
_ => Err(EngineError::Unsupported(
"query_with_columns() expects a SELECT — use execute() for DML/DDL".into(),
)),
}
}
pub fn query_prepared_with_columns(
&mut self,
stmt: &Statement,
params: &[Value],
) -> Result<(Vec<spg_storage::ColumnSchema>, Vec<Vec<Value>>), EngineError> {
match self.engine.execute_prepared(stmt.stmt.clone(), params)? {
QueryResult::Rows { columns, rows } => {
Ok((columns, rows.into_iter().map(|r| r.values).collect()))
}
QueryResult::CommandOk { .. } => Err(EngineError::Unsupported(
"query_prepared_with_columns() expects a SELECT — use execute_prepared() for DML/DDL".into(),
)),
_ => Err(EngineError::Unsupported(
"query_prepared_with_columns() expects a SELECT — use execute_prepared() for DML/DDL".into(),
)),
}
}
#[must_use]
pub const fn engine(&self) -> &Engine {
&self.engine
}
pub const fn engine_mut(&mut self) -> &mut Engine {
&mut self.engine
}
pub fn prepare(&mut self, sql: &str) -> Result<Statement, EngineError> {
let stmt = self
.engine
.prepare_cached(sql)
.map_err(EngineError::Parse)?;
Ok(Statement {
stmt,
sql: sql.to_string(),
})
}
pub fn describe(&mut self, sql: &str) -> Result<(Vec<u32>, Vec<ColumnSchema>), EngineError> {
let stmt = self
.engine
.prepare_cached(sql)
.map_err(EngineError::Parse)?;
Ok(self.engine.describe_prepared(&stmt))
}
pub fn execute_prepared(
&mut self,
stmt: &Statement,
params: &[Value],
) -> Result<QueryResult, EngineError> {
let (result, ticket) = self.execute_prepared_buffered(stmt, params)?;
if let Some(t) = ticket {
t.wait()?;
}
Ok(result)
}
pub fn execute_prepared_buffered(
&mut self,
stmt: &Statement,
params: &[Value],
) -> Result<(QueryResult, Option<WalTicket>), EngineError> {
let result = self.engine.execute_prepared(stmt.stmt.clone(), params)?;
let modified = matches!(
&result,
QueryResult::CommandOk {
modified_catalog: true,
..
}
);
let mut ticket = None;
if self.persistence.is_some()
&& (modified
|| (self.tx_wal.is_some() && !sql_is_read_only(&stmt.sql))
|| tx_control_kind(&stmt.sql).is_some())
{
let mut wal_stmt = stmt.stmt.clone();
crate::wal_render_with_params(&mut wal_stmt, params);
let canonical = format!("{wal_stmt}");
ticket = self.wal_after_ok(&canonical, modified)?;
}
Ok((result, ticket))
}
pub fn query_prepared(
&mut self,
stmt: &Statement,
params: &[Value],
) -> Result<Vec<Vec<Value>>, EngineError> {
match self.engine.execute_prepared(stmt.stmt.clone(), params)? {
QueryResult::Rows { rows, .. } => Ok(rows.into_iter().map(|r| r.values).collect()),
QueryResult::CommandOk { .. } => Err(EngineError::Unsupported(
"query_prepared() expects a SELECT — use execute_prepared() for DML/DDL".into(),
)),
_ => Err(EngineError::Unsupported(
"query_prepared() expects a SELECT — use execute_prepared() for DML/DDL".into(),
)),
}
}
pub fn prepare_on_snapshot(
snapshot: &CatalogSnapshot,
sql: &str,
) -> Result<Statement, EngineError> {
let stmt =
spg_engine::Engine::prepare_on_snapshot(snapshot, sql).map_err(EngineError::Parse)?;
Ok(Statement {
stmt,
sql: sql.to_string(),
})
}
pub fn execute_prepared_on_snapshot(
snapshot: &CatalogSnapshot,
stmt: &Statement,
params: &[Value],
) -> Result<QueryResult, EngineError> {
spg_engine::Engine::execute_readonly_prepared_on_snapshot(
snapshot,
stmt.stmt.clone(),
params,
)
}
pub fn describe_on_snapshot(
snapshot: &CatalogSnapshot,
sql: &str,
) -> Result<(Vec<u32>, Vec<ColumnSchema>), EngineError> {
let stmt =
spg_engine::Engine::prepare_on_snapshot(snapshot, sql).map_err(EngineError::Parse)?;
Ok(spg_engine::Engine::describe_prepared_on_snapshot(
snapshot, &stmt,
))
}
pub fn execute_script(&mut self, sql: &str) -> Result<Vec<QueryResult>, EngineError> {
let stmts = split_statements(sql);
let script_owns_tx = stmts.iter().any(|s| tx_control_kind(s).is_some());
let wrap = stmts.len() > 1 && !script_owns_tx && !self.engine.in_transaction();
if !wrap {
let mut out = Vec::with_capacity(stmts.len());
for stmt in &stmts {
out.push(self.execute_dump_statement(stmt)?);
}
return Ok(out);
}
self.execute("BEGIN")?;
let mut out = Vec::with_capacity(stmts.len());
for stmt in &stmts {
match self.execute_dump_statement(stmt) {
Ok(r) => out.push(r),
Err(e) => {
let _ = self.execute("ROLLBACK");
return Err(e);
}
}
}
self.execute("COMMIT")?;
Ok(out)
}
pub fn execute_dump_statement(&mut self, stmt: &str) -> Result<QueryResult, EngineError> {
let stmt_clean = strip_leading_sql_noise(stmt);
let head_is_copy = stmt_clean
.get(..4)
.is_some_and(|p| p.eq_ignore_ascii_case("copy"));
if head_is_copy
&& let Some((head, data)) = stmt_clean.split_once(';')
&& let Some(spec) = spg_engine::copy::parse_copy_from_stdin_head(head)
{
let mut affected: usize = 0;
for line in data.lines() {
let line = line.strip_suffix('\r').unwrap_or(line);
if line.is_empty() {
continue;
}
let values = spg_engine::copy::decode_copy_text_row(line);
let insert = spg_engine::copy::build_copy_insert(
&spec.table,
spec.columns.as_deref(),
&values,
);
match self.execute(&insert)? {
QueryResult::CommandOk { affected: n, .. } => affected += n,
_ => affected += 1,
}
}
return Ok(QueryResult::CommandOk {
affected,
modified_catalog: false,
});
}
self.execute(stmt)
}
pub fn with_transaction<R, F>(&mut self, body: F) -> Result<R, EngineError>
where
F: FnOnce(&mut Self) -> Result<R, EngineError>,
{
self.execute("BEGIN")?;
match body(self) {
Ok(value) => {
self.execute("COMMIT")?;
Ok(value)
}
Err(e) => {
let _ = self.execute("ROLLBACK");
Err(e)
}
}
}
}
impl Default for Database {
fn default() -> Self {
Self::open_in_memory()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub struct EmbeddedMetrics {
pub hot_rows: u64,
pub hot_bytes: u64,
pub cold_segments: u64,
pub tables: u64,
pub wal_bytes: u64,
pub persistent: bool,
}
#[must_use = "the background freezer keeps running until this handle is dropped"]
#[derive(Debug)]
pub struct FreezerHandle {
shutdown: Arc<AtomicBool>,
join: Option<JoinHandle<()>>,
}
impl FreezerHandle {
pub fn stop(&mut self) {
self.shutdown.store(true, Ordering::Release);
if let Some(h) = self.join.take() {
let _ = h.join();
}
}
}
impl Drop for FreezerHandle {
fn drop(&mut self) {
self.stop();
}
}
#[derive(Debug, Clone)]
pub struct FreezerOptions {
pub tick: Duration,
pub hot_tier_bytes: u64,
pub batch_rows: usize,
pub compact_when_segments_exceed: usize,
pub compact_target_bytes: u64,
}
impl Default for FreezerOptions {
fn default() -> Self {
Self {
tick: Duration::from_secs(1),
hot_tier_bytes: 4 * 1024 * 1024 * 1024,
batch_rows: 1000,
compact_when_segments_exceed: 64,
compact_target_bytes: 64 * 1024 * 1024,
}
}
}
impl Database {
#[must_use]
pub fn cold_segment_count(&self) -> usize {
self.engine.catalog().cold_segment_count()
}
#[must_use]
pub fn metrics(&self) -> EmbeddedMetrics {
let cat = self.engine.catalog();
let mut hot_rows: u64 = 0;
let mut hot_bytes: u64 = 0;
for name in cat.table_names() {
if let Some(t) = cat.get(&name) {
hot_rows = hot_rows.saturating_add(t.row_count() as u64);
hot_bytes = hot_bytes.saturating_add(t.hot_bytes());
}
}
let (wal_bytes, persistent) = match &self.persistence {
Some(p) => (p.wal.written_len(), true),
None => (0, false),
};
EmbeddedMetrics {
hot_rows,
hot_bytes,
cold_segments: cat.cold_segment_count() as u64,
tables: cat.table_count() as u64,
wal_bytes,
persistent,
}
}
pub fn spawn_background_freezer(
db: Arc<Mutex<Database>>,
opts: FreezerOptions,
) -> FreezerHandle {
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_for_thread = Arc::clone(&shutdown);
let join = thread::Builder::new()
.name("spg-embedded-freezer".into())
.spawn(move || {
background_freezer_loop(db, opts, shutdown_for_thread);
})
.expect("spawn background freezer thread");
FreezerHandle {
shutdown,
join: Some(join),
}
}
}
fn background_freezer_loop(
db: Arc<Mutex<Database>>,
opts: FreezerOptions,
shutdown: Arc<AtomicBool>,
) {
let slice = Duration::from_millis(50.min(opts.tick.as_millis() as u64));
let mut last_tick = std::time::Instant::now();
loop {
if shutdown.load(Ordering::Acquire) {
return;
}
thread::sleep(slice);
if last_tick.elapsed() < opts.tick {
continue;
}
last_tick = std::time::Instant::now();
let Ok(mut guard) = db.lock() else {
return;
};
if guard.engine.catalog().hot_tier_bytes() <= opts.hot_tier_bytes {
continue;
}
let Some((table, index)) = pick_freeze_target(&guard) else {
continue;
};
let row_count = guard
.engine
.catalog()
.get(&table)
.map_or(0, spg_storage::Table::row_count);
let to_freeze = opts.batch_rows.min(row_count);
if to_freeze == 0 {
continue;
}
if let Err(e) = guard.freeze_oldest_to_cold(&table, &index, to_freeze) {
eprintln!("spg-embedded: background freeze on {table}.{index} failed: {e:?}");
continue;
}
let count = guard.engine.catalog().cold_segment_count();
if count > opts.compact_when_segments_exceed {
if let Err(e) = guard
.engine
.compact_cold_segments_with_target(opts.compact_target_bytes)
{
eprintln!(
"spg-embedded: background compact failed (segments={count}, \
threshold={}): {e:?}",
opts.compact_when_segments_exceed,
);
}
}
}
}
fn pick_freeze_target(db: &Database) -> Option<(String, String)> {
let cat = db.engine.catalog();
let mut best: Option<(String, String, u64)> = None;
for name in cat.table_names() {
let Some(t) = cat.get(&name) else { continue };
if t.row_count() == 0 {
continue;
}
let cols = &t.schema().columns;
let Some(idx) = t.indices().iter().find(|i| {
matches!(i.kind, spg_storage::IndexKind::BTree(_))
&& i.column_position < cols.len()
&& matches!(
cols[i.column_position].ty,
spg_storage::DataType::SmallInt
| spg_storage::DataType::Int
| spg_storage::DataType::BigInt
)
}) else {
continue;
};
let hot = t.hot_bytes();
match best {
None => best = Some((name, idx.name.clone(), hot)),
Some((_, _, best_hot)) if hot > best_hot => {
best = Some((name, idx.name.clone(), hot));
}
_ => {}
}
}
best.map(|(t, i, _)| (t, i))
}
pub fn revert_wal_to_seq(
wal_path: impl AsRef<Path>,
to_seq: u64,
out_db_path: impl AsRef<Path>,
) -> Result<u64, EngineError> {
let path = wal_path.as_ref();
let wal_bytes = if path.is_dir() {
let mut combined = Vec::new();
let chunks = sorted_wal_chunks(path).map_err(io_err)?;
for chunk in chunks {
let bytes = std::fs::read(&chunk).map_err(io_err)?;
combined.extend_from_slice(&bytes);
}
combined
} else {
std::fs::read(path).map_err(io_err)?
};
let mut engine = Engine::new();
let mut applied = 0u64;
let mut cur = 0usize;
while cur < wal_bytes.len() && applied < to_seq {
let (sql_bytes, total) = decode_wal_record(&wal_bytes[cur..])?;
cur += total;
if sql_bytes.is_empty() {
continue;
}
let sql = core::str::from_utf8(&sql_bytes).map_err(|e| {
EngineError::Storage(spg_storage::StorageError::Corrupt(format!(
"WAL record at offset {cur}: non-UTF-8 SQL: {e}"
)))
})?;
for stmt in split_statements(sql) {
engine.execute(stmt)?;
}
applied += 1;
}
let snapshot = engine.snapshot();
std::fs::write(out_db_path.as_ref(), &snapshot).map_err(io_err)?;
Ok(applied)
}
fn decode_wal_record(tail: &[u8]) -> Result<(Vec<u8>, usize), EngineError> {
if tail.len() < 4 {
return Err(EngineError::Storage(spg_storage::StorageError::Corrupt(
format!("WAL truncated record: {} < 4 header bytes", tail.len()),
)));
}
let raw_len = u32::from_le_bytes(tail[..4].try_into().unwrap());
let is_v2 = raw_len & WAL_V2_SENTINEL != 0;
let is_v3 = is_v2 && (raw_len & WAL_V3_FLAG != 0);
let len_mask = if is_v3 {
!(WAL_V2_SENTINEL | WAL_V3_FLAG)
} else {
!WAL_V2_SENTINEL
};
let rec_len = (raw_len & len_mask) as usize;
let header_len = if is_v3 {
9
} else if is_v2 {
8
} else {
4
};
if tail.len() < header_len + rec_len {
return Err(EngineError::Storage(spg_storage::StorageError::Corrupt(
format!(
"WAL truncated record: header+payload {} > available {}",
header_len + rec_len,
tail.len()
),
)));
}
if is_v3 {
let type_byte = tail[8];
if type_byte == WAL_V3_TYPE_AUTO_COMMIT_SQL {
let payload = &tail[header_len..header_len + rec_len];
return Ok((payload.to_vec(), header_len + rec_len));
}
if type_byte == WAL_V4_TYPE_AUTO_COMMIT_SQL || type_byte == WAL_V4_TYPE_TX_COMMIT_SQL {
let v4_total = header_len + WAL_V4_EXTRA_HEADER + rec_len;
if tail.len() < v4_total {
return Err(EngineError::Storage(spg_storage::StorageError::Corrupt(
format!(
"WAL truncated v4 record: header+payload {v4_total} > available {}",
tail.len()
),
)));
}
let sql_start = header_len + WAL_V4_EXTRA_HEADER;
let sql_bytes = tail[sql_start..sql_start + rec_len].to_vec();
return Ok((sql_bytes, v4_total));
}
return Ok((Vec::new(), header_len + rec_len));
}
let payload = &tail[header_len..header_len + rec_len];
Ok((payload.to_vec(), header_len + rec_len))
}
impl Drop for Database {
fn drop(&mut self) {
if self.persistence.is_some() {
if let Err(e) = self.checkpoint() {
eprintln!(
"spg-embedded: final checkpoint on Drop failed: {e:?} \
(WAL is intact; next open_path will replay)"
);
}
}
if let Some(ctx) = self.persistence.as_mut() {
if let Some(shutdown) = ctx.retention_shutdown.take() {
shutdown.store(true, Ordering::SeqCst);
}
if let Some(handle) = ctx.retention_thread.take() {
let _ = handle.join();
}
if let Some(shutdown) = ctx.flusher_shutdown.take() {
shutdown.store(true, Ordering::SeqCst);
}
if let Some(handle) = ctx.flusher_thread.take() {
let _ = handle.join();
}
}
if let Some(ctx) = &self.persistence
&& ctx.lock_path.exists()
{
if let Err(e) = std::fs::remove_dir_all(&ctx.lock_path) {
eprintln!(
"spg-embedded: lock release on Drop failed for {}: {e:?}",
ctx.lock_path.display()
);
}
}
}
}
impl Database {
pub fn force_unlock(db_path: impl AsRef<Path>) -> Result<(), EngineError> {
let lock_path = {
let mut p = db_path.as_ref().to_path_buf();
let name = p
.file_name()
.map(|n| {
let mut s = n.to_os_string();
s.push(".lock");
s
})
.unwrap_or_else(|| std::ffi::OsString::from(".lock"));
p.set_file_name(name);
p
};
if !lock_path.exists() {
return Ok(());
}
std::fs::remove_dir_all(&lock_path).map_err(io_err)
}
}
fn io_err(e: std::io::Error) -> EngineError {
EngineError::Storage(spg_storage::StorageError::Corrupt(format!("io: {e}")))
}
#[allow(dead_code)]
fn _database_is_send() {
fn assert_send<T: Send>() {}
assert_send::<Database>();
}
pub trait FromSpgRow: Sized {
fn from_spg_row(row: &[Value]) -> Result<Self, EngineError>;
}
#[macro_export]
macro_rules! spg_row {
(
$(#[$meta:meta])*
$vis:vis struct $name:ident {
$(
$(#[$fmeta:meta])*
$fvis:vis $field:ident : $ty:ty,
)*
}
) => {
$(#[$meta])*
#[derive(Debug, Clone)]
$vis struct $name {
$(
$(#[$fmeta])*
$fvis $field : $ty,
)*
}
impl $crate::FromSpgRow for $name {
fn from_spg_row(row: &[$crate::Value]) -> ::core::result::Result<Self, $crate::EngineError> {
let mut __spg_row_iter = row.iter();
$(
let $field: $ty = {
let v = __spg_row_iter
.next()
.ok_or_else(|| $crate::EngineError::Unsupported(
::std::format!(
"spg_row! {}: missing column for field `{}`",
::core::stringify!($name),
::core::stringify!($field)
)
))?;
<$ty as $crate::FromSpgValue>::from_spg_value(v)
.map_err(|e| $crate::EngineError::Unsupported(
::std::format!(
"spg_row! {}: column `{}`: {}",
::core::stringify!($name),
::core::stringify!($field),
e
)
))?
};
)*
Ok(Self { $($field,)* })
}
}
};
}
pub trait FromSpgValue: Sized {
fn from_spg_value(v: &Value) -> Result<Self, &'static str>;
}
macro_rules! impl_from_value_int {
($($t:ty),* $(,)?) => {
$(
impl FromSpgValue for $t {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::SmallInt(n) => <$t>::try_from(*n).map_err(|_| "SmallInt does not fit target int type"),
Value::Int(n) => <$t>::try_from(*n).map_err(|_| "Int does not fit target int type"),
Value::BigInt(n) => <$t>::try_from(*n).map_err(|_| "BigInt does not fit target int type"),
Value::Null => Err("NULL in non-Option int column"),
_ => Err("non-integer value in int column"),
}
}
}
)*
};
}
impl_from_value_int!(i16, i32, i64);
impl FromSpgValue for f32 {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Float(f) => Ok(*f as f32),
Value::Null => Err("NULL in non-Option float column"),
_ => Err("non-float value in float column"),
}
}
}
impl FromSpgValue for f64 {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Float(f) => Ok(*f),
Value::Null => Err("NULL in non-Option float column"),
_ => Err("non-float value in float column"),
}
}
}
impl FromSpgValue for bool {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Bool(b) => Ok(*b),
Value::Null => Err("NULL in non-Option bool column"),
_ => Err("non-bool value in bool column"),
}
}
}
impl FromSpgValue for String {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Text(s) => Ok(s.clone()),
Value::Null => Err("NULL in non-Option text column"),
_ => Err("non-text value in String column"),
}
}
}
impl FromSpgValue for Vec<f32> {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Vector(xs) => Ok(xs.clone()),
Value::Null => Err("NULL in non-Option vector column"),
_ => Err("non-vector value in Vec<f32> column"),
}
}
}
impl<T: FromSpgValue> FromSpgValue for Option<T> {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Null => Ok(None),
other => T::from_spg_value(other).map(Some),
}
}
}
fn acquire_path_lock(lock_path: &Path) -> Result<(), EngineError> {
for attempt in 0..2 {
match std::fs::create_dir(lock_path) {
Ok(()) => {
let _ = std::fs::write(lock_path.join("pid"), std::process::id().to_string());
return Ok(());
}
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists && attempt == 0 => {
let owner = std::fs::read_to_string(lock_path.join("pid"))
.ok()
.and_then(|s| s.trim().parse::<u32>().ok());
let owner_alive = owner.is_some_and(pid_alive);
if owner_alive {
return Err(EngineError::Unsupported(format!(
"database is locked by another process (pid {}): {}; \
stop that process first, or call Database::force_unlock()",
owner.unwrap_or(0),
lock_path.display()
)));
}
eprintln!(
"spg-embedded: reclaiming stale lock {} (owner pid {:?} not alive)",
lock_path.display(),
owner
);
std::fs::remove_dir_all(lock_path).map_err(io_err)?;
}
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
return Err(EngineError::Unsupported(format!(
"database is locked by another process: {}; \
stop that process first, or call Database::force_unlock()",
lock_path.display()
)));
}
Err(e) => return Err(io_err(e)),
}
}
unreachable!("acquire_path_lock loop covers both attempts")
}
#[cfg(unix)]
fn pid_alive(pid: u32) -> bool {
match std::process::Command::new("ps")
.arg("-p")
.arg(pid.to_string())
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.status()
{
Ok(status) => status.success(),
Err(_) => true,
}
}
#[cfg(not(unix))]
fn pid_alive(_pid: u32) -> bool {
true
}
fn note_dialect_signals(chunk: &str, mysql_escapes: &mut bool) {
if chunk.len() > 4096 {
return;
}
let lower = chunk.to_ascii_lowercase();
if lower.contains("sql_mode") {
*mysql_escapes = true;
} else if lower.contains("standard_conforming_strings") {
*mysql_escapes = lower.contains("off");
}
}
fn strip_leading_sql_noise(mut s: &str) -> &str {
loop {
let t = s.trim_start();
if let Some(rest) = t.strip_prefix("--") {
s = rest.split_once('\n').map_or("", |(_, r)| r);
continue;
}
if t.starts_with("/*") && !t.starts_with("/*!") {
match t.find("*/") {
Some(e) => {
s = &t[e + 2..];
continue;
}
None => return "",
}
}
return t;
}
}
pub fn split_statements(sql: &str) -> Vec<&str> {
let bytes = sql.as_bytes();
let mut stmts = Vec::new();
let mut start = 0usize;
let mut has_content = false;
let mut mysql_escapes = false;
let mut i = 0usize;
while i < bytes.len() {
match bytes[i] {
b'\\' if !has_content => {
while i < bytes.len() && bytes[i] != b'\n' {
i += 1;
}
start = if i < bytes.len() { i + 1 } else { i };
}
b'\'' => {
has_content = true;
let escape_string = mysql_escapes
|| (i >= 1
&& matches!(bytes[i - 1], b'e' | b'E')
&& !(i >= 2
&& (bytes[i - 2].is_ascii_alphanumeric() || bytes[i - 2] == b'_')));
i += 1;
while i < bytes.len() {
if escape_string && bytes[i] == b'\\' {
i += 2;
continue;
}
if bytes[i] == b'\'' {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
i += 2;
continue;
}
break;
}
i += 1;
}
}
b'"' => {
has_content = true;
i += 1;
while i < bytes.len() && bytes[i] != b'"' {
i += 1;
}
}
b'$' => {
let tag_end = bytes[i + 1..]
.iter()
.position(|&b| !(b.is_ascii_alphanumeric() || b == b'_'))
.map(|off| i + 1 + off);
if let Some(te) = tag_end
&& te < bytes.len()
&& bytes[te] == b'$'
{
has_content = true;
let tag = &sql[i..=te];
if let Some(close) = sql[te + 1..].find(tag) {
i = te + 1 + close + tag.len();
continue;
}
i = bytes.len();
continue;
}
has_content = true;
}
b'-' if i + 1 < bytes.len() && bytes[i + 1] == b'-' => {
while i < bytes.len() && bytes[i] != b'\n' {
i += 1;
}
}
b'/' if i + 1 < bytes.len() && bytes[i + 1] == b'*' => {
if i + 2 < bytes.len() && bytes[i + 2] == b'!' {
has_content = true;
}
let mut depth = 1usize;
i += 2;
while i < bytes.len() && depth > 0 {
if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
depth += 1;
i += 2;
} else if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
depth -= 1;
i += 2;
} else {
i += 1;
}
}
continue;
}
b';' => {
if has_content {
let head = &sql[start..i];
let head_clean = strip_leading_sql_noise(head);
let is_copy_head = head_clean
.get(..4)
.is_some_and(|p| p.eq_ignore_ascii_case("copy"))
&& spg_engine::copy::parse_copy_from_stdin_head(head_clean).is_some();
if is_copy_head {
let mut j = i + 1;
let data_end;
loop {
if j >= bytes.len() {
data_end = bytes.len();
break;
}
let line_end = sql[j..].find('\n').map_or(bytes.len(), |off| j + off);
if sql[j..line_end].trim_end_matches('\r').trim() == "\\." {
data_end = j;
i = line_end; break;
}
j = line_end + 1;
}
stmts.push(&sql[start..data_end]);
if data_end == bytes.len() {
i = bytes.len();
}
start = i + 1;
has_content = false;
i += 1;
continue;
}
note_dialect_signals(head, &mut mysql_escapes);
stmts.push(head);
}
start = i + 1;
has_content = false;
}
b => {
if !b.is_ascii_whitespace() {
has_content = true;
}
}
}
i += 1;
}
if has_content {
stmts.push(&sql[start..]);
}
stmts
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn split_statements_basic_and_trailing() {
assert_eq!(
split_statements("CREATE TABLE a (x INT); INSERT INTO a VALUES (1)"),
vec!["CREATE TABLE a (x INT)", " INSERT INTO a VALUES (1)"]
);
assert!(split_statements(" ;; -- nothing\n;").is_empty());
}
#[test]
fn split_statements_quoting_forms() {
let cases = [
"INSERT INTO t VALUES ('a;b')",
"INSERT INTO t VALUES ('it''s; fine')",
r"INSERT INTO t VALUES (E'it\'s; fine')",
"CREATE TABLE \"odd;name\" (x INT)",
"DO $body$ BEGIN PERFORM 1; END $body$",
"DO $$ SELECT 1; $$",
];
for sql in cases {
assert_eq!(split_statements(sql), vec![sql], "must stay whole: {sql}");
}
for sql in cases {
let script = format!("{sql};\nSELECT 2");
assert_eq!(
split_statements(&script),
vec![sql, "\nSELECT 2"],
"must split after: {sql}"
);
}
}
#[test]
fn split_statements_drops_psql_meta_lines() {
let script = "\\restrict TOKEN123\nSELECT 1;\n\\unrestrict TOKEN123\nSELECT 2;\n\\.\n";
assert_eq!(split_statements(script), vec!["SELECT 1", "SELECT 2"]);
let s2 = r"SELECT E'a\\b'";
assert_eq!(split_statements(s2), vec![s2]);
}
#[test]
fn split_statements_comments_hide_semicolons() {
let script = "-- c1 ; still comment\nSELECT 1; /* a ; b /* nested ; */ */ SELECT 2";
let got = split_statements(script);
assert_eq!(got.len(), 2);
assert!(got[0].contains("SELECT 1"));
assert!(got[1].contains("SELECT 2"));
}
#[test]
fn in_memory_create_insert_select() {
let mut db = Database::open_in_memory();
db.execute("CREATE TABLE t (id INT NOT NULL, name TEXT)")
.unwrap();
db.execute("INSERT INTO t VALUES (1, 'alice')").unwrap();
db.execute("INSERT INTO t VALUES (2, 'bob')").unwrap();
let rows = db.query("SELECT id FROM t WHERE id = 1").unwrap();
assert_eq!(rows.len(), 1);
match &rows[0][0] {
Value::Int(1) => {}
other => panic!("expected Int(1), got {other:?}"),
}
}
#[test]
fn query_on_non_select_errors() {
let mut db = Database::open_in_memory();
db.execute("CREATE TABLE t (id INT)").unwrap();
let r = db.query("INSERT INTO t VALUES (1)");
assert!(r.is_err(), "query() on INSERT must error");
}
#[test]
fn snapshot_roundtrip() {
let mut db = Database::open_in_memory();
db.execute("CREATE TABLE t (id INT NOT NULL)").unwrap();
db.execute("INSERT INTO t VALUES (42)").unwrap();
let bytes = db.snapshot();
let mut restored = Database::restore(&bytes).unwrap();
let rows = restored.query("SELECT id FROM t WHERE id = 42").unwrap();
assert_eq!(rows.len(), 1);
match &rows[0][0] {
Value::Int(42) => {}
other => panic!("expected Int(42), got {other:?}"),
}
}
#[test]
fn from_spg_row_trait_shape() {
struct User {
_id: i32,
}
impl FromSpgRow for User {
fn from_spg_row(row: &[Value]) -> Result<Self, EngineError> {
match row.first() {
Some(Value::Int(n)) => Ok(Self { _id: *n }),
_ => Err(EngineError::Unsupported("bad id".into())),
}
}
}
let row = vec![Value::Int(7)];
let _u = User::from_spg_row(&row).unwrap();
}
}