mod alloc_budget;
mod backup;
mod flusher;
mod freezer;
mod manifest;
mod observability;
mod prefetch;
mod pubsub;
thread_local! {
static PREFETCH_HITS_BOOT: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
}
mod commands;
mod mysqlwire;
mod pgwire;
mod replication;
mod scram;
mod wal;
mod wire;
use std::collections::{BTreeMap, VecDeque};
use std::env;
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::path::{Path, PathBuf};
use std::process;
use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU8, AtomicUsize, Ordering};
use std::sync::mpsc::{self, SyncSender};
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use spg_audit::AuditLog;
use spg_engine::{Engine, EngineError, QueryResult, Role};
use spg_storage::{ColumnSchema, DataType, Row, Value};
use spg_wire::{
Frame, FrameError, Op, build_error_response, build_stats_response, decode, parse_auth,
parse_auth_user, parse_query,
};
pub(crate) use commands::*;
pub(crate) use wal::*;
use wire::{emit_result, write_frame};
#[global_allocator]
static GLOBAL_ALLOC: alloc_budget::BudgetAllocator = alloc_budget::BudgetAllocator;
const DEFAULT_MAX_QUERY_BYTES: u64 = 256 * 1024 * 1024;
const DEFAULT_ADDR: &str = "127.0.0.1:5544";
const READ_CHUNK: usize = 4096;
const BATCH_ROWS_PER_FRAME: usize = 256;
const SHUTDOWN_POLL: Duration = Duration::from_millis(50);
const DEFAULT_SHUTDOWN_DEADLINE_SEC: u64 = 30;
const DEFAULT_COMMIT_GROUP_MAX: usize = 16;
static SHUTDOWN_FLAG: AtomicBool = AtomicBool::new(false);
#[derive(Debug, Default, Clone, Copy)]
struct Limits {
max_connections: Option<usize>,
max_query_rows: Option<usize>,
max_query_bytes: Option<u64>,
query_timeout_ms: Option<u64>,
max_query_ns: Option<u64>,
idle_timeout_sec: Option<u64>,
slow_query_log_ms: Option<u64>,
wal_min_free_bytes: Option<u64>,
shutdown_deadline_sec: Option<u64>,
}
#[derive(Debug, Default, Clone, Copy)]
struct ChaosKnobs {
wal_quota_bytes: Option<u64>,
disable_wal_preflight: bool,
}
struct CommitResult {
result: Result<QueryResult, EngineError>,
wal_outcome: std::io::Result<()>,
}
struct CommitTask {
sql: String,
cancel_flag: Arc<AtomicBool>,
ack: SyncSender<CommitResult>,
}
struct CommitQueueState {
pending: VecDeque<CommitTask>,
leader_active: bool,
}
pub(crate) struct ServerState {
pub(crate) engine: RwLock<Engine>,
db_path: Option<PathBuf>,
audit_log: Mutex<AuditLog>,
audit_path: Option<PathBuf>,
wal: Option<Mutex<File>>,
wal_sync_clone: Option<Arc<File>>,
wal_path: Option<PathBuf>,
commit_queue: Mutex<CommitQueueState>,
password: Option<String>,
limits: Limits,
active_connections: AtomicUsize,
metrics: Arc<observability::Metrics>,
sub_workers: Mutex<BTreeMap<String, Arc<AtomicBool>>>,
pub(crate) cluster_id: u64,
pub(crate) wal_level: AtomicU8,
chaos: ChaosKnobs,
pub(crate) lag_state: Arc<replication::LagState>,
cold_preload: Vec<ColdPreloadSpec>,
cold_preload_done: AtomicBool,
pub(crate) hot_tier_byte_budget: u64,
pub(crate) cold_segment_paths: Mutex<BTreeMap<u32, PathBuf>>,
pub(crate) connections: RwLock<Vec<Arc<ConnState>>>,
}
pub(crate) struct ConnState {
pub(crate) pid: u32,
pub(crate) user: String,
pub(crate) started_at_us: i64,
pub(crate) current_sql: RwLock<String>,
pub(crate) wait_event: AtomicU8,
pub(crate) last_query_start_us: AtomicI64,
pub(crate) in_transaction: AtomicBool,
pub(crate) application_name: RwLock<String>,
}
impl ConnState {
pub(crate) fn elapsed_us(&self) -> i64 {
let start = self.last_query_start_us.load(Ordering::Relaxed);
if start == 0 {
return 0;
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.ok()
.map_or(0, |d| d.as_micros() as i64);
(now - start).max(0)
}
pub(crate) fn wait_event_str(&self) -> &'static str {
match self.wait_event.load(Ordering::Relaxed) {
1 => "write_lock",
2 => "fsync",
3 => "group_commit",
_ => "",
}
}
}
pub(crate) static ACTIVITY_STATE: std::sync::OnceLock<Arc<ServerState>> =
std::sync::OnceLock::new();
pub(crate) fn audit_chain_snapshot() -> Vec<spg_engine::AuditRow> {
let Some(state) = ACTIVITY_STATE.get() else {
return Vec::new();
};
let Ok(log) = state.audit_log.lock() else {
return Vec::new();
};
log.entries()
.iter()
.map(|e| spg_engine::AuditRow {
seq: i64::try_from(e.seq).unwrap_or(i64::MAX),
ts_ms: i64::try_from(e.ts_ms).unwrap_or(i64::MAX),
prev_hash_hex: hex_encode(&e.prev_hash),
entry_hash_hex: hex_encode(&e.hash),
sql: e.sql.clone(),
})
.collect()
}
pub(crate) fn audit_verify_snapshot() -> (i64, i64) {
let Some(state) = ACTIVITY_STATE.get() else {
return (0, -1);
};
let Ok(log) = state.audit_log.lock() else {
return (0, -1);
};
let n = log.entries().len() as i64;
match log.verify() {
Ok(()) => (n, -1),
Err(spg_audit::AuditError::BrokenChain { seq })
| Err(spg_audit::AuditError::HashMismatch { seq })
| Err(spg_audit::AuditError::InvalidUtf8 { seq }) => (
i64::try_from(seq).unwrap_or(i64::MAX),
i64::try_from(seq).unwrap_or(i64::MAX),
),
Err(_) => (0, 0),
}
}
pub(crate) fn log_slow_query(sql: &str, elapsed_us: u64) {
let elapsed_str = elapsed_us.to_string();
observability::log_event(
"warn",
"slow_query",
&[("elapsed_us", &elapsed_str), ("sql", sql)],
);
}
fn hex_encode(bytes: &[u8]) -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut out = String::with_capacity(bytes.len() * 2);
for &b in bytes {
out.push(HEX[(b >> 4) as usize] as char);
out.push(HEX[(b & 0xf) as usize] as char);
}
out
}
pub(crate) fn activity_snapshot() -> Vec<spg_engine::ActivityRow> {
let Some(state) = ACTIVITY_STATE.get() else {
return Vec::new();
};
let Ok(conns) = state.connections.read() else {
return Vec::new();
};
conns
.iter()
.map(|c| {
let current_sql = c.current_sql.read().map(|g| g.clone()).unwrap_or_default();
let application_name = c
.application_name
.read()
.map(|g| g.clone())
.unwrap_or_default();
spg_engine::ActivityRow {
pid: c.pid,
user: c.user.clone(),
started_at_us: c.started_at_us,
current_sql,
wait_event: c.wait_event_str().to_string(),
elapsed_us: c.elapsed_us(),
in_transaction: c.in_transaction.load(Ordering::Relaxed),
application_name,
}
})
.collect()
}
pub(crate) const DEFAULT_HOT_TIER_BYTES: u64 = 4 * 1024 * 1024 * 1024;
struct ColdPreloadSpec {
table: String,
index: String,
path: PathBuf,
loaded: AtomicBool,
}
fn parse_optional_path(arg: Option<String>) -> Option<PathBuf> {
arg.filter(|s| !s.is_empty() && s != "-").map(PathBuf::from)
}
fn resolve_path(cli: Option<String>, env_key: &str) -> Option<PathBuf> {
parse_optional_path(cli).or_else(|| {
env::var(env_key)
.ok()
.and_then(|s| parse_optional_path(Some(s)))
})
}
fn main() {
let raw_args: Vec<String> = env::args().skip(1).collect();
let replay_only = raw_args.iter().any(|a| a == "--replay-only");
let mut args = raw_args.into_iter().filter(|a| a != "--replay-only");
let addr = args
.next()
.or_else(|| env::var("SPG_ADDR").ok())
.unwrap_or_else(|| DEFAULT_ADDR.to_string());
let db_path = resolve_path(args.next(), "SPG_DB");
let audit_path = resolve_path(args.next(), "SPG_AUDIT");
let wal_path = resolve_path(args.next(), "SPG_WAL");
let password = env::var("SPG_PASSWORD").ok().filter(|s| !s.is_empty());
let limits = Limits {
max_connections: parse_env_usize("SPG_MAX_CONNECTIONS"),
max_query_rows: parse_env_usize("SPG_MAX_QUERY_ROWS"),
max_query_bytes: parse_env_u64("SPG_MAX_QUERY_BYTES"),
query_timeout_ms: parse_env_u64("SPG_QUERY_TIMEOUT_MS"),
max_query_ns: parse_env_u64("SPG_MAX_QUERY_NS"),
idle_timeout_sec: parse_env_u64("SPG_IDLE_TIMEOUT_SEC"),
slow_query_log_ms: parse_env_u64("SPG_SLOW_QUERY_LOG_MS"),
wal_min_free_bytes: parse_env_u64("SPG_WAL_MIN_FREE_BYTES"),
shutdown_deadline_sec: parse_env_u64("SPG_SHUTDOWN_DEADLINE_SEC"),
};
install_shutdown_handlers();
if replay_only {
if let Err(e) = run_replay_only(db_path, wal_path) {
eprintln!("spg-server: replay-only fatal: {e}");
process::exit(1);
}
eprintln!("spg-server: --replay-only complete; exiting 0");
return;
}
if let Err(e) = run(&addr, db_path, audit_path, wal_path, password, limits) {
eprintln!("spg-server: fatal: {e}");
process::exit(1);
}
}
fn run_replay_only(db_path: Option<PathBuf>, wal_path: Option<PathBuf>) -> std::io::Result<()> {
let mut engine = match db_path.as_deref() {
Some(p) if p.exists() => {
let bytes = fs::read(p)?;
let engine = Engine::restore_envelope(&bytes)
.map_err(|e| std::io::Error::other(format!("restore: {e}")))?;
eprintln!(
"spg-server: --replay-only restored {} table(s) from {}",
engine.catalog().table_count(),
p.display()
);
engine
}
_ => Engine::new(),
};
if let Some(w) = wal_path.as_deref() {
if w.exists() {
let wal_bytes = fs::read(w)?;
let applied = replay_wal_bytes(&wal_bytes, &mut engine)?;
eprintln!(
"spg-server: --replay-only applied {applied} WAL record(s) from {}",
w.display()
);
} else {
eprintln!(
"spg-server: --replay-only WAL path {} doesn't exist; nothing to replay",
w.display()
);
}
}
let _ = engine.snapshot();
Ok(())
}
fn parse_env_usize(env_key: &str) -> Option<usize> {
env::var(env_key)
.ok()
.and_then(|s| s.trim().parse::<usize>().ok())
.filter(|&n| n > 0)
}
fn parse_env_u64(env_key: &str) -> Option<u64> {
env::var(env_key)
.ok()
.and_then(|s| s.trim().parse::<u64>().ok())
.filter(|&n| n > 0)
}
fn parse_cold_preload_env() -> Vec<ColdPreloadSpec> {
let Ok(raw) = env::var("SPG_PRELOAD_COLD_SEGMENT") else {
return Vec::new();
};
let mut out = Vec::new();
for entry in raw.split(';').map(str::trim).filter(|s| !s.is_empty()) {
let parts: Vec<&str> = entry.splitn(3, ':').collect();
if parts.len() != 3 {
eprintln!(
"spg-server: SPG_PRELOAD_COLD_SEGMENT entry {entry:?} \
ignored — expected `table:index:path`"
);
continue;
}
let table = parts[0].trim().to_string();
let index = parts[1].trim().to_string();
let path = PathBuf::from(parts[2].trim());
if table.is_empty() || index.is_empty() || path.as_os_str().is_empty() {
eprintln!(
"spg-server: SPG_PRELOAD_COLD_SEGMENT entry {entry:?} \
ignored — empty table / index / path"
);
continue;
}
out.push(ColdPreloadSpec {
table,
index,
path,
loaded: AtomicBool::new(false),
});
}
if !out.is_empty() {
eprintln!(
"spg-server: cold-tier preload queue has {} spec(s); each one \
will load on the first Op::Query after its table + index \
both exist",
out.len()
);
}
out
}
#[allow(
clippy::too_many_lines,
reason = "single-purpose preload routine; splitting hurts readability more than the line count helps"
)]
fn cluster_id_sidecar_path(wal_path: Option<&Path>, db_path: Option<&Path>) -> Option<PathBuf> {
let base = wal_path.or(db_path)?;
let mut name = base
.file_name()
.map(std::ffi::OsStr::to_os_string)
.unwrap_or_default();
name.push(".cluster_id");
Some(base.with_file_name(name))
}
fn load_or_generate_cluster_id(wal_path: Option<&Path>, db_path: Option<&Path>) -> u64 {
if let Some(p) = cluster_id_sidecar_path(wal_path, db_path) {
if p.exists()
&& let Ok(bytes) = std::fs::read(&p)
&& bytes.len() == 8
{
return u64::from_le_bytes(bytes.try_into().unwrap());
}
let id = generate_cluster_id();
if let Err(e) = std::fs::write(&p, id.to_le_bytes()) {
eprintln!(
"spg-server: cluster_id sidecar write to {} failed: {e} — \
keeping in-memory id (cycle detection won't survive restart)",
p.display()
);
}
id
} else {
generate_cluster_id()
}
}
fn generate_cluster_id() -> u64 {
let pid = u64::from(std::process::id());
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| u64::try_from(d.as_nanos()).unwrap_or(u64::MAX));
let mut x = ts.wrapping_mul(6364136223846793005).wrapping_add(pid);
x = (x ^ (x >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
x = (x ^ (x >> 27)).wrapping_mul(0x94d049bb133111eb);
x ^ (x >> 31)
}
pub(crate) fn auto_analyze_interval_ms() -> u64 {
std::env::var("SPG_AUTO_ANALYZE_INTERVAL_MS")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(30_000)
}
pub(crate) fn spawn_auto_analyze_worker(state: Arc<ServerState>) {
let interval = std::time::Duration::from_millis(auto_analyze_interval_ms());
if interval.is_zero() {
return;
}
thread::Builder::new()
.name("spg-auto-analyze".into())
.spawn(move || {
run_auto_analyze_loop(state, interval);
})
.ok();
}
const AUTO_ANALYZE_TICK: std::time::Duration = std::time::Duration::from_millis(200);
fn run_auto_analyze_loop(state: Arc<ServerState>, interval: std::time::Duration) {
let mut last_sweep = std::time::Instant::now();
loop {
thread::sleep(AUTO_ANALYZE_TICK);
if last_sweep.elapsed() < interval {
continue;
}
last_sweep = std::time::Instant::now();
let needs: Vec<String> = {
let Ok(eng) = state.engine.read() else {
continue;
};
eng.tables_needing_analyze()
};
if needs.is_empty() {
continue;
}
for table in &needs {
let Ok(mut eng) = state.engine.write() else {
break;
};
if eng.catalog().get(table).is_none() {
continue;
}
if let Err(e) = eng.execute(&format!("ANALYZE {}", quote_ident_simple(table))) {
eprintln!("spg-server: auto-analyze {table:?} failed: {e}");
}
}
}
}
fn quote_ident_simple(name: &str) -> String {
let needs_quote = name.is_empty()
|| name
.chars()
.any(|c| !(c.is_ascii_alphanumeric() || c == '_'))
|| name.starts_with(|c: char| c.is_ascii_digit());
if needs_quote {
let escaped: String = name
.chars()
.flat_map(|c| if c == '"' { vec!['"', '"'] } else { vec![c] })
.collect();
format!("\"{escaped}\"")
} else {
name.to_string()
}
}
pub(crate) fn reconcile_subscriptions(state: &Arc<ServerState>) {
use std::collections::BTreeMap;
let want: BTreeMap<String, (String, bool)> = {
let Ok(eng) = state.engine.read() else {
return;
};
eng.subscriptions()
.iter()
.map(|(n, s)| (n.clone(), (s.conn_str.clone(), s.enabled)))
.collect()
};
let Ok(mut workers) = state.sub_workers.lock() else {
return;
};
let stale: Vec<String> = workers
.keys()
.filter(|n| !want.contains_key(n.as_str()))
.cloned()
.collect();
for name in stale {
if let Some(flag) = workers.remove(&name) {
flag.store(true, std::sync::atomic::Ordering::Release);
}
}
for (name, (conn_str, enabled)) in &want {
if !enabled {
continue;
}
if workers.contains_key(name) {
continue;
}
let flag = Arc::new(AtomicBool::new(false));
let flag_for_worker = Arc::clone(&flag);
let state_for_worker = Arc::clone(state);
let name_clone = name.clone();
let conn_clone = conn_str.clone();
let thread_name = format!("spg-sub-{}", name.chars().take(20).collect::<String>());
thread::Builder::new()
.name(thread_name)
.spawn(move || {
replication::run_subscription_worker(
name_clone,
conn_clone,
state_for_worker,
flag_for_worker,
);
})
.ok();
workers.insert(name.clone(), flag);
}
}
pub(crate) fn try_lazy_preload_cold(state: &ServerState) {
if state.cold_preload_done.load(Ordering::Relaxed) {
return;
}
let mut still_pending = 0usize;
for spec in &state.cold_preload {
if spec.loaded.load(Ordering::Relaxed) {
continue;
}
let ready = {
let Ok(engine) = state.engine.read() else {
return;
};
let cat = engine.catalog();
cat.get(&spec.table)
.is_some_and(|t| t.indices().iter().any(|i| i.name == spec.index))
};
if !ready {
still_pending += 1;
continue;
}
let bytes = match std::fs::read(&spec.path) {
Ok(b) => b,
Err(e) => {
eprintln!(
"spg-server: cold preload {}:{} from {} failed: {e}; \
marking loaded to avoid retry storm",
spec.table,
spec.index,
spec.path.display()
);
spec.loaded.store(true, Ordering::Relaxed);
continue;
}
};
let Ok(mut engine) = state.engine.write() else {
return;
};
let mut cat = engine.catalog().clone();
let seg_id = match cat.load_segment_bytes(bytes) {
Ok(id) => id,
Err(e) => {
eprintln!(
"spg-server: cold preload {}:{} parse failed: {e}",
spec.table, spec.index
);
spec.loaded.store(true, Ordering::Relaxed);
continue;
}
};
let pairs: Vec<(spg_storage::IndexKey, spg_storage::RowLocator)> = {
let Some(seg) = cat.cold_segment(seg_id) else {
eprintln!(
"spg-server: cold preload {}:{} segment_id {seg_id} \
vanished after load — should be impossible",
spec.table, spec.index
);
spec.loaded.store(true, Ordering::Relaxed);
continue;
};
seg.scan()
.map(|(key, _payload)| {
(
spg_storage::IndexKey::Int(
i64::try_from(key).expect("cold-segment PK fits in i64"),
),
spg_storage::RowLocator::Cold {
segment_id: seg_id,
page_offset: 0,
},
)
})
.collect()
};
let pairs_count = pairs.len();
let Some(table_mut) = cat.get_mut(&spec.table) else {
eprintln!(
"spg-server: cold preload {}:{} table disappeared mid-load",
spec.table, spec.index
);
spec.loaded.store(true, Ordering::Relaxed);
continue;
};
if let Err(e) = table_mut.register_cold_locators(&spec.index, pairs) {
eprintln!(
"spg-server: cold preload {}:{} register_cold_locators failed: {e}",
spec.table, spec.index
);
spec.loaded.store(true, Ordering::Relaxed);
continue;
}
engine.replace_catalog(cat);
spec.loaded.store(true, Ordering::Relaxed);
if let Ok(mut paths) = state.cold_segment_paths.lock() {
paths.insert(seg_id, spec.path.clone());
}
eprintln!(
"spg-server: cold preload {}:{} loaded {} row(s) from {}",
spec.table,
spec.index,
pairs_count,
spec.path.display()
);
}
if still_pending == 0 {
state.cold_preload_done.store(true, Ordering::Relaxed);
}
}
#[allow(clippy::too_many_lines)] fn run(
addr: &str,
db_path: Option<PathBuf>,
audit_path: Option<PathBuf>,
wal_path: Option<PathBuf>,
password: Option<String>,
limits: Limits,
) -> std::io::Result<()> {
let mut cold_segment_paths: BTreeMap<u32, PathBuf> = BTreeMap::new();
let (base_engine, manifest_wal_baseline) =
restore_engine(db_path.as_deref(), &mut cold_segment_paths)?;
let mut engine = base_engine
.with_clock(wall_clock_micros)
.with_salt_fn(urandom_salt_or_panic);
if let Some(n) = limits.max_query_rows {
engine = engine.with_max_query_rows(n);
}
let engine_byte_cap: Option<usize> = match limits.max_query_bytes {
Some(0) => None,
Some(n) => usize::try_from(n).ok(),
None => usize::try_from(DEFAULT_MAX_QUERY_BYTES).ok(),
};
if let Some(n) = engine_byte_cap {
engine = engine.with_max_query_bytes(n);
}
let audit_log = match &audit_path {
Some(p) if p.exists() => {
let bytes = fs::read(p)?;
let log = AuditLog::deserialize(&bytes).map_err(|e| {
std::io::Error::other(format!("audit log {} rejected: {e}", p.display()))
})?;
eprintln!(
"spg-server: verified audit log {} ({} entries)",
p.display(),
log.len()
);
log
}
Some(p) => {
fs::write(p, AuditLog::header_bytes())?;
eprintln!("spg-server: started fresh audit log at {}", p.display());
AuditLog::new()
}
None => AuditLog::new(),
};
replay_wal_into_engine(&mut engine, wal_path.as_deref(), manifest_wal_baseline)?;
bootstrap_admin_from_env(&mut engine, db_path.as_deref())?;
let (wal, wal_sync_clone) = open_wal_for_append(wal_path.as_deref())?;
let auth_msg = if password.is_some() {
" (AUTH required)"
} else {
""
};
let chaos = ChaosKnobs {
wal_quota_bytes: parse_env_u64("SPG_FAIL_WAL_QUOTA_BYTES"),
disable_wal_preflight: env::var("SPG_DISABLE_WAL_PREFLIGHT")
.ok()
.is_some_and(|s| !s.is_empty() && s != "0"),
};
let cold_preload = parse_cold_preload_env();
let cold_preload_done = AtomicBool::new(cold_preload.is_empty());
let hot_tier_byte_budget =
parse_env_u64("SPG_HOT_TIER_BYTES").unwrap_or(DEFAULT_HOT_TIER_BYTES);
let cluster_id = load_or_generate_cluster_id(wal_path.as_deref(), db_path.as_deref());
let state = Arc::new(ServerState {
engine: RwLock::new(engine),
db_path,
audit_log: Mutex::new(audit_log),
audit_path,
wal,
wal_sync_clone,
wal_path,
commit_queue: Mutex::new(CommitQueueState {
pending: VecDeque::new(),
leader_active: false,
}),
password,
limits,
active_connections: AtomicUsize::new(0),
metrics: Arc::new(observability::Metrics::default()),
chaos,
lag_state: Arc::new(replication::LagState::default()),
cold_preload,
cold_preload_done,
hot_tier_byte_budget,
cold_segment_paths: Mutex::new(cold_segment_paths),
sub_workers: Mutex::new(BTreeMap::new()),
cluster_id,
wal_level: AtomicU8::new(parse_wal_level_env()),
connections: RwLock::new(Vec::new()),
});
let _ = ACTIVITY_STATE.set(Arc::clone(&state));
PREFETCH_HITS_BOOT.with(|cell| {
let hits = cell.take();
if hits > 0 {
state
.metrics
.cold_prefetch_hits
.store(hits, std::sync::atomic::Ordering::Relaxed);
}
});
if let Ok(mut e) = state.engine.write() {
let prev = std::mem::replace(&mut *e, Engine::new());
let slow_us: u64 = std::env::var("SPG_SLOW_QUERY_THRESHOLD_MS")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(100)
* 1_000;
*e = prev
.with_activity_provider(activity_snapshot)
.with_audit_providers(audit_chain_snapshot, audit_verify_snapshot)
.with_slow_query_log(slow_us, log_slow_query);
if let Ok(s) = std::env::var("SPG_PLAN_CACHE_MAX")
&& let Ok(n) = s.parse::<usize>()
{
e.set_plan_cache_max(n);
}
}
reconcile_subscriptions(&state);
spawn_auto_analyze_worker(Arc::clone(&state));
let listener = TcpListener::bind(addr)?;
let local = listener.local_addr()?;
eprintln!("spg-server: listening on {local}{auth_msg}");
spawn_optional_listeners(&state);
if freezer::spawn(Arc::clone(&state)).is_none() {
eprintln!("spg-server: freezer disabled via SPG_FREEZER_DISABLE");
}
if flusher::spawn(Arc::clone(&state)).is_none() {
}
spawn_shutdown_waker(&listener)?;
for stream in listener.incoming() {
if SHUTDOWN_FLAG.load(Ordering::Acquire) {
drop(stream); break;
}
let mut stream = stream?;
let guard = ConnectionGuard::try_claim(&state);
let Some(guard) = guard else {
let peer = stream.peer_addr().ok();
let _ = write_frame(
&mut stream,
&build_error_response(&format!(
"max_connections reached ({} active)",
state.limits.max_connections.unwrap_or(0)
)),
);
eprintln!("spg-server: rejected {peer:?}: max_connections reached");
continue;
};
let state_for_thread = Arc::clone(&state);
thread::spawn(move || {
let _guard = guard; let peer = stream.peer_addr().ok();
if let Err(e) = handle(stream, &state_for_thread) {
eprintln!("spg-server: conn {peer:?}: {e}");
}
});
}
drain_connections(&state);
Ok(())
}
fn restore_engine(
db_path: Option<&Path>,
cold_segment_paths: &mut BTreeMap<u32, PathBuf>,
) -> std::io::Result<(Engine, u64)> {
let mut manifest_wal_baseline: u64 = 0;
let engine = match db_path {
Some(p) if p.exists() => {
let bytes = fs::read(p)?;
let path_str = p.display();
let mut engine = Engine::restore_envelope(&bytes)
.map_err(|e| std::io::Error::other(format!("restore from {path_str}: {e}")))?;
eprintln!(
"spg-server: restored {} table(s), {} user(s) from {path_str}",
engine.catalog().table_count(),
engine.users().len()
);
manifest_wal_baseline =
load_manifest_and_preload_cold(&mut engine, p, &bytes, cold_segment_paths);
engine
}
Some(p) => {
eprintln!(
"spg-server: db file {} does not exist yet — starting fresh",
p.display()
);
Engine::new()
}
None => Engine::new(),
};
Ok((engine, manifest_wal_baseline))
}
fn replay_wal_into_engine(
engine: &mut Engine,
wal_path: Option<&Path>,
manifest_wal_baseline: u64,
) -> std::io::Result<()> {
if let Some(p) = wal_path
&& p.exists()
{
let mut bytes = fs::read(p)?;
if let Ok(s) = env::var("SPG_REPLAY_UPTO")
&& let Ok(upto) = s.trim().parse::<u64>()
{
let upto_usize = usize::try_from(upto).unwrap_or(usize::MAX);
if bytes.len() > upto_usize {
eprintln!(
"spg-server: PITR — truncating WAL replay at offset {upto} \
(of {} total bytes)",
bytes.len()
);
bytes.truncate(upto_usize);
}
}
let baseline_usize = usize::try_from(manifest_wal_baseline).unwrap_or(usize::MAX);
if baseline_usize > 0 && baseline_usize <= bytes.len() {
eprintln!(
"spg-server: manifest skip — WAL replay starts at offset {manifest_wal_baseline} \
(of {} total bytes)",
bytes.len()
);
bytes.drain(..baseline_usize);
} else if baseline_usize > bytes.len() {
eprintln!(
"spg-server: manifest WAL baseline {manifest_wal_baseline} exceeds file size {}; \
replaying from start as a safety net",
bytes.len()
);
}
let applied = replay_wal_bytes(&bytes, engine)?;
eprintln!(
"spg-server: replayed {} WAL entries from {}",
applied,
p.display()
);
if engine.in_transaction() {
eprintln!("spg-server: WAL ended mid-transaction — auto-rollback");
engine
.execute("ROLLBACK")
.map_err(|e| std::io::Error::other(format!("post-replay rollback: {e}")))?;
}
} else if let Some(p) = wal_path {
fs::write(p, b"")?;
eprintln!("spg-server: started fresh WAL at {}", p.display());
}
Ok(())
}
#[allow(clippy::type_complexity)]
fn open_wal_for_append(
wal_path: Option<&Path>,
) -> std::io::Result<(Option<Mutex<File>>, Option<Arc<File>>)> {
match wal_path {
Some(p) => {
let file = OpenOptions::new().append(true).open(p).map_err(|e| {
std::io::Error::other(format!("open WAL {} for append: {e}", p.display()))
})?;
let sync_clone = file.try_clone().ok().map(Arc::new);
Ok((Some(Mutex::new(file)), sync_clone))
}
None => Ok((None, None)),
}
}
fn spawn_optional_listeners(state: &Arc<ServerState>) {
if let Ok(pg_addr) = env::var("SPG_PG_ADDR")
&& !pg_addr.is_empty()
{
match pgwire::spawn_listener(&pg_addr, Arc::clone(state)) {
Ok(pg_local) => eprintln!("spg-server: pg-wire listening on {pg_local}"),
Err(e) => eprintln!("spg-server: pg-wire failed to start on {pg_addr}: {e}"),
}
}
if let Ok(my_addr) = env::var("SPG_MYSQLWIRE_ADDR")
&& !my_addr.is_empty()
{
match mysqlwire::spawn_listener(&my_addr, Arc::clone(state)) {
Ok(my_local) => eprintln!("spg-server: mysql-wire listening on {my_local}"),
Err(e) => eprintln!("spg-server: mysql-wire failed to start on {my_addr}: {e}"),
}
}
if let Ok(http_addr) = env::var("SPG_HTTP_ADDR")
&& !http_addr.is_empty()
{
match observability::spawn_http(&http_addr, Arc::clone(state)) {
Ok(http_local) => eprintln!("spg-server: http listening on {http_local}"),
Err(e) => eprintln!("spg-server: http failed to start on {http_addr}: {e}"),
}
}
if let Ok(repl_addr) = env::var("SPG_REPL_ADDR")
&& !repl_addr.is_empty()
{
match replication::spawn_master_listener(&repl_addr, Arc::clone(state)) {
Ok(repl_local) => {
eprintln!("spg-server: replication listening on {repl_local}");
}
Err(e) => {
eprintln!("spg-server: replication failed to start on {repl_addr}: {e}");
}
}
}
if let Ok(master_addr) = env::var("SPG_FOLLOW_OF")
&& !master_addr.is_empty()
{
if let (Some(db), Some(wal)) = (state.db_path.clone(), state.wal_path.clone()) {
let state_for_follower = Arc::clone(state);
thread::Builder::new()
.name("spg-follower".into())
.spawn(move || {
replication::run_follower(master_addr, db, wal, state_for_follower);
})
.ok();
eprintln!("spg-server: started as follower");
} else {
eprintln!(
"spg-server: SPG_FOLLOW_OF set but db_path or wal_path missing — \
follower mode requires both"
);
}
}
}
fn spawn_shutdown_waker(listener: &TcpListener) -> std::io::Result<()> {
let local = listener.local_addr()?;
thread::Builder::new()
.name("spg-shutdown-waker".into())
.spawn(move || {
while !SHUTDOWN_FLAG.load(Ordering::Acquire) {
thread::sleep(SHUTDOWN_POLL);
}
let _ = TcpStream::connect(local);
})?;
Ok(())
}
fn drain_connections(state: &ServerState) {
let deadline_sec = state
.limits
.shutdown_deadline_sec
.unwrap_or(DEFAULT_SHUTDOWN_DEADLINE_SEC);
let started = Instant::now();
let budget = Duration::from_secs(deadline_sec);
eprintln!(
"spg-server: shutdown signal received — draining {} connection(s), deadline {}s",
state.active_connections.load(Ordering::Acquire),
deadline_sec,
);
loop {
let active = state.active_connections.load(Ordering::Acquire);
if active == 0 {
eprintln!("spg-server: drained — exiting 0");
return;
}
if started.elapsed() >= budget {
eprintln!(
"spg-server: drain deadline hit with {active} connection(s) still active — exiting 0"
);
return;
}
thread::sleep(SHUTDOWN_POLL);
}
}
#[allow(unsafe_code)]
fn install_shutdown_handlers() {
extern "C" fn handler(_sig: libc::c_int) {
SHUTDOWN_FLAG.store(true, Ordering::Release);
}
unsafe {
libc::signal(libc::SIGTERM, handler as *const () as libc::sighandler_t);
libc::signal(libc::SIGINT, handler as *const () as libc::sighandler_t);
}
}
struct ConnectionGuard {
state: Arc<ServerState>,
}
impl ConnectionGuard {
fn try_claim(state: &Arc<ServerState>) -> Option<Self> {
let max = state.limits.max_connections;
loop {
let current = state.active_connections.load(Ordering::Acquire);
if let Some(cap) = max
&& current >= cap
{
return None;
}
if state
.active_connections
.compare_exchange_weak(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return Some(Self {
state: Arc::clone(state),
});
}
}
}
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
self.state.active_connections.fetch_sub(1, Ordering::AcqRel);
}
}
fn handle(mut stream: TcpStream, state: &Arc<ServerState>) -> std::io::Result<()> {
let _ = stream.set_nodelay(true);
if let Some(secs) = state.limits.idle_timeout_sec {
let _ = stream.set_read_timeout(Some(std::time::Duration::from_secs(secs)));
}
let mut buf: Vec<u8> = Vec::with_capacity(READ_CHUNK);
let mut chunk = [0u8; READ_CHUNK];
let mut role = initial_role(state)?;
let mut in_tx = false;
loop {
let n = match stream.read(&mut chunk) {
Ok(n) => n,
Err(e)
if matches!(
e.kind(),
std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut
) =>
{
let _ = write_frame(
&mut stream,
&build_error_response("idle timeout reached, closing connection"),
);
return Ok(());
}
Err(e) => return Err(e),
};
if n == 0 {
return Ok(());
}
buf.extend_from_slice(&chunk[..n]);
loop {
match decode(&buf) {
Ok((frame, consumed)) => {
buf.drain(..consumed);
dispatch(&mut stream, &frame, state, &mut role, &mut in_tx)?;
}
Err(FrameError::ShortHeader | FrameError::ShortPayload) => break,
Err(e) => {
let _ = write_frame(&mut stream, &build_error_response(&e.to_string()));
return Err(std::io::Error::other(e.to_string()));
}
}
}
}
}
fn initial_role(state: &ServerState) -> std::io::Result<Option<Role>> {
let has_users = {
let engine = state
.engine
.read()
.map_err(|_| std::io::Error::other("engine rwlock poisoned"))?;
!engine.users().is_empty()
};
if has_users || state.password.is_some() {
Ok(None)
} else {
Ok(Some(Role::Admin))
}
}
fn current_role(role: Option<Role>) -> Role {
role.expect("dispatch already gated on role.is_some()")
}
fn sql_is_user_mgmt(sql: &str) -> bool {
let lower = sql.trim_start().to_ascii_lowercase();
(lower.starts_with("create ") && lower["create ".len()..].trim_start().starts_with("user"))
|| (lower.starts_with("drop ") && lower["drop ".len()..].trim_start().starts_with("user"))
}
fn sql_is_tx_control(sql: &str) -> bool {
let lower = sql.trim_start().to_ascii_lowercase();
let first_word = lower
.split(|c: char| c.is_whitespace() || c == ';')
.next()
.unwrap_or("");
matches!(
first_word,
"begin" | "start" | "commit" | "rollback" | "savepoint" | "release" | "end"
)
}
fn sql_is_read_only(sql: &str) -> bool {
let bytes = sql.as_bytes();
let mut i = 0;
while i < bytes.len()
&& (bytes[i] == b' ' || bytes[i] == b'\t' || bytes[i] == b'\n' || bytes[i] == b'\r')
{
i += 1;
}
let start = i;
while i < bytes.len() && bytes[i].is_ascii_alphabetic() {
i += 1;
}
let kw = &bytes[start..i];
matches!(kw.len(), 4 | 6) && {
let mut lower = [0u8; 6];
for (k, b) in kw.iter().enumerate() {
lower[k] = b.to_ascii_lowercase();
}
let s = &lower[..kw.len()];
s == b"select" || s == b"show"
}
}
pub(crate) const WAL_LEVEL_REPLICA: u8 = 0;
pub(crate) const WAL_LEVEL_LOGICAL: u8 = 1;
fn parse_wal_level_env() -> u8 {
match std::env::var("SPG_WAL_LEVEL")
.ok()
.map(|s| s.to_ascii_lowercase())
.as_deref()
{
None | Some("") | Some("replica") => WAL_LEVEL_REPLICA,
Some("logical") => WAL_LEVEL_LOGICAL,
Some(other) => {
eprintln!(
"spg-server: SPG_WAL_LEVEL={other:?} unknown — defaulting to replica. \
Valid values: replica, logical"
);
WAL_LEVEL_REPLICA
}
}
}
pub(crate) fn wal_level_label(v: u8) -> &'static str {
match v {
WAL_LEVEL_LOGICAL => "logical",
_ => "replica",
}
}
fn sql_looks_like_set_wal_level(sql: &str) -> bool {
let trimmed = sql.trim_start().to_ascii_lowercase();
trimmed.starts_with("set effective_wal_level")
}
fn sql_looks_like_show_wal_level(sql: &str) -> bool {
let trimmed = sql.trim_start().to_ascii_lowercase();
trimmed == "show effective_wal_level"
|| trimmed.starts_with("show effective_wal_level ")
|| trimmed.starts_with("show effective_wal_level;")
}
fn parse_set_wal_level_value(sql: &str) -> Result<u8, String> {
let lower = sql.trim().to_ascii_lowercase();
let rest = lower
.strip_prefix("set effective_wal_level")
.ok_or_else(|| "expected `set effective_wal_level …`".to_string())?
.trim_start();
let val_part = if let Some(r) = rest.strip_prefix('=') {
r.trim()
} else if let Some(r) = rest.strip_prefix("to ") {
r.trim()
} else {
return Err("expected `=` or `TO` after effective_wal_level".to_string());
};
let value = val_part
.trim_matches(|c: char| matches!(c, '\'' | '"' | ';'))
.trim();
match value {
"replica" => Ok(WAL_LEVEL_REPLICA),
"logical" => Ok(WAL_LEVEL_LOGICAL),
other => Err(format!(
"unknown effective_wal_level {other:?}; expected `replica` or `logical`"
)),
}
}
fn handle_set_wal_level(
stream: &mut TcpStream,
state: &Arc<ServerState>,
sql: &str,
) -> std::io::Result<()> {
match parse_set_wal_level_value(sql) {
Ok(level) => {
state.wal_level.store(level, Ordering::Release);
emit_result(
stream,
Ok(spg_engine::QueryResult::CommandOk {
affected: 1,
modified_catalog: false,
}),
)
}
Err(msg) => write_frame(stream, &build_error_response(&msg)),
}
}
fn handle_show_wal_level(stream: &mut TcpStream, state: &Arc<ServerState>) -> std::io::Result<()> {
let level = state.wal_level.load(Ordering::Acquire);
let row = vec![Row::new(vec![Value::Text(
wal_level_label(level).to_string(),
)])];
let columns = vec![ColumnSchema::new(
"effective_wal_level",
DataType::Text,
false,
)];
emit_result(
stream,
Ok(spg_engine::QueryResult::Rows { columns, rows: row }),
)
}
fn sql_looks_like_wait_for(sql: &str) -> bool {
let trimmed = sql.trim_start();
if trimmed.len() < 4 {
return false;
}
trimmed.as_bytes()[..4]
.iter()
.zip(b"WAIT")
.all(|(a, b)| a.to_ascii_uppercase() == *b)
&& trimmed
.as_bytes()
.get(4)
.is_some_and(u8::is_ascii_whitespace)
}
fn handle_wait_for_wal_position(
stream: &mut TcpStream,
state: &Arc<ServerState>,
target: u64,
timeout_ms: Option<u64>,
) -> std::io::Result<()> {
const POLL: std::time::Duration = std::time::Duration::from_millis(5);
let deadline =
timeout_ms.map(|ms| std::time::Instant::now() + std::time::Duration::from_millis(ms));
loop {
let current = state.lag_state.follower_applied_pos.load(Ordering::Acquire);
if current >= target {
return emit_result(
stream,
Ok(spg_engine::QueryResult::CommandOk {
affected: 1,
modified_catalog: false,
}),
);
}
if let Some(d) = deadline
&& std::time::Instant::now() >= d
{
return emit_result(
stream,
Ok(spg_engine::QueryResult::CommandOk {
affected: 0,
modified_catalog: false,
}),
);
}
std::thread::sleep(POLL);
}
}
#[allow(clippy::too_many_lines)] fn dispatch(
stream: &mut TcpStream,
frame: &Frame,
state: &Arc<ServerState>,
role: &mut Option<Role>,
in_tx: &mut bool,
) -> std::io::Result<()> {
if role.is_none() && !matches!(frame.op, Op::Ping | Op::Auth | Op::AuthUser) {
return write_frame(
stream,
&build_error_response("authentication required: send AUTH first"),
);
}
match frame.op {
Op::Ping => write_frame(stream, &Frame::pong()),
Op::Auth => {
let candidate = match parse_auth(frame) {
Ok(s) => s,
Err(e) => return write_frame(stream, &build_error_response(&e.to_string())),
};
let users_exist = state.engine.read().is_ok_and(|e| !e.users().is_empty());
if users_exist {
return write_frame(
stream,
&build_error_response("RBAC active: use AUTH USER <name> <password>"),
);
}
let ok = match &state.password {
Some(pw) => candidate == pw,
None => true,
};
if ok {
*role = Some(Role::Admin);
write_frame(stream, &Frame::pong())
} else {
write_frame(stream, &build_error_response("AUTH: wrong password"))
}
}
Op::AuthUser => {
let (user, pw) = match parse_auth_user(frame) {
Ok(t) => t,
Err(e) => return write_frame(stream, &build_error_response(&e.to_string())),
};
let verified = state
.engine
.read()
.map_err(|_| std::io::Error::other("engine rwlock poisoned"))?
.verify_user(user, pw);
match verified {
Some(r) => {
*role = Some(r);
write_frame(stream, &Frame::pong())
}
None => write_frame(stream, &build_error_response("AUTH: invalid credentials")),
}
}
Op::Stats => {
let body = render_stats(state)?;
write_frame(stream, &build_stats_response(&body))
}
Op::Query => handle_query_op(stream, frame, state, *role, in_tx),
Op::Pong
| Op::RowDescription
| Op::DataRow
| Op::DataRowBatch
| Op::CommandComplete
| Op::ErrorResponse
| Op::StatsResponse => write_frame(
stream,
&Frame::error("client → server opcode not accepted on this side"),
),
Op::Error => write_frame(
stream,
&Frame::error("clients should not send Error frames"),
),
}
}
fn handle_query_op(
stream: &mut TcpStream,
frame: &Frame,
state: &Arc<ServerState>,
role: Option<Role>,
in_tx: &mut bool,
) -> std::io::Result<()> {
state.metrics.queries_total.fetch_add(1, Ordering::Relaxed);
try_lazy_preload_cold(state);
let sql = match parse_query(frame) {
Ok(s) => s.to_string(),
Err(e) => {
state.metrics.errors_total.fetch_add(1, Ordering::Relaxed);
return write_frame(stream, &build_error_response(&e.to_string()));
}
};
let _slow_log = SlowLogGuard::new(state, &sql, role);
if sql_looks_like_wait_for(&sql)
&& let Ok(stmt) = spg_sql::parser::parse_statement(&sql)
&& let spg_sql::ast::Statement::WaitForWalPosition { pos, timeout_ms } = stmt
{
return handle_wait_for_wal_position(stream, state, pos, timeout_ms);
}
if sql_looks_like_show_wal_level(&sql) {
return handle_show_wal_level(stream, state);
}
if sql_looks_like_set_wal_level(&sql) {
return handle_set_wal_level(stream, state, &sql);
}
if !*in_tx && sql_is_read_only(&sql) {
let cancel_flag = Arc::new(AtomicBool::new(false));
let watchdog = spawn_query_watchdog(state, &cancel_flag);
let engine = state
.engine
.read()
.map_err(|_| std::io::Error::other("engine rwlock poisoned"))?;
let budget = usize::try_from(
state
.limits
.max_query_bytes
.unwrap_or(DEFAULT_MAX_QUERY_BYTES),
)
.unwrap_or(usize::MAX);
alloc_budget::reset_query_budget(budget, &cancel_flag);
let result = engine
.execute_readonly_with_cancel(&sql, spg_engine::CancelToken::from_flag(&cancel_flag));
alloc_budget::clear_query_budget();
drop(engine);
watchdog.cancel();
if !matches!(&result, Err(EngineError::WriteRequired)) {
return emit_result(stream, result);
}
}
let acting = current_role(role);
if sql_is_user_mgmt(&sql) {
if !acting.can_manage_users() {
return write_frame(
stream,
&build_error_response("permission denied: user management requires admin role"),
);
}
} else if !acting.can_write() {
return write_frame(
stream,
&build_error_response("permission denied: write requires admin or readwrite role"),
);
}
if let Some(backup_intent) = parse_backup_intent(&sql) {
if !acting.can_manage_users() {
return write_frame(
stream,
&build_error_response("permission denied: BACKUP requires admin role"),
);
}
return run_backup_command(stream, state, &backup_intent);
}
if parse_checkpoint_intent(&sql) {
if !acting.can_manage_users() {
return write_frame(
stream,
&build_error_response("permission denied: CHECKPOINT requires admin role"),
);
}
return run_checkpoint_command(stream, state);
}
if parse_compact_cold_segments_intent(&sql) {
if !acting.can_manage_users() {
return write_frame(
stream,
&build_error_response(
"permission denied: COMPACT COLD SEGMENTS requires admin role",
),
);
}
return run_compact_cold_segments_command(stream, state);
}
let needs_wrap = !*in_tx && state.wal.is_some() && !sql_is_tx_control(&sql);
if let Some(quota) = state.chaos.wal_quota_bytes
&& let Some(wal_path) = &state.wal_path
&& !state.chaos.disable_wal_preflight
{
let cur = fs::metadata(wal_path).map_or(0, |m| m.len());
let needed = if needs_wrap {
wal_v3_auto_commit_size(&sql)
} else {
4 + sql.len() as u64
};
if cur.saturating_add(needed) > quota {
return write_frame(
stream,
&build_error_response(&format!(
"wal quota exceeded: cur={cur} + {needed} > quota={quota} (SPG_FAIL_WAL_QUOTA_BYTES)"
)),
);
}
}
let cancel_flag = Arc::new(AtomicBool::new(false));
let watchdog = spawn_query_watchdog(state, &cancel_flag);
let (result, wal_result, snapshot) = if needs_wrap {
let (ack_tx, ack_rx) = mpsc::sync_channel::<CommitResult>(1);
let task = CommitTask {
sql: sql.clone(),
cancel_flag: Arc::clone(&cancel_flag),
ack: ack_tx,
};
let became_leader = enqueue_commit_task(state, task);
if became_leader {
run_leader_commit_round(state);
}
let CommitResult {
result,
wal_outcome,
} = ack_rx.recv().map_err(|_| {
std::io::Error::other("commit barrier: ack channel closed before result arrived")
})?;
*in_tx = state.engine.read().is_ok_and(|e| e.in_transaction());
(result, wal_outcome, None)
} else {
let mut engine = state
.engine
.write()
.map_err(|_| std::io::Error::other("engine rwlock poisoned"))?;
let budget = usize::try_from(
state
.limits
.max_query_bytes
.unwrap_or(DEFAULT_MAX_QUERY_BYTES),
)
.unwrap_or(usize::MAX);
alloc_budget::reset_query_budget(budget, &cancel_flag);
let result =
engine.execute_with_cancel(&sql, spg_engine::CancelToken::from_flag(&cancel_flag));
alloc_budget::clear_query_budget();
let was_command_ok = matches!(result, Ok(QueryResult::CommandOk { .. }));
let wal_result = if was_command_ok && state.wal.is_some() {
append_wal(state, &sql)
} else {
Ok(())
};
*in_tx = engine.in_transaction();
let snapshot = if state.db_path.is_some() && state.wal.is_none() {
match &result {
Ok(QueryResult::CommandOk {
modified_catalog: true,
..
}) => Some(engine.snapshot()),
_ => None,
}
} else {
None
};
drop(engine);
(result, wal_result, snapshot)
};
watchdog.cancel();
if let (Some(bytes), Some(path)) = (snapshot.as_ref(), state.db_path.as_deref())
&& let Err(e) = write_atomic(path, bytes)
{
let _ = write_frame(
stream,
&build_error_response(&format!("snapshot write failed: {e}")),
);
return Err(e);
}
if let (Some(bytes), Some(path)) = (snapshot.as_ref(), state.db_path.as_deref()) {
let paths_snapshot = state
.cold_segment_paths
.lock()
.map(|g| g.clone())
.unwrap_or_default();
let wal_len = state
.wal_path
.as_deref()
.and_then(|p| fs::metadata(p).ok())
.map_or(0, |m| m.len());
write_manifest_alongside(path, bytes, &paths_snapshot, wal_len);
}
if let Err(e) = wal_result {
let _ = write_frame(
stream,
&build_error_response(&format!("WAL append failed: {e}")),
);
return Err(e);
}
if state.audit_path.is_some()
&& matches!(
result,
Ok(QueryResult::CommandOk {
modified_catalog: true,
..
})
)
&& let Err(e) = append_audit(state, &sql)
{
let _ = write_frame(
stream,
&build_error_response(&format!("audit append failed: {e}")),
);
return Err(e);
}
if matches!(
result,
Ok(QueryResult::CommandOk {
modified_catalog: true,
..
})
) {
reconcile_subscriptions(state);
}
emit_result(stream, result)
}
pub(crate) fn append_audit_pub(state: &ServerState, sql: &str) -> std::io::Result<()> {
append_audit(state, sql)
}
fn append_audit(state: &ServerState, sql: &str) -> std::io::Result<()> {
let ts_ms = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| d.as_millis()),
)
.unwrap_or(u64::MAX);
let mut log = state
.audit_log
.lock()
.map_err(|_| std::io::Error::other("audit mutex poisoned"))?;
log.append(sql.to_string(), ts_ms);
if let Some(path) = state.audit_path.as_deref() {
let mut entry_bytes = Vec::new();
log.encode_entry_to(log.len() - 1, &mut entry_bytes);
let mut f = OpenOptions::new().append(true).open(path)?;
f.write_all(&entry_bytes)?;
}
Ok(())
}
pub(crate) fn wall_clock_micros() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| i64::try_from(d.as_micros()).unwrap_or(i64::MAX))
}
fn bootstrap_admin_from_env(engine: &mut Engine, db_path: Option<&Path>) -> std::io::Result<()> {
if !engine.users().is_empty() {
return Ok(());
}
let pg_user = env::var("POSTGRES_USER").ok().filter(|s| !s.is_empty());
let pg_pass = env::var("POSTGRES_PASSWORD").ok().filter(|s| !s.is_empty());
if let Some(db) = env::var("POSTGRES_DB").ok().filter(|s| !s.is_empty()) {
eprintln!(
"spg-server: POSTGRES_DB={db:?} accepted — SPG is single-database, \
every connection sees the same catalog"
);
}
let (pw, user) = match (
env::var("SPG_ADMIN_PASSWORD")
.ok()
.filter(|s| !s.is_empty()),
pg_pass,
) {
(Some(spg_pw), _) => {
let user = env::var("SPG_ADMIN_USER")
.ok()
.filter(|s| !s.is_empty())
.or(pg_user)
.unwrap_or_else(|| "admin".to_string());
(spg_pw, user)
}
(None, Some(pg_pw)) => {
let user = pg_user.unwrap_or_else(|| "postgres".to_string());
(pg_pw, user)
}
(None, None) => return Ok(()),
};
let salt = random_salt()?;
engine
.create_user(&user, &pw, Role::Admin, salt)
.map_err(|e| std::io::Error::other(format!("bootstrap admin {user:?}: {e}")))?;
eprintln!("spg-server: bootstrapped admin user {user:?} from env");
if let Some(p) = db_path {
let snapshot = engine.snapshot();
if let Err(e) = write_atomic(p, &snapshot) {
eprintln!("spg-server: warning — failed to persist bootstrap admin: {e}");
} else {
write_manifest_alongside(p, &snapshot, &BTreeMap::new(), 0);
}
}
Ok(())
}
struct SlowLogGuard<'a> {
threshold_us: Option<u64>,
sql: &'a str,
role: Option<Role>,
start: Instant,
}
impl<'a> SlowLogGuard<'a> {
fn new(state: &ServerState, sql: &'a str, role: Option<Role>) -> Self {
Self {
threshold_us: state
.limits
.slow_query_log_ms
.map(|ms| ms.saturating_mul(1000)),
sql,
role,
start: Instant::now(),
}
}
}
impl Drop for SlowLogGuard<'_> {
fn drop(&mut self) {
let Some(threshold_us) = self.threshold_us else {
return;
};
let elapsed_us = u64::try_from(self.start.elapsed().as_micros()).unwrap_or(u64::MAX);
if elapsed_us < threshold_us {
return;
}
let mut sql_escaped = String::with_capacity(self.sql.len() + 16);
json_escape_into(self.sql, &mut sql_escaped);
let role_str = self.role.map_or("unauth", Role::as_str);
eprintln!(
r#"{{"event":"slow_query","sql":"{sql_escaped}","elapsed_us":{elapsed_us},"role":"{role_str}","threshold_us":{threshold_us}}}"#
);
}
}
fn json_escape_into(s: &str, out: &mut String) {
use std::fmt::Write as _;
for ch in s.chars() {
match ch {
'"' => out.push_str("\\\""),
'\\' => out.push_str("\\\\"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
'\u{0008}' => out.push_str("\\b"),
'\u{000c}' => out.push_str("\\f"),
c if (c as u32) < 0x20 => {
let _ = write!(out, "\\u{:04x}", c as u32);
}
c => out.push(c),
}
}
}
struct Watchdog {
completed: Arc<AtomicBool>,
}
impl Watchdog {
fn cancel(&self) {
self.completed.store(true, Ordering::Release);
}
}
fn spawn_query_watchdog(state: &ServerState, cancel_flag: &Arc<AtomicBool>) -> Watchdog {
let completed = Arc::new(AtomicBool::new(false));
let timeout_dur = state
.limits
.query_timeout_ms
.map(std::time::Duration::from_millis);
let cpu_dur = state
.limits
.max_query_ns
.map(std::time::Duration::from_nanos);
let total = match (timeout_dur, cpu_dur) {
(Some(a), Some(b)) => Some(a.min(b)),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
};
let Some(total) = total else {
return Watchdog { completed };
};
let cancel_flag = Arc::clone(cancel_flag);
let completed_for_thread = Arc::clone(&completed);
thread::spawn(move || {
let slice = (total / 50).max(std::time::Duration::from_micros(100));
let start = std::time::Instant::now();
while start.elapsed() < total {
if completed_for_thread.load(Ordering::Acquire) {
return;
}
thread::sleep(slice);
}
cancel_flag.store(true, Ordering::Release);
});
Watchdog { completed }
}
fn random_salt() -> std::io::Result<[u8; 16]> {
let mut buf = [0u8; 16];
File::open("/dev/urandom")?.read_exact(&mut buf)?;
Ok(buf)
}
fn urandom_salt_or_panic() -> [u8; 16] {
random_salt().expect("/dev/urandom unreadable — refusing to create users without entropy")
}
fn write_atomic(path: &Path, data: &[u8]) -> std::io::Result<()> {
let dir = path.parent().unwrap_or_else(|| Path::new("."));
let pid = process::id();
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| d.subsec_nanos());
let tmp = dir.join(format!(".spg-tmp-{pid}-{nanos}"));
fs::write(&tmp, data)?;
if let Err(e) = fs::rename(&tmp, path) {
let _ = fs::remove_file(&tmp);
return Err(e);
}
Ok(())
}
pub(crate) fn write_manifest_alongside(
db_path: &Path,
snapshot_bytes: &[u8],
cold_segment_paths: &BTreeMap<u32, PathBuf>,
wal_baseline_offset: u64,
) {
let mp = manifest::manifest_path(db_path);
if let Some(dir) = mp.parent()
&& let Err(e) = fs::create_dir_all(dir)
{
eprintln!(
"spg-server: manifest dir {} mkdir failed: {e}",
dir.display()
);
return;
}
let cold_segments: Vec<manifest::ColdSegmentEntry> = cold_segment_paths
.iter()
.filter_map(|(&segment_id, path)| match fs::read(path) {
Ok(bytes) => Some(manifest::ColdSegmentEntry {
segment_id,
path: path.clone(),
crc32: spg_crypto::crc32::crc32(&bytes),
}),
Err(e) => {
eprintln!(
"spg-server: manifest skip segment {segment_id}: read {} failed: {e}",
path.display()
);
None
}
})
.collect();
let m = manifest::CatalogManifest {
catalog_crc32: spg_crypto::crc32::crc32(snapshot_bytes),
cold_segments,
wal_baseline_offset,
};
let bytes = m.serialize();
if let Err(e) = write_atomic(&mp, &bytes) {
eprintln!("spg-server: manifest write to {} failed: {e}", mp.display());
}
}
fn load_manifest_and_preload_cold(
engine: &mut Engine,
db_path: &Path,
snapshot_bytes: &[u8],
cold_segment_paths: &mut BTreeMap<u32, PathBuf>,
) -> u64 {
let mp = manifest::manifest_path(db_path);
if !mp.exists() {
return 0;
}
let bytes = match fs::read(&mp) {
Ok(b) => b,
Err(e) => {
eprintln!("spg-server: manifest read {} failed: {e}", mp.display());
return 0;
}
};
let m = match manifest::CatalogManifest::deserialize(&bytes) {
Ok(m) => m,
Err(e) => {
eprintln!("spg-server: manifest {} rejected: {e}", mp.display());
return 0;
}
};
let snapshot_crc = spg_crypto::crc32::crc32(snapshot_bytes);
if snapshot_crc != m.catalog_crc32 {
eprintln!(
"spg-server: manifest {} catalog CRC mismatch (expected={:#010x}, file={:#010x}); \
falling back to WAL-only replay",
mp.display(),
m.catalog_crc32,
snapshot_crc,
);
return 0;
}
let mut cat = engine.catalog().clone();
let mut loaded: usize = 0;
let mut skipped: usize = 0;
let paths_to_read: Vec<(u32, PathBuf)> = m
.cold_segments
.iter()
.map(|e| (e.segment_id, e.path.clone()))
.collect();
let workers = prefetch::worker_count_from_env();
let prefetched = prefetch::parallel_read_segments(&paths_to_read, workers, None);
let read_map: std::collections::BTreeMap<u32, std::io::Result<Vec<u8>>> =
prefetched.into_iter().collect();
let mut prefetch_hits: u64 = 0;
for entry in &m.cold_segments {
let bytes_result = read_map
.get(&entry.segment_id)
.map(|r| match r {
Ok(b) => Ok(b.clone()),
Err(e) => Err(std::io::Error::new(e.kind(), e.to_string())),
})
.unwrap_or_else(|| {
Err(std::io::Error::other(format!(
"no prefetch result for segment {}",
entry.segment_id
)))
});
match bytes_result {
Ok(seg_bytes) => {
let computed = spg_crypto::crc32::crc32(&seg_bytes);
if computed != entry.crc32 {
eprintln!(
"spg-server: manifest skip segment {}: CRC mismatch ({} != {})",
entry.segment_id, computed, entry.crc32
);
skipped += 1;
continue;
}
match cat.load_segment_bytes_at(entry.segment_id, seg_bytes) {
Ok(()) => {
cold_segment_paths.insert(entry.segment_id, entry.path.clone());
loaded += 1;
prefetch_hits += 1;
}
Err(e) => {
eprintln!(
"spg-server: manifest segment {} load failed: {e}",
entry.segment_id
);
skipped += 1;
}
}
}
Err(e) => {
eprintln!(
"spg-server: manifest skip segment {}: read {} failed: {e}",
entry.segment_id,
entry.path.display()
);
skipped += 1;
}
}
}
engine.replace_catalog(cat);
PREFETCH_HITS_BOOT.with(|cell| cell.set(prefetch_hits));
eprintln!(
"spg-server: manifest {} loaded {loaded} cold segment(s), skipped {skipped}; wal_baseline_offset={}",
mp.display(),
m.wal_baseline_offset,
);
m.wal_baseline_offset
}