mod alloc_budget;
mod backup;
mod flusher;
mod freezer;
mod manifest;
mod observability;
mod prefetch;
mod pubsub;
thread_local! {
static PREFETCH_HITS_BOOT: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
}
mod pgwire;
mod replication;
mod scram;
use std::collections::{BTreeMap, VecDeque};
use std::env;
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::path::{Path, PathBuf};
use std::process;
use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU8, AtomicUsize, Ordering};
use std::sync::mpsc::{self, SyncSender};
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use spg_audit::AuditLog;
use spg_engine::{Engine, EngineError, QueryResult, Role};
use spg_storage::{ColumnSchema, DataType, Row, Value};
use spg_wire::{
ColumnDesc, Frame, FrameError, Op, WireType, WireValue, build_command_complete, build_data_row,
build_data_row_batch, build_error_response, build_row_description, build_stats_response,
decode, encode, parse_auth, parse_auth_user, parse_query,
};
#[global_allocator]
static GLOBAL_ALLOC: alloc_budget::BudgetAllocator = alloc_budget::BudgetAllocator;
const DEFAULT_MAX_QUERY_BYTES: u64 = 256 * 1024 * 1024;
const DEFAULT_ADDR: &str = "127.0.0.1:5544";
const READ_CHUNK: usize = 4096;
const BATCH_ROWS_PER_FRAME: usize = 256;
const SHUTDOWN_POLL: Duration = Duration::from_millis(50);
const DEFAULT_SHUTDOWN_DEADLINE_SEC: u64 = 30;
const DEFAULT_COMMIT_GROUP_MAX: usize = 16;
static SHUTDOWN_FLAG: AtomicBool = AtomicBool::new(false);
#[derive(Debug, Default, Clone, Copy)]
struct Limits {
max_connections: Option<usize>,
max_query_rows: Option<usize>,
max_query_bytes: Option<u64>,
query_timeout_ms: Option<u64>,
max_query_ns: Option<u64>,
idle_timeout_sec: Option<u64>,
slow_query_log_ms: Option<u64>,
wal_min_free_bytes: Option<u64>,
shutdown_deadline_sec: Option<u64>,
}
#[derive(Debug, Default, Clone, Copy)]
struct ChaosKnobs {
wal_quota_bytes: Option<u64>,
disable_wal_preflight: bool,
}
struct CommitResult {
result: Result<QueryResult, EngineError>,
wal_outcome: std::io::Result<()>,
}
struct CommitTask {
sql: String,
cancel_flag: Arc<AtomicBool>,
ack: SyncSender<CommitResult>,
}
struct CommitQueueState {
pending: VecDeque<CommitTask>,
leader_active: bool,
}
pub(crate) struct ServerState {
pub(crate) engine: RwLock<Engine>,
db_path: Option<PathBuf>,
audit_log: Mutex<AuditLog>,
audit_path: Option<PathBuf>,
wal: Option<Mutex<File>>,
wal_sync_clone: Option<Arc<File>>,
wal_path: Option<PathBuf>,
commit_queue: Mutex<CommitQueueState>,
password: Option<String>,
limits: Limits,
active_connections: AtomicUsize,
metrics: Arc<observability::Metrics>,
sub_workers: Mutex<BTreeMap<String, Arc<AtomicBool>>>,
pub(crate) cluster_id: u64,
pub(crate) wal_level: AtomicU8,
chaos: ChaosKnobs,
pub(crate) lag_state: Arc<replication::LagState>,
cold_preload: Vec<ColdPreloadSpec>,
cold_preload_done: AtomicBool,
pub(crate) hot_tier_byte_budget: u64,
pub(crate) cold_segment_paths: Mutex<BTreeMap<u32, PathBuf>>,
pub(crate) connections: RwLock<Vec<Arc<ConnState>>>,
}
pub(crate) struct ConnState {
pub(crate) pid: u32,
pub(crate) user: String,
pub(crate) started_at_us: i64,
pub(crate) current_sql: RwLock<String>,
pub(crate) wait_event: AtomicU8,
pub(crate) last_query_start_us: AtomicI64,
pub(crate) in_transaction: AtomicBool,
}
impl ConnState {
pub(crate) fn elapsed_us(&self) -> i64 {
let start = self.last_query_start_us.load(Ordering::Relaxed);
if start == 0 {
return 0;
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.ok()
.map_or(0, |d| d.as_micros() as i64);
(now - start).max(0)
}
pub(crate) fn wait_event_str(&self) -> &'static str {
match self.wait_event.load(Ordering::Relaxed) {
1 => "write_lock",
2 => "fsync",
3 => "group_commit",
_ => "",
}
}
}
pub(crate) static ACTIVITY_STATE: std::sync::OnceLock<Arc<ServerState>> =
std::sync::OnceLock::new();
pub(crate) fn audit_chain_snapshot() -> Vec<spg_engine::AuditRow> {
let Some(state) = ACTIVITY_STATE.get() else {
return Vec::new();
};
let Ok(log) = state.audit_log.lock() else {
return Vec::new();
};
log.entries()
.iter()
.map(|e| spg_engine::AuditRow {
seq: i64::try_from(e.seq).unwrap_or(i64::MAX),
ts_ms: i64::try_from(e.ts_ms).unwrap_or(i64::MAX),
prev_hash_hex: hex_encode(&e.prev_hash),
entry_hash_hex: hex_encode(&e.hash),
sql: e.sql.clone(),
})
.collect()
}
pub(crate) fn audit_verify_snapshot() -> (i64, i64) {
let Some(state) = ACTIVITY_STATE.get() else {
return (0, -1);
};
let Ok(log) = state.audit_log.lock() else {
return (0, -1);
};
let n = log.entries().len() as i64;
match log.verify() {
Ok(()) => (n, -1),
Err(spg_audit::AuditError::BrokenChain { seq })
| Err(spg_audit::AuditError::HashMismatch { seq })
| Err(spg_audit::AuditError::InvalidUtf8 { seq }) => {
(i64::try_from(seq).unwrap_or(i64::MAX), i64::try_from(seq).unwrap_or(i64::MAX))
}
Err(_) => (0, 0),
}
}
pub(crate) fn log_slow_query(sql: &str, elapsed_us: u64) {
let elapsed_str = elapsed_us.to_string();
observability::log_event(
"warn",
"slow_query",
&[("elapsed_us", &elapsed_str), ("sql", sql)],
);
}
fn hex_encode(bytes: &[u8]) -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut out = String::with_capacity(bytes.len() * 2);
for &b in bytes {
out.push(HEX[(b >> 4) as usize] as char);
out.push(HEX[(b & 0xf) as usize] as char);
}
out
}
pub(crate) fn activity_snapshot() -> Vec<spg_engine::ActivityRow> {
let Some(state) = ACTIVITY_STATE.get() else {
return Vec::new();
};
let Ok(conns) = state.connections.read() else {
return Vec::new();
};
conns
.iter()
.map(|c| {
let current_sql = c
.current_sql
.read()
.map(|g| g.clone())
.unwrap_or_default();
spg_engine::ActivityRow {
pid: c.pid,
user: c.user.clone(),
started_at_us: c.started_at_us,
current_sql,
wait_event: c.wait_event_str().to_string(),
elapsed_us: c.elapsed_us(),
in_transaction: c.in_transaction.load(Ordering::Relaxed),
}
})
.collect()
}
pub(crate) const DEFAULT_HOT_TIER_BYTES: u64 = 4 * 1024 * 1024 * 1024;
struct ColdPreloadSpec {
table: String,
index: String,
path: PathBuf,
loaded: AtomicBool,
}
fn parse_optional_path(arg: Option<String>) -> Option<PathBuf> {
arg.filter(|s| !s.is_empty() && s != "-").map(PathBuf::from)
}
fn resolve_path(cli: Option<String>, env_key: &str) -> Option<PathBuf> {
parse_optional_path(cli).or_else(|| {
env::var(env_key)
.ok()
.and_then(|s| parse_optional_path(Some(s)))
})
}
fn main() {
let raw_args: Vec<String> = env::args().skip(1).collect();
let replay_only = raw_args.iter().any(|a| a == "--replay-only");
let mut args = raw_args.into_iter().filter(|a| a != "--replay-only");
let addr = args
.next()
.or_else(|| env::var("SPG_ADDR").ok())
.unwrap_or_else(|| DEFAULT_ADDR.to_string());
let db_path = resolve_path(args.next(), "SPG_DB");
let audit_path = resolve_path(args.next(), "SPG_AUDIT");
let wal_path = resolve_path(args.next(), "SPG_WAL");
let password = env::var("SPG_PASSWORD").ok().filter(|s| !s.is_empty());
let limits = Limits {
max_connections: parse_env_usize("SPG_MAX_CONNECTIONS"),
max_query_rows: parse_env_usize("SPG_MAX_QUERY_ROWS"),
max_query_bytes: parse_env_u64("SPG_MAX_QUERY_BYTES"),
query_timeout_ms: parse_env_u64("SPG_QUERY_TIMEOUT_MS"),
max_query_ns: parse_env_u64("SPG_MAX_QUERY_NS"),
idle_timeout_sec: parse_env_u64("SPG_IDLE_TIMEOUT_SEC"),
slow_query_log_ms: parse_env_u64("SPG_SLOW_QUERY_LOG_MS"),
wal_min_free_bytes: parse_env_u64("SPG_WAL_MIN_FREE_BYTES"),
shutdown_deadline_sec: parse_env_u64("SPG_SHUTDOWN_DEADLINE_SEC"),
};
install_shutdown_handlers();
if replay_only {
if let Err(e) = run_replay_only(db_path, wal_path) {
eprintln!("spg-server: replay-only fatal: {e}");
process::exit(1);
}
eprintln!("spg-server: --replay-only complete; exiting 0");
return;
}
if let Err(e) = run(&addr, db_path, audit_path, wal_path, password, limits) {
eprintln!("spg-server: fatal: {e}");
process::exit(1);
}
}
fn run_replay_only(
db_path: Option<PathBuf>,
wal_path: Option<PathBuf>,
) -> std::io::Result<()> {
let mut engine = match db_path.as_deref() {
Some(p) if p.exists() => {
let bytes = fs::read(p)?;
let engine = Engine::restore_envelope(&bytes)
.map_err(|e| std::io::Error::other(format!("restore: {e}")))?;
eprintln!(
"spg-server: --replay-only restored {} table(s) from {}",
engine.catalog().table_count(),
p.display()
);
engine
}
_ => Engine::new(),
};
if let Some(w) = wal_path.as_deref() {
if w.exists() {
let wal_bytes = fs::read(w)?;
let applied = replay_wal_bytes(&wal_bytes, &mut engine)?;
eprintln!(
"spg-server: --replay-only applied {applied} WAL record(s) from {}",
w.display()
);
} else {
eprintln!(
"spg-server: --replay-only WAL path {} doesn't exist; nothing to replay",
w.display()
);
}
}
let _ = engine.snapshot();
Ok(())
}
fn parse_env_usize(env_key: &str) -> Option<usize> {
env::var(env_key)
.ok()
.and_then(|s| s.trim().parse::<usize>().ok())
.filter(|&n| n > 0)
}
fn parse_env_u64(env_key: &str) -> Option<u64> {
env::var(env_key)
.ok()
.and_then(|s| s.trim().parse::<u64>().ok())
.filter(|&n| n > 0)
}
fn parse_cold_preload_env() -> Vec<ColdPreloadSpec> {
let Ok(raw) = env::var("SPG_PRELOAD_COLD_SEGMENT") else {
return Vec::new();
};
let mut out = Vec::new();
for entry in raw.split(';').map(str::trim).filter(|s| !s.is_empty()) {
let parts: Vec<&str> = entry.splitn(3, ':').collect();
if parts.len() != 3 {
eprintln!(
"spg-server: SPG_PRELOAD_COLD_SEGMENT entry {entry:?} \
ignored — expected `table:index:path`"
);
continue;
}
let table = parts[0].trim().to_string();
let index = parts[1].trim().to_string();
let path = PathBuf::from(parts[2].trim());
if table.is_empty() || index.is_empty() || path.as_os_str().is_empty() {
eprintln!(
"spg-server: SPG_PRELOAD_COLD_SEGMENT entry {entry:?} \
ignored — empty table / index / path"
);
continue;
}
out.push(ColdPreloadSpec {
table,
index,
path,
loaded: AtomicBool::new(false),
});
}
if !out.is_empty() {
eprintln!(
"spg-server: cold-tier preload queue has {} spec(s); each one \
will load on the first Op::Query after its table + index \
both exist",
out.len()
);
}
out
}
#[allow(
clippy::too_many_lines,
reason = "single-purpose preload routine; splitting hurts readability more than the line count helps"
)]
fn cluster_id_sidecar_path(wal_path: Option<&Path>, db_path: Option<&Path>) -> Option<PathBuf> {
let base = wal_path.or(db_path)?;
let mut name = base
.file_name()
.map(std::ffi::OsStr::to_os_string)
.unwrap_or_default();
name.push(".cluster_id");
Some(base.with_file_name(name))
}
fn load_or_generate_cluster_id(wal_path: Option<&Path>, db_path: Option<&Path>) -> u64 {
if let Some(p) = cluster_id_sidecar_path(wal_path, db_path) {
if p.exists()
&& let Ok(bytes) = std::fs::read(&p)
&& bytes.len() == 8
{
return u64::from_le_bytes(bytes.try_into().unwrap());
}
let id = generate_cluster_id();
if let Err(e) = std::fs::write(&p, id.to_le_bytes()) {
eprintln!(
"spg-server: cluster_id sidecar write to {} failed: {e} — \
keeping in-memory id (cycle detection won't survive restart)",
p.display()
);
}
id
} else {
generate_cluster_id()
}
}
fn generate_cluster_id() -> u64 {
let pid = u64::from(std::process::id());
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| u64::try_from(d.as_nanos()).unwrap_or(u64::MAX));
let mut x = ts.wrapping_mul(6364136223846793005).wrapping_add(pid);
x = (x ^ (x >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
x = (x ^ (x >> 27)).wrapping_mul(0x94d049bb133111eb);
x ^ (x >> 31)
}
pub(crate) fn auto_analyze_interval_ms() -> u64 {
std::env::var("SPG_AUTO_ANALYZE_INTERVAL_MS")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(30_000)
}
pub(crate) fn spawn_auto_analyze_worker(state: Arc<ServerState>) {
let interval = std::time::Duration::from_millis(auto_analyze_interval_ms());
if interval.is_zero() {
return;
}
thread::Builder::new()
.name("spg-auto-analyze".into())
.spawn(move || {
run_auto_analyze_loop(state, interval);
})
.ok();
}
const AUTO_ANALYZE_TICK: std::time::Duration = std::time::Duration::from_millis(200);
fn run_auto_analyze_loop(state: Arc<ServerState>, interval: std::time::Duration) {
let mut last_sweep = std::time::Instant::now();
loop {
thread::sleep(AUTO_ANALYZE_TICK);
if last_sweep.elapsed() < interval {
continue;
}
last_sweep = std::time::Instant::now();
let needs: Vec<String> = {
let Ok(eng) = state.engine.read() else {
continue;
};
eng.tables_needing_analyze()
};
if needs.is_empty() {
continue;
}
for table in &needs {
let Ok(mut eng) = state.engine.write() else {
break;
};
if eng.catalog().get(table).is_none() {
continue;
}
if let Err(e) = eng.execute(&format!("ANALYZE {}", quote_ident_simple(table))) {
eprintln!("spg-server: auto-analyze {table:?} failed: {e}");
}
}
}
}
fn quote_ident_simple(name: &str) -> String {
let needs_quote = name.is_empty()
|| name
.chars()
.any(|c| !(c.is_ascii_alphanumeric() || c == '_'))
|| name.starts_with(|c: char| c.is_ascii_digit());
if needs_quote {
let escaped: String = name
.chars()
.flat_map(|c| if c == '"' { vec!['"', '"'] } else { vec![c] })
.collect();
format!("\"{escaped}\"")
} else {
name.to_string()
}
}
pub(crate) fn reconcile_subscriptions(state: &Arc<ServerState>) {
use std::collections::BTreeMap;
let want: BTreeMap<String, (String, bool)> = {
let Ok(eng) = state.engine.read() else {
return;
};
eng.subscriptions()
.iter()
.map(|(n, s)| (n.clone(), (s.conn_str.clone(), s.enabled)))
.collect()
};
let Ok(mut workers) = state.sub_workers.lock() else {
return;
};
let stale: Vec<String> = workers
.keys()
.filter(|n| !want.contains_key(n.as_str()))
.cloned()
.collect();
for name in stale {
if let Some(flag) = workers.remove(&name) {
flag.store(true, std::sync::atomic::Ordering::Release);
}
}
for (name, (conn_str, enabled)) in &want {
if !enabled {
continue;
}
if workers.contains_key(name) {
continue;
}
let flag = Arc::new(AtomicBool::new(false));
let flag_for_worker = Arc::clone(&flag);
let state_for_worker = Arc::clone(state);
let name_clone = name.clone();
let conn_clone = conn_str.clone();
let thread_name = format!("spg-sub-{}", name.chars().take(20).collect::<String>());
thread::Builder::new()
.name(thread_name)
.spawn(move || {
replication::run_subscription_worker(
name_clone,
conn_clone,
state_for_worker,
flag_for_worker,
);
})
.ok();
workers.insert(name.clone(), flag);
}
}
pub(crate) fn try_lazy_preload_cold(state: &ServerState) {
if state.cold_preload_done.load(Ordering::Relaxed) {
return;
}
let mut still_pending = 0usize;
for spec in &state.cold_preload {
if spec.loaded.load(Ordering::Relaxed) {
continue;
}
let ready = {
let Ok(engine) = state.engine.read() else {
return;
};
let cat = engine.catalog();
cat.get(&spec.table)
.is_some_and(|t| t.indices().iter().any(|i| i.name == spec.index))
};
if !ready {
still_pending += 1;
continue;
}
let bytes = match std::fs::read(&spec.path) {
Ok(b) => b,
Err(e) => {
eprintln!(
"spg-server: cold preload {}:{} from {} failed: {e}; \
marking loaded to avoid retry storm",
spec.table,
spec.index,
spec.path.display()
);
spec.loaded.store(true, Ordering::Relaxed);
continue;
}
};
let Ok(mut engine) = state.engine.write() else {
return;
};
let mut cat = engine.catalog().clone();
let seg_id = match cat.load_segment_bytes(bytes) {
Ok(id) => id,
Err(e) => {
eprintln!(
"spg-server: cold preload {}:{} parse failed: {e}",
spec.table, spec.index
);
spec.loaded.store(true, Ordering::Relaxed);
continue;
}
};
let pairs: Vec<(spg_storage::IndexKey, spg_storage::RowLocator)> = {
let Some(seg) = cat.cold_segment(seg_id) else {
eprintln!(
"spg-server: cold preload {}:{} segment_id {seg_id} \
vanished after load — should be impossible",
spec.table, spec.index
);
spec.loaded.store(true, Ordering::Relaxed);
continue;
};
seg.scan()
.map(|(key, _payload)| {
(
spg_storage::IndexKey::Int(
i64::try_from(key).expect("cold-segment PK fits in i64"),
),
spg_storage::RowLocator::Cold {
segment_id: seg_id,
page_offset: 0,
},
)
})
.collect()
};
let pairs_count = pairs.len();
let Some(table_mut) = cat.get_mut(&spec.table) else {
eprintln!(
"spg-server: cold preload {}:{} table disappeared mid-load",
spec.table, spec.index
);
spec.loaded.store(true, Ordering::Relaxed);
continue;
};
if let Err(e) = table_mut.register_cold_locators(&spec.index, pairs) {
eprintln!(
"spg-server: cold preload {}:{} register_cold_locators failed: {e}",
spec.table, spec.index
);
spec.loaded.store(true, Ordering::Relaxed);
continue;
}
engine.replace_catalog(cat);
spec.loaded.store(true, Ordering::Relaxed);
if let Ok(mut paths) = state.cold_segment_paths.lock() {
paths.insert(seg_id, spec.path.clone());
}
eprintln!(
"spg-server: cold preload {}:{} loaded {} row(s) from {}",
spec.table,
spec.index,
pairs_count,
spec.path.display()
);
}
if still_pending == 0 {
state.cold_preload_done.store(true, Ordering::Relaxed);
}
}
#[allow(clippy::too_many_lines)] fn run(
addr: &str,
db_path: Option<PathBuf>,
audit_path: Option<PathBuf>,
wal_path: Option<PathBuf>,
password: Option<String>,
limits: Limits,
) -> std::io::Result<()> {
let mut cold_segment_paths: BTreeMap<u32, PathBuf> = BTreeMap::new();
let mut manifest_wal_baseline: u64 = 0;
let mut engine = match &db_path {
Some(p) if p.exists() => {
let bytes = fs::read(p)?;
let path_str = p.display();
let mut engine = Engine::restore_envelope(&bytes)
.map_err(|e| std::io::Error::other(format!("restore from {path_str}: {e}")))?;
eprintln!(
"spg-server: restored {} table(s), {} user(s) from {path_str}",
engine.catalog().table_count(),
engine.users().len()
);
manifest_wal_baseline =
load_manifest_and_preload_cold(&mut engine, p, &bytes, &mut cold_segment_paths);
engine
}
Some(p) => {
eprintln!(
"spg-server: db file {} does not exist yet — starting fresh",
p.display()
);
Engine::new()
}
None => Engine::new(),
}
.with_clock(wall_clock_micros)
.with_salt_fn(urandom_salt_or_panic);
if let Some(n) = limits.max_query_rows {
engine = engine.with_max_query_rows(n);
}
let audit_log = match &audit_path {
Some(p) if p.exists() => {
let bytes = fs::read(p)?;
let log = AuditLog::deserialize(&bytes).map_err(|e| {
std::io::Error::other(format!("audit log {} rejected: {e}", p.display()))
})?;
eprintln!(
"spg-server: verified audit log {} ({} entries)",
p.display(),
log.len()
);
log
}
Some(p) => {
fs::write(p, AuditLog::header_bytes())?;
eprintln!("spg-server: started fresh audit log at {}", p.display());
AuditLog::new()
}
None => AuditLog::new(),
};
if let Some(p) = &wal_path
&& p.exists()
{
let mut bytes = fs::read(p)?;
if let Ok(s) = env::var("SPG_REPLAY_UPTO")
&& let Ok(upto) = s.trim().parse::<u64>()
{
let upto_usize = usize::try_from(upto).unwrap_or(usize::MAX);
if bytes.len() > upto_usize {
eprintln!(
"spg-server: PITR — truncating WAL replay at offset {upto} \
(of {} total bytes)",
bytes.len()
);
bytes.truncate(upto_usize);
}
}
let baseline_usize = usize::try_from(manifest_wal_baseline).unwrap_or(usize::MAX);
if baseline_usize > 0 && baseline_usize <= bytes.len() {
eprintln!(
"spg-server: manifest skip — WAL replay starts at offset {manifest_wal_baseline} \
(of {} total bytes)",
bytes.len()
);
bytes.drain(..baseline_usize);
} else if baseline_usize > bytes.len() {
eprintln!(
"spg-server: manifest WAL baseline {manifest_wal_baseline} exceeds file size {}; \
replaying from start as a safety net",
bytes.len()
);
}
let applied = replay_wal_bytes(&bytes, &mut engine)?;
eprintln!(
"spg-server: replayed {} WAL entries from {}",
applied,
p.display()
);
if engine.in_transaction() {
eprintln!("spg-server: WAL ended mid-transaction — auto-rollback");
engine
.execute("ROLLBACK")
.map_err(|e| std::io::Error::other(format!("post-replay rollback: {e}")))?;
}
} else if let Some(p) = &wal_path {
fs::write(p, b"")?;
eprintln!("spg-server: started fresh WAL at {}", p.display());
}
bootstrap_admin_from_env(&mut engine, db_path.as_deref())?;
let (wal, wal_sync_clone) = match &wal_path {
Some(p) => {
let file = OpenOptions::new().append(true).open(p).map_err(|e| {
std::io::Error::other(format!("open WAL {} for append: {e}", p.display()))
})?;
let sync_clone = file.try_clone().ok().map(Arc::new);
(Some(Mutex::new(file)), sync_clone)
}
None => (None, None),
};
let auth_msg = if password.is_some() {
" (AUTH required)"
} else {
""
};
let chaos = ChaosKnobs {
wal_quota_bytes: parse_env_u64("SPG_FAIL_WAL_QUOTA_BYTES"),
disable_wal_preflight: env::var("SPG_DISABLE_WAL_PREFLIGHT")
.ok()
.is_some_and(|s| !s.is_empty() && s != "0"),
};
let cold_preload = parse_cold_preload_env();
let cold_preload_done = AtomicBool::new(cold_preload.is_empty());
let hot_tier_byte_budget =
parse_env_u64("SPG_HOT_TIER_BYTES").unwrap_or(DEFAULT_HOT_TIER_BYTES);
let cluster_id = load_or_generate_cluster_id(wal_path.as_deref(), db_path.as_deref());
let state = Arc::new(ServerState {
engine: RwLock::new(engine),
db_path,
audit_log: Mutex::new(audit_log),
audit_path,
wal,
wal_sync_clone,
wal_path,
commit_queue: Mutex::new(CommitQueueState {
pending: VecDeque::new(),
leader_active: false,
}),
password,
limits,
active_connections: AtomicUsize::new(0),
metrics: Arc::new(observability::Metrics::default()),
chaos,
lag_state: Arc::new(replication::LagState::default()),
cold_preload,
cold_preload_done,
hot_tier_byte_budget,
cold_segment_paths: Mutex::new(cold_segment_paths),
sub_workers: Mutex::new(BTreeMap::new()),
cluster_id,
wal_level: AtomicU8::new(parse_wal_level_env()),
connections: RwLock::new(Vec::new()),
});
let _ = ACTIVITY_STATE.set(Arc::clone(&state));
PREFETCH_HITS_BOOT.with(|cell| {
let hits = cell.take();
if hits > 0 {
state
.metrics
.cold_prefetch_hits
.store(hits, std::sync::atomic::Ordering::Relaxed);
}
});
if let Ok(mut e) = state.engine.write() {
let prev = std::mem::replace(&mut *e, Engine::new());
let slow_us: u64 = std::env::var("SPG_SLOW_QUERY_THRESHOLD_MS")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(100)
* 1_000;
*e = prev
.with_activity_provider(activity_snapshot)
.with_audit_providers(audit_chain_snapshot, audit_verify_snapshot)
.with_slow_query_log(slow_us, log_slow_query);
if let Ok(s) = std::env::var("SPG_PLAN_CACHE_MAX")
&& let Ok(n) = s.parse::<usize>()
{
e.set_plan_cache_max(n);
}
}
reconcile_subscriptions(&state);
spawn_auto_analyze_worker(Arc::clone(&state));
let listener = TcpListener::bind(addr)?;
let local = listener.local_addr()?;
eprintln!("spg-server: listening on {local}{auth_msg}");
if let Ok(pg_addr) = env::var("SPG_PG_ADDR")
&& !pg_addr.is_empty()
{
match pgwire::spawn_listener(&pg_addr, Arc::clone(&state)) {
Ok(pg_local) => eprintln!("spg-server: pg-wire listening on {pg_local}"),
Err(e) => eprintln!("spg-server: pg-wire failed to start on {pg_addr}: {e}"),
}
}
if let Ok(http_addr) = env::var("SPG_HTTP_ADDR")
&& !http_addr.is_empty()
{
match observability::spawn_http(&http_addr, Arc::clone(&state)) {
Ok(http_local) => eprintln!("spg-server: http listening on {http_local}"),
Err(e) => eprintln!("spg-server: http failed to start on {http_addr}: {e}"),
}
}
if let Ok(repl_addr) = env::var("SPG_REPL_ADDR")
&& !repl_addr.is_empty()
{
match replication::spawn_master_listener(&repl_addr, Arc::clone(&state)) {
Ok(repl_local) => {
eprintln!("spg-server: replication listening on {repl_local}");
}
Err(e) => {
eprintln!("spg-server: replication failed to start on {repl_addr}: {e}");
}
}
}
if let Ok(master_addr) = env::var("SPG_FOLLOW_OF")
&& !master_addr.is_empty()
{
if let (Some(db), Some(wal)) = (state.db_path.clone(), state.wal_path.clone()) {
let state_for_follower = Arc::clone(&state);
thread::Builder::new()
.name("spg-follower".into())
.spawn(move || {
replication::run_follower(master_addr, db, wal, state_for_follower);
})
.ok();
eprintln!("spg-server: started as follower");
} else {
eprintln!(
"spg-server: SPG_FOLLOW_OF set but db_path or wal_path missing — \
follower mode requires both"
);
}
}
if freezer::spawn(Arc::clone(&state)).is_none() {
eprintln!("spg-server: freezer disabled via SPG_FREEZER_DISABLE");
}
if flusher::spawn(Arc::clone(&state)).is_none() {
}
spawn_shutdown_waker(&listener)?;
for stream in listener.incoming() {
if SHUTDOWN_FLAG.load(Ordering::Acquire) {
drop(stream); break;
}
let mut stream = stream?;
let guard = ConnectionGuard::try_claim(&state);
let Some(guard) = guard else {
let peer = stream.peer_addr().ok();
let _ = write_frame(
&mut stream,
&build_error_response(&format!(
"max_connections reached ({} active)",
state.limits.max_connections.unwrap_or(0)
)),
);
eprintln!("spg-server: rejected {peer:?}: max_connections reached");
continue;
};
let state_for_thread = Arc::clone(&state);
thread::spawn(move || {
let _guard = guard; let peer = stream.peer_addr().ok();
if let Err(e) = handle(stream, &state_for_thread) {
eprintln!("spg-server: conn {peer:?}: {e}");
}
});
}
drain_connections(&state);
Ok(())
}
fn spawn_shutdown_waker(listener: &TcpListener) -> std::io::Result<()> {
let local = listener.local_addr()?;
thread::Builder::new()
.name("spg-shutdown-waker".into())
.spawn(move || {
while !SHUTDOWN_FLAG.load(Ordering::Acquire) {
thread::sleep(SHUTDOWN_POLL);
}
let _ = TcpStream::connect(local);
})?;
Ok(())
}
fn drain_connections(state: &ServerState) {
let deadline_sec = state
.limits
.shutdown_deadline_sec
.unwrap_or(DEFAULT_SHUTDOWN_DEADLINE_SEC);
let started = Instant::now();
let budget = Duration::from_secs(deadline_sec);
eprintln!(
"spg-server: shutdown signal received — draining {} connection(s), deadline {}s",
state.active_connections.load(Ordering::Acquire),
deadline_sec,
);
loop {
let active = state.active_connections.load(Ordering::Acquire);
if active == 0 {
eprintln!("spg-server: drained — exiting 0");
return;
}
if started.elapsed() >= budget {
eprintln!(
"spg-server: drain deadline hit with {active} connection(s) still active — exiting 0"
);
return;
}
thread::sleep(SHUTDOWN_POLL);
}
}
#[allow(unsafe_code)]
fn install_shutdown_handlers() {
extern "C" fn handler(_sig: libc::c_int) {
SHUTDOWN_FLAG.store(true, Ordering::Release);
}
unsafe {
libc::signal(libc::SIGTERM, handler as *const () as libc::sighandler_t);
libc::signal(libc::SIGINT, handler as *const () as libc::sighandler_t);
}
}
struct ConnectionGuard {
state: Arc<ServerState>,
}
impl ConnectionGuard {
fn try_claim(state: &Arc<ServerState>) -> Option<Self> {
let max = state.limits.max_connections;
loop {
let current = state.active_connections.load(Ordering::Acquire);
if let Some(cap) = max
&& current >= cap
{
return None;
}
if state
.active_connections
.compare_exchange_weak(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return Some(Self {
state: Arc::clone(state),
});
}
}
}
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
self.state.active_connections.fetch_sub(1, Ordering::AcqRel);
}
}
fn handle(mut stream: TcpStream, state: &Arc<ServerState>) -> std::io::Result<()> {
let _ = stream.set_nodelay(true);
if let Some(secs) = state.limits.idle_timeout_sec {
let _ = stream.set_read_timeout(Some(std::time::Duration::from_secs(secs)));
}
let mut buf: Vec<u8> = Vec::with_capacity(READ_CHUNK);
let mut chunk = [0u8; READ_CHUNK];
let mut role = initial_role(state)?;
let mut in_tx = false;
loop {
let n = match stream.read(&mut chunk) {
Ok(n) => n,
Err(e)
if matches!(
e.kind(),
std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut
) =>
{
let _ = write_frame(
&mut stream,
&build_error_response("idle timeout reached, closing connection"),
);
return Ok(());
}
Err(e) => return Err(e),
};
if n == 0 {
return Ok(());
}
buf.extend_from_slice(&chunk[..n]);
loop {
match decode(&buf) {
Ok((frame, consumed)) => {
buf.drain(..consumed);
dispatch(&mut stream, &frame, state, &mut role, &mut in_tx)?;
}
Err(FrameError::ShortHeader | FrameError::ShortPayload) => break,
Err(e) => {
let _ = write_frame(&mut stream, &build_error_response(&e.to_string()));
return Err(std::io::Error::other(e.to_string()));
}
}
}
}
}
fn initial_role(state: &ServerState) -> std::io::Result<Option<Role>> {
let has_users = {
let engine = state
.engine
.read()
.map_err(|_| std::io::Error::other("engine rwlock poisoned"))?;
!engine.users().is_empty()
};
if has_users || state.password.is_some() {
Ok(None)
} else {
Ok(Some(Role::Admin))
}
}
fn current_role(role: Option<Role>) -> Role {
role.expect("dispatch already gated on role.is_some()")
}
fn sql_is_user_mgmt(sql: &str) -> bool {
let lower = sql.trim_start().to_ascii_lowercase();
(lower.starts_with("create ") && lower["create ".len()..].trim_start().starts_with("user"))
|| (lower.starts_with("drop ") && lower["drop ".len()..].trim_start().starts_with("user"))
}
fn sql_is_tx_control(sql: &str) -> bool {
let lower = sql.trim_start().to_ascii_lowercase();
let first_word = lower
.split(|c: char| c.is_whitespace() || c == ';')
.next()
.unwrap_or("");
matches!(
first_word,
"begin" | "start" | "commit" | "rollback" | "savepoint" | "release" | "end"
)
}
fn sql_is_read_only(sql: &str) -> bool {
let bytes = sql.as_bytes();
let mut i = 0;
while i < bytes.len()
&& (bytes[i] == b' ' || bytes[i] == b'\t' || bytes[i] == b'\n' || bytes[i] == b'\r')
{
i += 1;
}
let start = i;
while i < bytes.len() && bytes[i].is_ascii_alphabetic() {
i += 1;
}
let kw = &bytes[start..i];
matches!(kw.len(), 4 | 6) && {
let mut lower = [0u8; 6];
for (k, b) in kw.iter().enumerate() {
lower[k] = b.to_ascii_lowercase();
}
let s = &lower[..kw.len()];
s == b"select" || s == b"show"
}
}
pub(crate) const WAL_LEVEL_REPLICA: u8 = 0;
pub(crate) const WAL_LEVEL_LOGICAL: u8 = 1;
fn parse_wal_level_env() -> u8 {
match std::env::var("SPG_WAL_LEVEL")
.ok()
.map(|s| s.to_ascii_lowercase())
.as_deref()
{
None | Some("") | Some("replica") => WAL_LEVEL_REPLICA,
Some("logical") => WAL_LEVEL_LOGICAL,
Some(other) => {
eprintln!(
"spg-server: SPG_WAL_LEVEL={other:?} unknown — defaulting to replica. \
Valid values: replica, logical"
);
WAL_LEVEL_REPLICA
}
}
}
pub(crate) fn wal_level_label(v: u8) -> &'static str {
match v {
WAL_LEVEL_LOGICAL => "logical",
_ => "replica",
}
}
fn sql_looks_like_set_wal_level(sql: &str) -> bool {
let trimmed = sql.trim_start().to_ascii_lowercase();
trimmed.starts_with("set effective_wal_level")
}
fn sql_looks_like_show_wal_level(sql: &str) -> bool {
let trimmed = sql.trim_start().to_ascii_lowercase();
trimmed == "show effective_wal_level"
|| trimmed.starts_with("show effective_wal_level ")
|| trimmed.starts_with("show effective_wal_level;")
}
fn parse_set_wal_level_value(sql: &str) -> Result<u8, String> {
let lower = sql.trim().to_ascii_lowercase();
let rest = lower
.strip_prefix("set effective_wal_level")
.ok_or_else(|| "expected `set effective_wal_level …`".to_string())?
.trim_start();
let val_part = if let Some(r) = rest.strip_prefix('=') {
r.trim()
} else if let Some(r) = rest.strip_prefix("to ") {
r.trim()
} else {
return Err("expected `=` or `TO` after effective_wal_level".to_string());
};
let value = val_part
.trim_matches(|c: char| matches!(c, '\'' | '"' | ';'))
.trim();
match value {
"replica" => Ok(WAL_LEVEL_REPLICA),
"logical" => Ok(WAL_LEVEL_LOGICAL),
other => Err(format!(
"unknown effective_wal_level {other:?}; expected `replica` or `logical`"
)),
}
}
fn handle_set_wal_level(
stream: &mut TcpStream,
state: &Arc<ServerState>,
sql: &str,
) -> std::io::Result<()> {
match parse_set_wal_level_value(sql) {
Ok(level) => {
state.wal_level.store(level, Ordering::Release);
emit_result(
stream,
Ok(spg_engine::QueryResult::CommandOk {
affected: 1,
modified_catalog: false,
}),
)
}
Err(msg) => write_frame(stream, &build_error_response(&msg)),
}
}
fn handle_show_wal_level(
stream: &mut TcpStream,
state: &Arc<ServerState>,
) -> std::io::Result<()> {
let level = state.wal_level.load(Ordering::Acquire);
let row = vec![Row::new(vec![Value::Text(
wal_level_label(level).to_string(),
)])];
let columns = vec![ColumnSchema::new(
"effective_wal_level",
DataType::Text,
false,
)];
emit_result(stream, Ok(spg_engine::QueryResult::Rows { columns, rows: row }))
}
fn sql_looks_like_wait_for(sql: &str) -> bool {
let trimmed = sql.trim_start();
if trimmed.len() < 4 {
return false;
}
trimmed.as_bytes()[..4]
.iter()
.zip(b"WAIT")
.all(|(a, b)| a.to_ascii_uppercase() == *b)
&& trimmed
.as_bytes()
.get(4)
.is_some_and(u8::is_ascii_whitespace)
}
fn handle_wait_for_wal_position(
stream: &mut TcpStream,
state: &Arc<ServerState>,
target: u64,
timeout_ms: Option<u64>,
) -> std::io::Result<()> {
const POLL: std::time::Duration = std::time::Duration::from_millis(5);
let deadline = timeout_ms.map(|ms| std::time::Instant::now() + std::time::Duration::from_millis(ms));
loop {
let current = state
.lag_state
.follower_applied_pos
.load(Ordering::Acquire);
if current >= target {
return emit_result(
stream,
Ok(spg_engine::QueryResult::CommandOk {
affected: 1,
modified_catalog: false,
}),
);
}
if let Some(d) = deadline
&& std::time::Instant::now() >= d
{
return emit_result(
stream,
Ok(spg_engine::QueryResult::CommandOk {
affected: 0,
modified_catalog: false,
}),
);
}
std::thread::sleep(POLL);
}
}
#[allow(clippy::too_many_lines)] fn dispatch(
stream: &mut TcpStream,
frame: &Frame,
state: &Arc<ServerState>,
role: &mut Option<Role>,
in_tx: &mut bool,
) -> std::io::Result<()> {
if role.is_none() && !matches!(frame.op, Op::Ping | Op::Auth | Op::AuthUser) {
return write_frame(
stream,
&build_error_response("authentication required: send AUTH first"),
);
}
match frame.op {
Op::Ping => write_frame(stream, &Frame::pong()),
Op::Auth => {
let candidate = match parse_auth(frame) {
Ok(s) => s,
Err(e) => return write_frame(stream, &build_error_response(&e.to_string())),
};
let users_exist = state.engine.read().is_ok_and(|e| !e.users().is_empty());
if users_exist {
return write_frame(
stream,
&build_error_response("RBAC active: use AUTH USER <name> <password>"),
);
}
let ok = match &state.password {
Some(pw) => candidate == pw,
None => true,
};
if ok {
*role = Some(Role::Admin);
write_frame(stream, &Frame::pong())
} else {
write_frame(stream, &build_error_response("AUTH: wrong password"))
}
}
Op::AuthUser => {
let (user, pw) = match parse_auth_user(frame) {
Ok(t) => t,
Err(e) => return write_frame(stream, &build_error_response(&e.to_string())),
};
let verified = state
.engine
.read()
.map_err(|_| std::io::Error::other("engine rwlock poisoned"))?
.verify_user(user, pw);
match verified {
Some(r) => {
*role = Some(r);
write_frame(stream, &Frame::pong())
}
None => write_frame(stream, &build_error_response("AUTH: invalid credentials")),
}
}
Op::Stats => {
let body = render_stats(state)?;
write_frame(stream, &build_stats_response(&body))
}
Op::Query => {
state.metrics.queries_total.fetch_add(1, Ordering::Relaxed);
try_lazy_preload_cold(state);
let sql = match parse_query(frame) {
Ok(s) => s.to_string(),
Err(e) => {
state.metrics.errors_total.fetch_add(1, Ordering::Relaxed);
return write_frame(stream, &build_error_response(&e.to_string()));
}
};
let _slow_log = SlowLogGuard::new(state, &sql, *role);
if sql_looks_like_wait_for(&sql)
&& let Ok(stmt) = spg_sql::parser::parse_statement(&sql)
&& let spg_sql::ast::Statement::WaitForWalPosition { pos, timeout_ms } = stmt
{
return handle_wait_for_wal_position(stream, state, pos, timeout_ms);
}
if sql_looks_like_show_wal_level(&sql) {
return handle_show_wal_level(stream, state);
}
if sql_looks_like_set_wal_level(&sql) {
return handle_set_wal_level(stream, state, &sql);
}
if !*in_tx && sql_is_read_only(&sql) {
let cancel_flag = Arc::new(AtomicBool::new(false));
let watchdog = spawn_query_watchdog(state, &cancel_flag);
let engine = state
.engine
.read()
.map_err(|_| std::io::Error::other("engine rwlock poisoned"))?;
let budget = usize::try_from(
state
.limits
.max_query_bytes
.unwrap_or(DEFAULT_MAX_QUERY_BYTES),
)
.unwrap_or(usize::MAX);
alloc_budget::reset_query_budget(budget, &cancel_flag);
let result = engine.execute_readonly_with_cancel(
&sql,
spg_engine::CancelToken::from_flag(&cancel_flag),
);
alloc_budget::clear_query_budget();
drop(engine);
watchdog.cancel();
if !matches!(&result, Err(EngineError::WriteRequired)) {
return emit_result(stream, result);
}
}
let acting = current_role(*role);
if sql_is_user_mgmt(&sql) {
if !acting.can_manage_users() {
return write_frame(
stream,
&build_error_response(
"permission denied: user management requires admin role",
),
);
}
} else if !acting.can_write() {
return write_frame(
stream,
&build_error_response(
"permission denied: write requires admin or readwrite role",
),
);
}
if let Some(backup_intent) = parse_backup_intent(&sql) {
if !acting.can_manage_users() {
return write_frame(
stream,
&build_error_response("permission denied: BACKUP requires admin role"),
);
}
return run_backup_command(stream, state, &backup_intent);
}
if parse_checkpoint_intent(&sql) {
if !acting.can_manage_users() {
return write_frame(
stream,
&build_error_response("permission denied: CHECKPOINT requires admin role"),
);
}
return run_checkpoint_command(stream, state);
}
if parse_compact_cold_segments_intent(&sql) {
if !acting.can_manage_users() {
return write_frame(
stream,
&build_error_response(
"permission denied: COMPACT COLD SEGMENTS requires admin role",
),
);
}
return run_compact_cold_segments_command(stream, state);
}
let needs_wrap = !*in_tx && state.wal.is_some() && !sql_is_tx_control(&sql);
if let Some(quota) = state.chaos.wal_quota_bytes
&& let Some(wal_path) = &state.wal_path
&& !state.chaos.disable_wal_preflight
{
let cur = fs::metadata(wal_path).map_or(0, |m| m.len());
let needed = if needs_wrap {
wal_v3_auto_commit_size(&sql)
} else {
4 + sql.len() as u64
};
if cur.saturating_add(needed) > quota {
return write_frame(
stream,
&build_error_response(&format!(
"wal quota exceeded: cur={cur} + {needed} > quota={quota} (SPG_FAIL_WAL_QUOTA_BYTES)"
)),
);
}
}
let cancel_flag = Arc::new(AtomicBool::new(false));
let watchdog = spawn_query_watchdog(state, &cancel_flag);
let (result, wal_result, snapshot) = if needs_wrap {
let (ack_tx, ack_rx) = mpsc::sync_channel::<CommitResult>(1);
let task = CommitTask {
sql: sql.clone(),
cancel_flag: Arc::clone(&cancel_flag),
ack: ack_tx,
};
let became_leader = enqueue_commit_task(state, task);
if became_leader {
run_leader_commit_round(state);
}
let CommitResult {
result,
wal_outcome,
} = ack_rx.recv().map_err(|_| {
std::io::Error::other(
"commit barrier: ack channel closed before result arrived",
)
})?;
*in_tx = state.engine.read().is_ok_and(|e| e.in_transaction());
(result, wal_outcome, None)
} else {
let mut engine = state
.engine
.write()
.map_err(|_| std::io::Error::other("engine rwlock poisoned"))?;
let budget = usize::try_from(
state
.limits
.max_query_bytes
.unwrap_or(DEFAULT_MAX_QUERY_BYTES),
)
.unwrap_or(usize::MAX);
alloc_budget::reset_query_budget(budget, &cancel_flag);
let result = engine
.execute_with_cancel(&sql, spg_engine::CancelToken::from_flag(&cancel_flag));
alloc_budget::clear_query_budget();
let was_command_ok = matches!(result, Ok(QueryResult::CommandOk { .. }));
let wal_result = if was_command_ok && state.wal.is_some() {
append_wal(state, &sql)
} else {
Ok(())
};
*in_tx = engine.in_transaction();
let snapshot = if state.db_path.is_some() && state.wal.is_none() {
match &result {
Ok(QueryResult::CommandOk {
modified_catalog: true,
..
}) => Some(engine.snapshot()),
_ => None,
}
} else {
None
};
drop(engine);
(result, wal_result, snapshot)
};
watchdog.cancel();
if let (Some(bytes), Some(path)) = (snapshot.as_ref(), state.db_path.as_deref())
&& let Err(e) = write_atomic(path, bytes)
{
let _ = write_frame(
stream,
&build_error_response(&format!("snapshot write failed: {e}")),
);
return Err(e);
}
if let (Some(bytes), Some(path)) = (snapshot.as_ref(), state.db_path.as_deref()) {
let paths_snapshot = state
.cold_segment_paths
.lock()
.map(|g| g.clone())
.unwrap_or_default();
let wal_len = state
.wal_path
.as_deref()
.and_then(|p| fs::metadata(p).ok())
.map_or(0, |m| m.len());
write_manifest_alongside(path, bytes, &paths_snapshot, wal_len);
}
if let Err(e) = wal_result {
let _ = write_frame(
stream,
&build_error_response(&format!("WAL append failed: {e}")),
);
return Err(e);
}
if state.audit_path.is_some()
&& matches!(
result,
Ok(QueryResult::CommandOk {
modified_catalog: true,
..
})
)
&& let Err(e) = append_audit(state, &sql)
{
let _ = write_frame(
stream,
&build_error_response(&format!("audit append failed: {e}")),
);
return Err(e);
}
if matches!(
result,
Ok(QueryResult::CommandOk {
modified_catalog: true,
..
})
) {
reconcile_subscriptions(state);
}
emit_result(stream, result)
}
Op::Pong
| Op::RowDescription
| Op::DataRow
| Op::DataRowBatch
| Op::CommandComplete
| Op::ErrorResponse
| Op::StatsResponse => write_frame(
stream,
&Frame::error("client → server opcode not accepted on this side"),
),
Op::Error => write_frame(
stream,
&Frame::error("clients should not send Error frames"),
),
}
}
fn render_stats(state: &ServerState) -> std::io::Result<String> {
use std::fmt::Write as _;
let engine = state
.engine
.read()
.map_err(|_| std::io::Error::other("engine rwlock poisoned"))?;
let audit = state
.audit_log
.lock()
.map_err(|_| std::io::Error::other("audit mutex poisoned"))?;
let catalog = engine.catalog();
let mut out = String::new();
writeln!(out, "spg_version={}", env!("CARGO_PKG_VERSION")).unwrap();
writeln!(out, "tables={}", catalog.table_count()).unwrap();
for i in 0..catalog.table_count() {
let _ = i; }
writeln!(out, "in_transaction={}", engine.in_transaction()).unwrap();
writeln!(out, "audit_entries={}", audit.len()).unwrap();
writeln!(
out,
"db_path={}",
state
.db_path
.as_deref()
.map_or("<in-memory>".to_string(), |p| p.display().to_string())
)
.unwrap();
writeln!(
out,
"audit_path={}",
state
.audit_path
.as_deref()
.map_or("<disabled>".to_string(), |p| p.display().to_string())
)
.unwrap();
writeln!(
out,
"wal_path={}",
state
.wal_path
.as_deref()
.map_or("<disabled>".to_string(), |p| p.display().to_string())
)
.unwrap();
Ok(out)
}
fn parse_backup_intent(sql: &str) -> Option<BackupIntent> {
let trimmed = sql.trim().trim_end_matches(';').trim();
let lower = trimmed.to_ascii_lowercase();
let after_prefix = lower
.strip_prefix("backup ")?
.trim_start()
.strip_prefix("to ")?
.trim_start();
let prefix_consumed = lower.len() - after_prefix.len();
if !trimmed[prefix_consumed..].starts_with('\'') {
return None;
}
let after_open = &trimmed[prefix_consumed + 1..];
let close = after_open.find('\'')?;
let path = after_open[..close].to_string();
let tail = after_open[close + 1..].trim().to_ascii_lowercase();
if tail.is_empty() {
return Some(BackupIntent::Full { path });
}
let since_str = tail
.strip_prefix("incremental ")?
.trim_start()
.strip_prefix("since ")?
.trim_start();
let since: u64 = since_str.parse().ok()?;
Some(BackupIntent::Incremental { path, since })
}
#[derive(Debug)]
enum BackupIntent {
Full { path: String },
Incremental { path: String, since: u64 },
}
fn parse_checkpoint_intent(sql: &str) -> bool {
let trimmed = sql.trim().trim_end_matches(';').trim();
trimmed.eq_ignore_ascii_case("checkpoint")
}
fn run_checkpoint_command(stream: &mut TcpStream, state: &ServerState) -> std::io::Result<()> {
let Some(db_path) = state.db_path.as_deref() else {
return write_frame(
stream,
&build_error_response("CHECKPOINT requires a db_path (server started without one)"),
);
};
let snapshot_bytes = {
let engine = state
.engine
.write()
.map_err(|_| std::io::Error::other("engine rwlock poisoned"))?;
if engine.in_transaction() {
return write_frame(
stream,
&build_error_response("CHECKPOINT refused: an open transaction is in flight"),
);
}
let bytes = engine.snapshot();
drop(engine);
bytes
};
if let Err(e) = write_atomic(db_path, &snapshot_bytes) {
return write_frame(
stream,
&build_error_response(&format!("CHECKPOINT snapshot write failed: {e}")),
);
}
let cold_paths = state
.cold_segment_paths
.lock()
.map(|g| g.clone())
.unwrap_or_default();
write_manifest_alongside(db_path, &snapshot_bytes, &cold_paths, 0);
if let Some(wal_mutex) = state.wal.as_ref() {
let wal_lock = wal_mutex
.lock()
.map_err(|_| std::io::Error::other("WAL mutex poisoned"))?;
if let Err(e) = wal_lock.set_len(0) {
return write_frame(
stream,
&build_error_response(&format!("CHECKPOINT WAL truncate failed: {e}")),
);
}
if let Err(e) = wal_lock.sync_data() {
return write_frame(
stream,
&build_error_response(&format!("CHECKPOINT WAL sync failed: {e}")),
);
}
drop(wal_lock);
}
write_frame(stream, &build_command_complete(0))
}
fn parse_compact_cold_segments_intent(sql: &str) -> bool {
let trimmed = sql.trim().trim_end_matches(';').trim();
let mut parts = trimmed.split_whitespace();
matches!(parts.next(), Some(w) if w.eq_ignore_ascii_case("compact"))
&& matches!(parts.next(), Some(w) if w.eq_ignore_ascii_case("cold"))
&& matches!(parts.next(), Some(w) if w.eq_ignore_ascii_case("segments"))
&& parts.next().is_none()
}
fn compaction_target_bytes() -> u64 {
static CHECKED: std::sync::OnceLock<u64> = std::sync::OnceLock::new();
*CHECKED.get_or_init(|| {
parse_env_u64("SPG_COMPACTION_TARGET_SEGMENT_BYTES")
.unwrap_or(spg_engine::COMPACTION_TARGET_DEFAULT_BYTES)
})
}
fn run_compact_cold_segments_command(
stream: &mut TcpStream,
state: &ServerState,
) -> std::io::Result<()> {
let target = compaction_target_bytes();
let reports = {
let mut engine = state
.engine
.write()
.map_err(|_| std::io::Error::other("engine rwlock poisoned"))?;
if engine.in_transaction() {
return write_frame(
stream,
&build_error_response(
"COMPACT COLD SEGMENTS refused: an open transaction is in flight",
),
);
}
match engine.compact_cold_segments_with_target(target) {
Ok(r) => r,
Err(e) => {
return write_frame(
stream,
&build_error_response(&format!("COMPACT COLD SEGMENTS failed: {e:?}")),
);
}
}
};
let merged_count = reports.len();
if let Some(db_path) = state.db_path.as_deref() {
for (_tname, _iname, report) in &reports {
let Some(merged_id) = report.merged_segment_id else {
continue;
};
match persist_compact_merged_segment(db_path, merged_id, &report.merged_segment_bytes)
{
Ok(merged_path) => {
if let Ok(mut paths) = state.cold_segment_paths.lock() {
for src in &report.sources {
paths.remove(src);
}
paths.insert(merged_id, merged_path);
}
}
Err(e) => {
eprintln!(
"spg-server: COMPACT persist of merged segment {merged_id} failed: {e}"
);
}
}
}
state.metrics.cold_segments.store(
state
.engine
.read()
.ok()
.map(|e| e.catalog().cold_segment_count() as u64)
.unwrap_or(0),
std::sync::atomic::Ordering::Relaxed,
);
}
write_frame(stream, &build_command_complete(merged_count as u64))
}
fn persist_compact_merged_segment(
db_path: &Path,
merged_id: u32,
merged_segment_bytes: &[u8],
) -> std::io::Result<PathBuf> {
let parent = db_path.parent().unwrap_or_else(|| Path::new("."));
let stem = db_path
.file_stem()
.unwrap_or_else(|| std::ffi::OsStr::new("db"))
.to_string_lossy();
let seg_dir = parent.join(format!("{stem}.spg")).join("segments");
fs::create_dir_all(&seg_dir)?;
let final_path = seg_dir.join(format!("seg_{merged_id}.spg"));
let tmp_path = seg_dir.join(format!("seg_{merged_id}.spg.tmp"));
let bytes_to_write = if std::env::var("SPG_SEGMENT_COMPRESSION")
.map_or(true, |v| !v.eq_ignore_ascii_case("none"))
{
spg_storage::wrap_v2_envelope(merged_segment_bytes.to_vec(), true)
} else {
merged_segment_bytes.to_vec()
};
fs::write(&tmp_path, &bytes_to_write)?;
fs::rename(&tmp_path, &final_path)?;
Ok(final_path)
}
fn run_backup_command(
stream: &mut TcpStream,
state: &ServerState,
intent: &BackupIntent,
) -> std::io::Result<()> {
let result = match intent {
BackupIntent::Full { path } => backup::take_full_backup(state, Path::new(path)),
BackupIntent::Incremental { path, since } => {
backup::take_incremental_backup(state, Path::new(path), *since)
}
};
match result {
Ok(wal_pos) => write_frame(stream, &build_command_complete(wal_pos)),
Err(e) => write_frame(
stream,
&build_error_response(&format!("backup failed: {e}")),
),
}
}
pub(crate) const WAL_V2_SENTINEL: u32 = 0x8000_0000;
pub(crate) const WAL_V3_FLAG: u32 = 0x4000_0000;
pub(crate) const WAL_V3_SENTINEL: u32 = WAL_V2_SENTINEL | WAL_V3_FLAG;
pub(crate) fn synchronous_commit_disabled() -> bool {
static CACHED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
std::env::var("SPG_SYNCHRONOUS_COMMIT")
.ok()
.is_some_and(|s| matches!(s.trim().to_lowercase().as_str(), "off" | "false" | "0"))
})
}
pub(crate) const WAL_V3_TYPE_AUTO_COMMIT_SQL: u8 = 0x01;
pub(crate) const WAL_V3_TYPE_DURABILITY_CHECKPOINT: u8 = 0x02;
pub(crate) const WAL_V3_TYPE_COMPRESSED_SQL: u8 = 0x03;
pub(crate) const WAL_COMPRESS_ALGO_LZSS: u8 = 0x01;
pub(crate) const WAL_COMPRESS_MIN_BYTES: usize = 256;
fn encode_wal_record(sql: &str) -> std::io::Result<Vec<u8>> {
let len = u32::try_from(sql.len())
.map_err(|_| std::io::Error::other("SQL too large for WAL entry"))?;
if len & WAL_V2_SENTINEL != 0 {
return Err(std::io::Error::other(
"SQL byte count would alias the v4.37 WAL framing sentinel (≥ 2 GiB)",
));
}
let crc = spg_crypto::crc32::crc32(sql.as_bytes());
let mut entry = Vec::with_capacity(8 + sql.len());
entry.extend_from_slice(&(len | WAL_V2_SENTINEL).to_le_bytes());
entry.extend_from_slice(&crc.to_le_bytes());
entry.extend_from_slice(sql.as_bytes());
Ok(entry)
}
fn encode_wal_v3_record(type_tag: u8, payload: &[u8]) -> std::io::Result<Vec<u8>> {
let len = u32::try_from(payload.len())
.map_err(|_| std::io::Error::other("WAL v3 payload too large"))?;
if len & (WAL_V2_SENTINEL | WAL_V3_FLAG) != 0 {
return Err(std::io::Error::other(
"WAL v3 payload size would alias the v4.41 sentinel bits (≥ 1 GiB)",
));
}
let mut crc_input = Vec::with_capacity(1 + payload.len());
crc_input.push(type_tag);
crc_input.extend_from_slice(payload);
let crc = spg_crypto::crc32::crc32(&crc_input);
let mut entry = Vec::with_capacity(9 + payload.len());
entry.extend_from_slice(&(len | WAL_V3_SENTINEL).to_le_bytes());
entry.extend_from_slice(&crc.to_le_bytes());
entry.push(type_tag);
entry.extend_from_slice(payload);
Ok(entry)
}
pub(crate) fn encode_wal_auto_commit_sql_metrics(
sql: &str,
metrics: &observability::Metrics,
) -> std::io::Result<Vec<u8>> {
use std::sync::atomic::Ordering;
let raw_len = sql.len() as u64;
metrics
.wal_bytes_uncompressed_in
.fetch_add(raw_len, Ordering::Relaxed);
let threshold = wal_compression_min_bytes();
if !wal_compression_enabled() || sql.len() < threshold {
let out = encode_wal_v3_record(WAL_V3_TYPE_AUTO_COMMIT_SQL, sql.as_bytes())?;
metrics
.wal_bytes_compressed_out
.fetch_add(out.len() as u64, Ordering::Relaxed);
return Ok(out);
}
let compressed = spg_crypto::lzss::compress(sql.as_bytes());
if compressed.len() + 1 >= sql.len() {
let out = encode_wal_v3_record(WAL_V3_TYPE_AUTO_COMMIT_SQL, sql.as_bytes())?;
metrics
.wal_bytes_compressed_out
.fetch_add(out.len() as u64, Ordering::Relaxed);
return Ok(out);
}
let mut payload = Vec::with_capacity(1 + compressed.len());
payload.push(WAL_COMPRESS_ALGO_LZSS);
payload.extend_from_slice(&compressed);
let out = encode_wal_v3_record(WAL_V3_TYPE_COMPRESSED_SQL, &payload)?;
metrics
.wal_bytes_compressed_out
.fetch_add(out.len() as u64, Ordering::Relaxed);
Ok(out)
}
#[allow(dead_code)]
pub(crate) fn encode_wal_auto_commit_sql(sql: &str) -> std::io::Result<Vec<u8>> {
let threshold = wal_compression_min_bytes();
if !wal_compression_enabled() || sql.len() < threshold {
return encode_wal_v3_record(WAL_V3_TYPE_AUTO_COMMIT_SQL, sql.as_bytes());
}
let compressed = spg_crypto::lzss::compress(sql.as_bytes());
if compressed.len() + 1 >= sql.len() {
return encode_wal_v3_record(WAL_V3_TYPE_AUTO_COMMIT_SQL, sql.as_bytes());
}
let mut payload = Vec::with_capacity(1 + compressed.len());
payload.push(WAL_COMPRESS_ALGO_LZSS);
payload.extend_from_slice(&compressed);
encode_wal_v3_record(WAL_V3_TYPE_COMPRESSED_SQL, &payload)
}
pub(crate) fn wal_compression_min_bytes() -> usize {
static CHECKED: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
*CHECKED.get_or_init(|| {
std::env::var("SPG_COMPRESSION_MIN_BYTES")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(WAL_COMPRESS_MIN_BYTES)
})
}
pub(crate) fn wal_compression_enabled() -> bool {
static CHECKED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
*CHECKED.get_or_init(|| {
std::env::var("SPG_WAL_COMPRESSION")
.map_or(true, |v| !v.eq_ignore_ascii_case("none"))
})
}
fn wal_v3_auto_commit_size(sql: &str) -> u64 {
9u64 + sql.len() as u64
}
fn encode_durability_marker(byte_offset: u64) -> std::io::Result<Vec<u8>> {
encode_wal_v3_record(
WAL_V3_TYPE_DURABILITY_CHECKPOINT,
&byte_offset.to_le_bytes(),
)
}
fn append_durability_marker(state: &ServerState) -> std::io::Result<u64> {
let Some(wal_mutex) = state.wal.as_ref() else {
return Ok(0);
};
let pre_marker_offset = {
let mut wal = wal_mutex
.lock()
.map_err(|_| std::io::Error::other("wal mutex poisoned"))?;
let pre_marker_offset = wal.metadata()?.len();
let entry = encode_durability_marker(pre_marker_offset)?;
if let Some(quota) = state.chaos.wal_quota_bytes
&& pre_marker_offset.saturating_add(entry.len() as u64) > quota
{
return Err(std::io::Error::new(
std::io::ErrorKind::StorageFull,
format!(
"wal quota exceeded by durability marker: cur={pre_marker_offset} + {} > quota={quota}",
entry.len()
),
));
}
if let Some(min_free) = state.limits.wal_min_free_bytes
&& let Some(wal_path) = state.wal_path.as_deref()
{
let free = wal_volume_free_bytes(wal_path)?;
if free < min_free {
return Err(std::io::Error::new(
std::io::ErrorKind::StorageFull,
format!(
"WAL volume below water-mark for durability marker: free={free} < SPG_WAL_MIN_FREE_BYTES={min_free}"
),
));
}
}
wal.write_all(&entry)?;
pre_marker_offset
};
if let Some(sync_handle) = state.wal_sync_clone.as_ref() {
sync_handle.sync_data()?;
} else {
let wal = wal_mutex
.lock()
.map_err(|_| std::io::Error::other("wal mutex poisoned"))?;
wal.sync_data()?;
}
Ok(pre_marker_offset)
}
fn append_wal_v3_group(state: &ServerState, entries: &[Vec<u8>]) -> std::io::Result<()> {
let Some(wal) = state.wal.as_ref() else {
return Ok(());
};
if entries.is_empty() {
return Ok(());
}
let total: usize = entries.iter().map(Vec::len).sum();
let mut batched = Vec::with_capacity(total);
for e in entries {
batched.extend_from_slice(e);
}
let mut f = wal
.lock()
.map_err(|_| std::io::Error::other("wal mutex poisoned"))?;
if let Some(quota) = state.chaos.wal_quota_bytes {
let current = f.metadata().map_or(0, |m| m.len());
if current.saturating_add(batched.len() as u64) > quota {
return Err(std::io::Error::new(
std::io::ErrorKind::StorageFull,
format!(
"wal quota exceeded: cur={current} + {} > quota={quota} (SPG_FAIL_WAL_QUOTA_BYTES)",
batched.len()
),
));
}
}
if let Some(min_free) = state.limits.wal_min_free_bytes
&& let Some(wal_path) = state.wal_path.as_deref()
{
let free = wal_volume_free_bytes(wal_path)?;
if free < min_free {
return Err(std::io::Error::new(
std::io::ErrorKind::StorageFull,
format!(
"WAL volume below water-mark: free={free} < SPG_WAL_MIN_FREE_BYTES={min_free}"
),
));
}
}
f.write_all(&batched)?;
if !synchronous_commit_disabled() {
f.sync_data()?;
}
if let Some(tee_path) = wal_tee_path() {
if let Err(e) = append_to_tee(tee_path, &batched) {
eprintln!("spg-server: WAL tee append to {tee_path:?} failed: {e}");
}
}
Ok(())
}
fn wal_tee_path() -> Option<&'static str> {
static CACHED: std::sync::OnceLock<Option<String>> = std::sync::OnceLock::new();
CACHED
.get_or_init(|| env::var("SPG_WAL_TEE_PATH").ok().filter(|s| !s.is_empty()))
.as_deref()
}
fn append_to_tee(path: &str, bytes: &[u8]) -> std::io::Result<()> {
let mut f = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(path)?;
f.write_all(bytes)
}
fn clone_io_err(e: &std::io::Error) -> std::io::Error {
std::io::Error::new(e.kind(), e.to_string())
}
fn commit_group_max() -> usize {
parse_env_usize("SPG_COMMIT_GROUP_MAX").unwrap_or(DEFAULT_COMMIT_GROUP_MAX)
}
fn commit_delay_us() -> u64 {
parse_env_u64("SPG_COMMIT_DELAY_US").unwrap_or(0)
}
fn enqueue_commit_task(state: &ServerState, task: CommitTask) -> bool {
let mut q = state
.commit_queue
.lock()
.expect("commit queue mutex poisoned");
q.pending.push_back(task);
if q.leader_active {
false
} else {
q.leader_active = true;
true
}
}
#[allow(clippy::too_many_lines)]
fn run_leader_commit_round(state: &ServerState) {
struct Prepared {
task: CommitTask,
result: QueryResult,
wal_bytes: Vec<u8>,
}
let group_max = commit_group_max();
let delay_us = commit_delay_us();
loop {
let group: Vec<CommitTask> = {
let mut q = state
.commit_queue
.lock()
.expect("commit queue mutex poisoned");
if delay_us > 0 && q.pending.len() < group_max {
let deadline = Instant::now() + Duration::from_micros(delay_us);
while q.pending.len() < group_max && Instant::now() < deadline {
drop(q);
thread::yield_now();
q = state
.commit_queue
.lock()
.expect("commit queue mutex poisoned");
}
}
if q.pending.is_empty() {
q.leader_active = false;
return;
}
let take = q.pending.len().min(group_max);
q.pending.drain(..take).collect()
};
let mut prepared: Vec<Prepared> = Vec::with_capacity(group.len());
let pre_image: Option<spg_storage::Catalog> = {
let Ok(mut engine) = state.engine.write() else {
drop(group);
if let Ok(mut q) = state.commit_queue.lock() {
q.leader_active = false;
}
return;
};
let pre = engine.catalog().clone();
for task in group {
let tx_id = engine.alloc_tx_id();
if let Err(e) = engine.execute_in("BEGIN", tx_id) {
let _ = task.ack.send(CommitResult {
result: Err(e),
wal_outcome: Ok(()),
});
continue;
}
let exec_res = engine.execute_in_with_cancel(
&task.sql,
tx_id,
spg_engine::CancelToken::from_flag(&task.cancel_flag),
);
let was_command_ok = matches!(exec_res, Ok(QueryResult::CommandOk { .. }));
if !was_command_ok {
let _ = engine.execute_in("ROLLBACK", tx_id);
let _ = task.ack.send(CommitResult {
result: exec_res,
wal_outcome: Ok(()),
});
continue;
}
let wal_bytes = match encode_wal_auto_commit_sql_metrics(&task.sql, &state.metrics) {
Ok(b) => b,
Err(e) => {
let _ = engine.execute_in("ROLLBACK", tx_id);
let _ = task.ack.send(CommitResult {
result: exec_res,
wal_outcome: Err(e),
});
continue;
}
};
if let Err(e) = engine.execute_in("COMMIT", tx_id) {
let _ = engine.execute_in("ROLLBACK", tx_id);
let _ = task.ack.send(CommitResult {
result: Err(e),
wal_outcome: Ok(()),
});
continue;
}
prepared.push(Prepared {
task,
result: exec_res.unwrap(),
wal_bytes,
});
}
if prepared.is_empty() { None } else { Some(pre) }
};
if prepared.is_empty() {
continue;
}
let entries: Vec<Vec<u8>> = prepared.iter().map(|p| p.wal_bytes.clone()).collect();
let wal_outcome: std::io::Result<()> = append_wal_v3_group(state, &entries);
if wal_outcome.is_err()
&& let Some(pre) = pre_image
{
if let Ok(mut engine) = state.engine.write() {
engine.replace_catalog(pre);
} else {
drop(prepared);
if let Ok(mut q) = state.commit_queue.lock() {
q.leader_active = false;
}
return;
}
}
let wal_ok = wal_outcome.is_ok();
for p in prepared {
let cloned_wal = match &wal_outcome {
Ok(()) => Ok(()),
Err(e) => Err(clone_io_err(e)),
};
if wal_ok {
pubsub::publish_sql(&p.task.sql);
}
let _ = p.task.ack.send(CommitResult {
result: Ok(p.result),
wal_outcome: cloned_wal,
});
}
}
}
fn append_wal(state: &ServerState, sql: &str) -> std::io::Result<()> {
let Some(wal) = state.wal.as_ref() else {
return Ok(());
};
let entry = encode_wal_record(sql)?;
let mut f = wal
.lock()
.map_err(|_| std::io::Error::other("wal mutex poisoned"))?;
if let Some(quota) = state.chaos.wal_quota_bytes {
let current = f.metadata().map_or(0, |m| m.len());
if current.saturating_add(entry.len() as u64) > quota {
return Err(std::io::Error::new(
std::io::ErrorKind::StorageFull,
format!(
"wal quota exceeded: cur={current} + {} > quota={quota} (SPG_FAIL_WAL_QUOTA_BYTES)",
entry.len()
),
));
}
}
if let Some(min_free) = state.limits.wal_min_free_bytes
&& let Some(wal_path) = state.wal_path.as_deref()
{
let free = wal_volume_free_bytes(wal_path)?;
if free < min_free {
return Err(std::io::Error::new(
std::io::ErrorKind::StorageFull,
format!(
"WAL volume below water-mark: free={free} < SPG_WAL_MIN_FREE_BYTES={min_free}"
),
));
}
}
f.write_all(&entry)?;
if !synchronous_commit_disabled() {
f.sync_data()?;
}
Ok(())
}
#[allow(unsafe_code, clippy::cast_lossless, clippy::useless_conversion)]
fn wal_volume_free_bytes(path: &Path) -> std::io::Result<u64> {
use std::os::unix::ffi::OsStrExt;
let bytes = path.as_os_str().as_bytes();
let mut c_path = Vec::with_capacity(bytes.len() + 1);
c_path.extend_from_slice(bytes);
c_path.push(0);
let mut stat: libc::statvfs = unsafe { std::mem::zeroed() };
let rc = unsafe { libc::statvfs(c_path.as_ptr().cast(), &raw mut stat) };
if rc != 0 {
return Err(std::io::Error::last_os_error());
}
let bavail = stat.f_bavail as u64;
let frsize = stat.f_frsize as u64;
Ok(bavail.saturating_mul(frsize))
}
fn dispatch_v3_record(
tag: u8,
payload: &[u8],
frame_off: usize,
engine: &mut Engine,
) -> std::io::Result<bool> {
match tag {
WAL_V3_TYPE_AUTO_COMMIT_SQL => {
let sql = core::str::from_utf8(payload).map_err(|_| {
std::io::Error::other("v3 auto_commit_sql payload has non-UTF-8 SQL")
})?;
engine
.execute(sql)
.map_err(|e| std::io::Error::other(format!("WAL replay rejected {sql:?}: {e}")))?;
Ok(true)
}
WAL_V3_TYPE_COMPRESSED_SQL => {
if payload.is_empty() {
return Err(std::io::Error::other(format!(
"WAL compressed_sql at offset {frame_off}: empty payload"
)));
}
let algo = payload[0];
let compressed = &payload[1..];
let raw_bytes = match algo {
WAL_COMPRESS_ALGO_LZSS => spg_crypto::lzss::decompress(compressed).map_err(|e| {
std::io::Error::other(format!(
"WAL compressed_sql at offset {frame_off}: LZSS decompress failed: {e:?}"
))
})?,
other => {
return Err(std::io::Error::other(format!(
"WAL compressed_sql at offset {frame_off}: unknown algo byte {other:#04x}"
)));
}
};
let sql = core::str::from_utf8(&raw_bytes).map_err(|_| {
std::io::Error::other(format!(
"WAL compressed_sql at offset {frame_off}: decompressed bytes are not valid UTF-8"
))
})?;
engine
.execute(sql)
.map_err(|e| std::io::Error::other(format!("WAL replay rejected {sql:?}: {e}")))?;
Ok(true)
}
WAL_V3_TYPE_DURABILITY_CHECKPOINT => {
if payload.len() != 8 {
return Err(std::io::Error::other(format!(
"WAL durability_checkpoint at offset {frame_off} has {}-byte payload (expected 8)",
payload.len()
)));
}
let arr: [u8; 8] = payload.try_into().expect("checked len above");
let recorded_off = u64::from_le_bytes(arr);
let frame_off_u64 = frame_off as u64;
if recorded_off != frame_off_u64 {
eprintln!(
"spg-server: WAL durability_checkpoint at offset {frame_off} carries recorded_off={recorded_off} — possible WAL relocation; treating marker as no-op"
);
}
Ok(false)
}
other => Err(std::io::Error::other(format!(
"WAL v3 unknown type byte {other:#04x} at offset {frame_off} — refusing to replay"
))),
}
}
fn replay_wal_bytes(bytes: &[u8], engine: &mut Engine) -> std::io::Result<usize> {
let mut cur = 0;
let mut applied = 0usize;
while cur < bytes.len() {
if bytes.len() - cur < 4 {
eprintln!(
"spg-server: WAL truncated at offset {cur} (need 4-byte length, have {})",
bytes.len() - cur
);
break;
}
let frame_off = cur;
let len_arr: [u8; 4] = bytes[cur..cur + 4].try_into().expect("checked");
let raw_len = u32::from_le_bytes(len_arr);
cur += 4;
let is_v2 = raw_len & WAL_V2_SENTINEL != 0;
let is_v3 = is_v2 && (raw_len & WAL_V3_FLAG != 0);
let len_mask = if is_v3 {
!(WAL_V2_SENTINEL | WAL_V3_FLAG)
} else {
!WAL_V2_SENTINEL
};
let len = (raw_len & len_mask) as usize;
let expected_crc = if is_v2 {
if bytes.len() - cur < 4 {
eprintln!(
"spg-server: v2/v3 WAL truncated at offset {cur} (need 4-byte CRC, have {})",
bytes.len() - cur
);
break;
}
let crc_arr: [u8; 4] = bytes[cur..cur + 4].try_into().expect("checked");
cur += 4;
Some(u32::from_le_bytes(crc_arr))
} else {
None
};
let v3_type_tag = if is_v3 {
if bytes.len() - cur < 1 {
eprintln!(
"spg-server: v3 WAL truncated at offset {cur} (need 1-byte type, have 0)"
);
break;
}
let t = bytes[cur];
cur += 1;
Some(t)
} else {
None
};
if cur + len > bytes.len() {
eprintln!("spg-server: WAL entry truncated (payload_len={len}) — dropping tail");
break;
}
let payload = &bytes[cur..cur + len];
if let Some(expected) = expected_crc {
let actual = if let Some(tag) = v3_type_tag {
let mut buf = Vec::with_capacity(1 + payload.len());
buf.push(tag);
buf.extend_from_slice(payload);
spg_crypto::crc32::crc32(&buf)
} else {
spg_crypto::crc32::crc32(payload)
};
if actual != expected {
return Err(std::io::Error::other(format!(
"WAL CRC mismatch at offset {frame_off} (expected={expected:#010x}, computed={actual:#010x}, payload_len={len}) — corruption detected, refusing to replay"
)));
}
}
let count_as_applied = if let Some(tag) = v3_type_tag {
dispatch_v3_record(tag, payload, frame_off, engine)?
} else {
let sql = core::str::from_utf8(payload)
.map_err(|_| std::io::Error::other("WAL entry has non-UTF-8 SQL"))?;
engine
.execute(sql)
.map_err(|e| std::io::Error::other(format!("WAL replay rejected {sql:?}: {e}")))?;
true
};
cur += len;
if count_as_applied {
applied += 1;
}
}
Ok(applied)
}
pub(crate) fn append_audit_pub(state: &ServerState, sql: &str) -> std::io::Result<()> {
append_audit(state, sql)
}
fn append_audit(state: &ServerState, sql: &str) -> std::io::Result<()> {
let ts_ms = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| d.as_millis()),
)
.unwrap_or(u64::MAX);
let mut log = state
.audit_log
.lock()
.map_err(|_| std::io::Error::other("audit mutex poisoned"))?;
log.append(sql.to_string(), ts_ms);
if let Some(path) = state.audit_path.as_deref() {
let mut entry_bytes = Vec::new();
log.encode_entry_to(log.len() - 1, &mut entry_bytes);
let mut f = OpenOptions::new().append(true).open(path)?;
f.write_all(&entry_bytes)?;
}
Ok(())
}
pub(crate) fn wall_clock_micros() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| i64::try_from(d.as_micros()).unwrap_or(i64::MAX))
}
fn bootstrap_admin_from_env(engine: &mut Engine, db_path: Option<&Path>) -> std::io::Result<()> {
if !engine.users().is_empty() {
return Ok(());
}
let Ok(pw) = env::var("SPG_ADMIN_PASSWORD") else {
return Ok(());
};
if pw.is_empty() {
return Ok(());
}
let user = env::var("SPG_ADMIN_USER")
.ok()
.filter(|s| !s.is_empty())
.unwrap_or_else(|| "admin".to_string());
let salt = random_salt()?;
engine
.create_user(&user, &pw, Role::Admin, salt)
.map_err(|e| std::io::Error::other(format!("bootstrap admin {user:?}: {e}")))?;
eprintln!("spg-server: bootstrapped admin user {user:?} from SPG_ADMIN_PASSWORD");
if let Some(p) = db_path {
let snapshot = engine.snapshot();
if let Err(e) = write_atomic(p, &snapshot) {
eprintln!("spg-server: warning — failed to persist bootstrap admin: {e}");
} else {
write_manifest_alongside(p, &snapshot, &BTreeMap::new(), 0);
}
}
Ok(())
}
struct SlowLogGuard<'a> {
threshold_us: Option<u64>,
sql: &'a str,
role: Option<Role>,
start: Instant,
}
impl<'a> SlowLogGuard<'a> {
fn new(state: &ServerState, sql: &'a str, role: Option<Role>) -> Self {
Self {
threshold_us: state
.limits
.slow_query_log_ms
.map(|ms| ms.saturating_mul(1000)),
sql,
role,
start: Instant::now(),
}
}
}
impl Drop for SlowLogGuard<'_> {
fn drop(&mut self) {
let Some(threshold_us) = self.threshold_us else {
return;
};
let elapsed_us = u64::try_from(self.start.elapsed().as_micros()).unwrap_or(u64::MAX);
if elapsed_us < threshold_us {
return;
}
let mut sql_escaped = String::with_capacity(self.sql.len() + 16);
json_escape_into(self.sql, &mut sql_escaped);
let role_str = self.role.map_or("unauth", Role::as_str);
eprintln!(
r#"{{"event":"slow_query","sql":"{sql_escaped}","elapsed_us":{elapsed_us},"role":"{role_str}","threshold_us":{threshold_us}}}"#
);
}
}
fn json_escape_into(s: &str, out: &mut String) {
use std::fmt::Write as _;
for ch in s.chars() {
match ch {
'"' => out.push_str("\\\""),
'\\' => out.push_str("\\\\"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
'\u{0008}' => out.push_str("\\b"),
'\u{000c}' => out.push_str("\\f"),
c if (c as u32) < 0x20 => {
let _ = write!(out, "\\u{:04x}", c as u32);
}
c => out.push(c),
}
}
}
struct Watchdog {
completed: Arc<AtomicBool>,
}
impl Watchdog {
fn cancel(&self) {
self.completed.store(true, Ordering::Release);
}
}
fn spawn_query_watchdog(state: &ServerState, cancel_flag: &Arc<AtomicBool>) -> Watchdog {
let completed = Arc::new(AtomicBool::new(false));
let timeout_dur = state
.limits
.query_timeout_ms
.map(std::time::Duration::from_millis);
let cpu_dur = state
.limits
.max_query_ns
.map(std::time::Duration::from_nanos);
let total = match (timeout_dur, cpu_dur) {
(Some(a), Some(b)) => Some(a.min(b)),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
};
let Some(total) = total else {
return Watchdog { completed };
};
let cancel_flag = Arc::clone(cancel_flag);
let completed_for_thread = Arc::clone(&completed);
thread::spawn(move || {
let slice = (total / 50).max(std::time::Duration::from_micros(100));
let start = std::time::Instant::now();
while start.elapsed() < total {
if completed_for_thread.load(Ordering::Acquire) {
return;
}
thread::sleep(slice);
}
cancel_flag.store(true, Ordering::Release);
});
Watchdog { completed }
}
fn random_salt() -> std::io::Result<[u8; 16]> {
let mut buf = [0u8; 16];
File::open("/dev/urandom")?.read_exact(&mut buf)?;
Ok(buf)
}
fn urandom_salt_or_panic() -> [u8; 16] {
random_salt().expect("/dev/urandom unreadable — refusing to create users without entropy")
}
fn write_atomic(path: &Path, data: &[u8]) -> std::io::Result<()> {
let dir = path.parent().unwrap_or_else(|| Path::new("."));
let pid = process::id();
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| d.subsec_nanos());
let tmp = dir.join(format!(".spg-tmp-{pid}-{nanos}"));
fs::write(&tmp, data)?;
if let Err(e) = fs::rename(&tmp, path) {
let _ = fs::remove_file(&tmp);
return Err(e);
}
Ok(())
}
pub(crate) fn write_manifest_alongside(
db_path: &Path,
snapshot_bytes: &[u8],
cold_segment_paths: &BTreeMap<u32, PathBuf>,
wal_baseline_offset: u64,
) {
let mp = manifest::manifest_path(db_path);
if let Some(dir) = mp.parent()
&& let Err(e) = fs::create_dir_all(dir)
{
eprintln!(
"spg-server: manifest dir {} mkdir failed: {e}",
dir.display()
);
return;
}
let cold_segments: Vec<manifest::ColdSegmentEntry> = cold_segment_paths
.iter()
.filter_map(|(&segment_id, path)| match fs::read(path) {
Ok(bytes) => Some(manifest::ColdSegmentEntry {
segment_id,
path: path.clone(),
crc32: spg_crypto::crc32::crc32(&bytes),
}),
Err(e) => {
eprintln!(
"spg-server: manifest skip segment {segment_id}: read {} failed: {e}",
path.display()
);
None
}
})
.collect();
let m = manifest::CatalogManifest {
catalog_crc32: spg_crypto::crc32::crc32(snapshot_bytes),
cold_segments,
wal_baseline_offset,
};
let bytes = m.serialize();
if let Err(e) = write_atomic(&mp, &bytes) {
eprintln!("spg-server: manifest write to {} failed: {e}", mp.display());
}
}
fn load_manifest_and_preload_cold(
engine: &mut Engine,
db_path: &Path,
snapshot_bytes: &[u8],
cold_segment_paths: &mut BTreeMap<u32, PathBuf>,
) -> u64 {
let mp = manifest::manifest_path(db_path);
if !mp.exists() {
return 0;
}
let bytes = match fs::read(&mp) {
Ok(b) => b,
Err(e) => {
eprintln!("spg-server: manifest read {} failed: {e}", mp.display());
return 0;
}
};
let m = match manifest::CatalogManifest::deserialize(&bytes) {
Ok(m) => m,
Err(e) => {
eprintln!("spg-server: manifest {} rejected: {e}", mp.display());
return 0;
}
};
let snapshot_crc = spg_crypto::crc32::crc32(snapshot_bytes);
if snapshot_crc != m.catalog_crc32 {
eprintln!(
"spg-server: manifest {} catalog CRC mismatch (expected={:#010x}, file={:#010x}); \
falling back to WAL-only replay",
mp.display(),
m.catalog_crc32,
snapshot_crc,
);
return 0;
}
let mut cat = engine.catalog().clone();
let mut loaded: usize = 0;
let mut skipped: usize = 0;
let paths_to_read: Vec<(u32, PathBuf)> = m
.cold_segments
.iter()
.map(|e| (e.segment_id, e.path.clone()))
.collect();
let workers = prefetch::worker_count_from_env();
let prefetched = prefetch::parallel_read_segments(&paths_to_read, workers, None);
let read_map: std::collections::BTreeMap<u32, std::io::Result<Vec<u8>>> =
prefetched.into_iter().collect();
let mut prefetch_hits: u64 = 0;
for entry in &m.cold_segments {
let bytes_result = read_map
.get(&entry.segment_id)
.map(|r| match r {
Ok(b) => Ok(b.clone()),
Err(e) => Err(std::io::Error::new(e.kind(), e.to_string())),
})
.unwrap_or_else(|| {
Err(std::io::Error::other(format!(
"no prefetch result for segment {}",
entry.segment_id
)))
});
match bytes_result {
Ok(seg_bytes) => {
let computed = spg_crypto::crc32::crc32(&seg_bytes);
if computed != entry.crc32 {
eprintln!(
"spg-server: manifest skip segment {}: CRC mismatch ({} != {})",
entry.segment_id, computed, entry.crc32
);
skipped += 1;
continue;
}
match cat.load_segment_bytes_at(entry.segment_id, seg_bytes) {
Ok(()) => {
cold_segment_paths.insert(entry.segment_id, entry.path.clone());
loaded += 1;
prefetch_hits += 1;
}
Err(e) => {
eprintln!(
"spg-server: manifest segment {} load failed: {e}",
entry.segment_id
);
skipped += 1;
}
}
}
Err(e) => {
eprintln!(
"spg-server: manifest skip segment {}: read {} failed: {e}",
entry.segment_id,
entry.path.display()
);
skipped += 1;
}
}
}
engine.replace_catalog(cat);
PREFETCH_HITS_BOOT.with(|cell| cell.set(prefetch_hits));
eprintln!(
"spg-server: manifest {} loaded {loaded} cold segment(s), skipped {skipped}; wal_baseline_offset={}",
mp.display(),
m.wal_baseline_offset,
);
m.wal_baseline_offset
}
fn emit_result(
stream: &mut TcpStream,
result: Result<QueryResult, EngineError>,
) -> std::io::Result<()> {
match result {
Ok(QueryResult::CommandOk { affected, .. }) => {
write_frame(stream, &build_command_complete(affected as u64))
}
Ok(QueryResult::Rows { columns, rows }) => {
let descs = columns
.iter()
.map(column_schema_to_desc)
.collect::<Vec<_>>();
let rd =
build_row_description(&descs).map_err(|e| std::io::Error::other(e.to_string()))?;
let mut out: Vec<u8> = Vec::with_capacity(
spg_wire::FRAME_HEADER_LEN + rd.payload.len() + rows.len() * 64 + 16,
);
encode(&rd, &mut out).map_err(|e| std::io::Error::other(e.to_string()))?;
if rows.len() <= 1 {
for row in rows {
let wire = row_to_wire(&row);
let frame =
build_data_row(&wire).map_err(|e| std::io::Error::other(e.to_string()))?;
encode(&frame, &mut out).map_err(|e| std::io::Error::other(e.to_string()))?;
}
} else {
let wire_rows: Vec<Vec<WireValue>> = rows.iter().map(row_to_wire).collect();
for chunk in wire_rows.chunks(BATCH_ROWS_PER_FRAME) {
let frame = build_data_row_batch(chunk)
.map_err(|e| std::io::Error::other(e.to_string()))?;
encode(&frame, &mut out).map_err(|e| std::io::Error::other(e.to_string()))?;
}
}
let cc = build_command_complete(0);
encode(&cc, &mut out).map_err(|e| std::io::Error::other(e.to_string()))?;
stream.write_all(&out)
}
Err(e) => write_frame(stream, &build_error_response(&e.to_string())),
Ok(_) => write_frame(stream, &build_error_response("unexpected QueryResult variant")),
}
}
fn column_schema_to_desc(c: &ColumnSchema) -> ColumnDesc {
ColumnDesc {
name: c.name.clone(),
ty: data_type_to_wire(c.ty),
nullable: c.nullable,
}
}
const fn data_type_to_wire(t: DataType) -> WireType {
match t {
DataType::SmallInt | DataType::Int => WireType::Int,
DataType::BigInt => WireType::BigInt,
DataType::Float => WireType::Float,
DataType::Text
| DataType::Varchar(_)
| DataType::Char(_)
| DataType::Numeric { .. }
| DataType::Date
| DataType::Timestamp
| DataType::Timestamptz
| DataType::Interval
| DataType::Json
| DataType::Jsonb
| DataType::Bytes
| DataType::TextArray => WireType::Text,
DataType::Bool => WireType::Bool,
DataType::Vector { .. } => WireType::Vector,
}
}
fn row_to_wire(r: &Row) -> Vec<WireValue> {
r.values.iter().map(value_to_wire).collect()
}
fn value_to_wire(v: &Value) -> WireValue {
match v {
Value::Null => WireValue::Null,
Value::SmallInt(n) => WireValue::Int(i32::from(*n)),
Value::Int(n) => WireValue::Int(*n),
Value::BigInt(n) => WireValue::BigInt(*n),
Value::Float(x) => WireValue::Float(*x),
Value::Text(s) | Value::Json(s) => WireValue::Text(s.clone()),
Value::Bool(b) => WireValue::Bool(*b),
Value::Vector(v) => WireValue::Vector(v.clone()),
Value::Sq8Vector(q) => WireValue::Vector(spg_storage::quantize::dequantize(q)),
Value::HalfVector(h) => WireValue::Vector(h.to_f32_vec()),
Value::Numeric { scaled, scale } => {
WireValue::Text(spg_engine::eval::format_numeric(*scaled, *scale))
}
Value::Date(d) => WireValue::Text(spg_engine::eval::format_date(*d)),
Value::Timestamp(t) => WireValue::Text(spg_engine::eval::format_timestamp(*t)),
Value::Interval { months, micros } => {
WireValue::Text(spg_engine::eval::format_interval(*months, *micros))
}
Value::Bytes(b) => WireValue::Text(spg_engine::eval::format_bytea_hex(b)),
Value::TextArray(items) => WireValue::Text(spg_engine::eval::format_text_array(items)),
_ => WireValue::Text(format!("{v:?}")),
}
}
fn write_frame(stream: &mut TcpStream, frame: &Frame) -> std::io::Result<()> {
let mut out = Vec::with_capacity(32);
encode(frame, &mut out).map_err(|e| std::io::Error::other(e.to_string()))?;
stream.write_all(&out)
}
#[cfg(test)]
mod wal_v3_durability_marker_tests {
use super::{
Engine, WAL_V2_SENTINEL, WAL_V3_FLAG, WAL_V3_SENTINEL, WAL_V3_TYPE_AUTO_COMMIT_SQL,
WAL_V3_TYPE_DURABILITY_CHECKPOINT, encode_durability_marker, encode_wal_v3_record,
replay_wal_bytes,
};
#[test]
fn durability_marker_frame_shape_pins_v3_wire() {
let bytes = encode_durability_marker(0x1234_5678).unwrap();
assert_eq!(bytes.len(), 17, "marker frame must be 17 bytes");
let raw_len = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
let len_field = raw_len & !(WAL_V2_SENTINEL | WAL_V3_FLAG);
assert_eq!(len_field, 8, "marker payload is 8 bytes (the u64 offset)");
assert_eq!(
raw_len & WAL_V3_SENTINEL,
WAL_V3_SENTINEL,
"marker must carry v3 sentinel bits",
);
assert_eq!(
bytes[8], WAL_V3_TYPE_DURABILITY_CHECKPOINT,
"type byte must be 0x02",
);
let offset = u64::from_le_bytes(bytes[9..17].try_into().unwrap());
assert_eq!(offset, 0x1234_5678, "payload echoes the offset arg");
}
#[test]
fn replay_skips_durability_markers_and_does_not_increment_applied() {
let mut stream = Vec::new();
stream.extend_from_slice(&encode_durability_marker(0).unwrap());
stream.extend_from_slice(&encode_durability_marker(17).unwrap());
stream.extend_from_slice(&encode_durability_marker(34).unwrap());
let mut engine = Engine::new();
let applied = replay_wal_bytes(&stream, &mut engine).expect("replay must accept markers");
assert_eq!(applied, 0, "markers do not count as applied records");
}
#[test]
fn replay_mixes_sql_and_markers_advancing_cursor_correctly() {
let mut stream = Vec::new();
let create_a =
encode_wal_v3_record(WAL_V3_TYPE_AUTO_COMMIT_SQL, b"CREATE TABLE a (id INT)").unwrap();
let create_b =
encode_wal_v3_record(WAL_V3_TYPE_AUTO_COMMIT_SQL, b"CREATE TABLE b (id INT)").unwrap();
let marker_off = create_a.len() as u64;
let marker = encode_durability_marker(marker_off).unwrap();
stream.extend_from_slice(&create_a);
stream.extend_from_slice(&marker);
stream.extend_from_slice(&create_b);
let mut engine = Engine::new();
let applied =
replay_wal_bytes(&stream, &mut engine).expect("mixed stream must replay cleanly");
assert_eq!(
applied, 2,
"two CREATE TABLEs applied; marker doesn't count"
);
}
#[test]
fn replay_rejects_marker_with_wrong_payload_length() {
let bad =
encode_wal_v3_record(WAL_V3_TYPE_DURABILITY_CHECKPOINT, &0u32.to_le_bytes()).unwrap();
let mut engine = Engine::new();
let err = replay_wal_bytes(&bad, &mut engine).expect_err("4-byte payload must error");
let msg = err.to_string();
assert!(
msg.contains("durability_checkpoint") && msg.contains("4-byte payload"),
"error message should name the malformed marker: got {msg:?}",
);
}
}