use std::io::{Read, Seek, SeekFrom, Write};
use std::net::{TcpListener, TcpStream};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::thread;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use spg_engine::Engine;
use spg_sql::ast::PublicationScope;
use crate::ServerState;
const MAGIC_V1: &[u8; 8] = b"SPGREPL\x01";
pub(crate) const MAGIC_SUB: &[u8; 8] = b"SPGSUB\x01\x00";
const MAGIC_V2: &[u8; 8] = b"SPGREPL\x02";
const FRAME_TYPE_WAL: u8 = 0x00;
const FRAME_TYPE_STATUS: u8 = 0x01;
const FRAME_TYPE_SKIP: u8 = 0x02;
const FRAME_TYPE_SEGMENT_FILE_CHUNK: u8 = 0x03;
const SEGMENT_CHUNK_SIZE_BYTES: usize = 4 * 1024 * 1024;
const SEGMENT_CHUNK_HEADER_MAX_BYTES: u32 = 16 * 1024 * 1024;
const SEGMENT_CHUNK_HEADER_LEN: usize = 16;
fn encode_segment_chunk_payload(
segment_id: u32,
chunk_seq: u32,
chunk_total: u32,
chunk_bytes: &[u8],
) -> Vec<u8> {
let mut out = Vec::with_capacity(SEGMENT_CHUNK_HEADER_LEN + chunk_bytes.len());
out.extend_from_slice(&segment_id.to_le_bytes());
out.extend_from_slice(&chunk_seq.to_le_bytes());
out.extend_from_slice(&chunk_total.to_le_bytes());
out.extend_from_slice(&(chunk_bytes.len() as u32).to_le_bytes());
out.extend_from_slice(chunk_bytes);
out
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct SegmentChunk<'a> {
segment_id: u32,
chunk_seq: u32,
chunk_total: u32,
body: &'a [u8],
}
fn decode_segment_chunk(payload: &[u8]) -> std::io::Result<SegmentChunk<'_>> {
if payload.len() < SEGMENT_CHUNK_HEADER_LEN {
return Err(std::io::Error::other(format!(
"segment-chunk: payload too short ({} < {SEGMENT_CHUNK_HEADER_LEN})",
payload.len()
)));
}
let segment_id = u32::from_le_bytes(payload[0..4].try_into().unwrap());
let chunk_seq = u32::from_le_bytes(payload[4..8].try_into().unwrap());
let chunk_total = u32::from_le_bytes(payload[8..12].try_into().unwrap());
let chunk_bytes = u32::from_le_bytes(payload[12..16].try_into().unwrap());
if chunk_bytes > SEGMENT_CHUNK_HEADER_MAX_BYTES {
return Err(std::io::Error::other(format!(
"segment-chunk: chunk_bytes {chunk_bytes} > cap {SEGMENT_CHUNK_HEADER_MAX_BYTES}"
)));
}
let body_len = chunk_bytes as usize;
if payload.len() != SEGMENT_CHUNK_HEADER_LEN + body_len {
return Err(std::io::Error::other(format!(
"segment-chunk: declared chunk_bytes {body_len} ≠ payload tail {}",
payload.len() - SEGMENT_CHUNK_HEADER_LEN
)));
}
if chunk_total == 0 || chunk_seq >= chunk_total {
return Err(std::io::Error::other(format!(
"segment-chunk: out-of-range chunk_seq {chunk_seq} / chunk_total {chunk_total}"
)));
}
Ok(SegmentChunk {
segment_id,
chunk_seq,
chunk_total,
body: &payload[SEGMENT_CHUNK_HEADER_LEN..],
})
}
const TAIL_POLL: Duration = Duration::from_millis(50);
const STATUS_INTERVAL: Duration = Duration::from_millis(50);
const RECONNECT_DELAY: Duration = Duration::from_millis(500);
#[derive(Debug)]
pub struct LagState {
pub primary_pos: AtomicU64,
pub follower_applied_pos: AtomicU64,
pub primary_wall_time_us: AtomicU64,
}
impl Default for LagState {
fn default() -> Self {
Self {
primary_pos: AtomicU64::new(0),
follower_applied_pos: AtomicU64::new(0),
primary_wall_time_us: AtomicU64::new(0),
}
}
}
pub fn spawn_master_listener(
addr: &str,
state: Arc<ServerState>,
) -> std::io::Result<std::net::SocketAddr> {
let listener = TcpListener::bind(addr)?;
let local = listener.local_addr()?;
thread::Builder::new()
.name("spg-repl-listener".into())
.spawn(move || {
for stream in listener.incoming().flatten() {
let state = Arc::clone(&state);
thread::Builder::new()
.name("spg-repl-stream".into())
.spawn(move || {
if let Err(e) = serve_follower(stream, &state) {
eprintln!("spg-server: replication stream ended: {e}");
}
})
.ok();
}
})?;
Ok(local)
}
fn serve_follower(mut stream: TcpStream, state: &ServerState) -> std::io::Result<()> {
stream.set_read_timeout(Some(Duration::from_secs(30)))?;
let mut hs = [0u8; 16];
stream.read_exact(&mut hs)?;
let protocol = if &hs[..8] == MAGIC_V1 {
Protocol::V1
} else if &hs[..8] == MAGIC_V2 {
Protocol::V2
} else if &hs[..8] == MAGIC_SUB {
Protocol::Sub
} else {
return Err(std::io::Error::other("bad replication magic"));
};
let start_offset = u64::from_le_bytes(hs[8..16].try_into().unwrap());
if matches!(protocol, Protocol::Sub) {
if state.wal_level.load(Ordering::Acquire) != crate::WAL_LEVEL_LOGICAL {
return Err(std::io::Error::other(
"MAGIC_SUB rejected: effective_wal_level must be `logical`",
));
}
let publication_names = read_publication_list(&mut stream)?;
let mut sub_cluster_buf = [0u8; 8];
stream.read_exact(&mut sub_cluster_buf)?;
let subscriber_cluster_id = u64::from_le_bytes(sub_cluster_buf);
let filter = build_publication_filter(state, &publication_names);
let effective_start = if start_offset == 0 {
current_wal_len(state)?
} else {
start_offset
};
stream.write_all(&effective_start.to_le_bytes())?;
stream.write_all(&state.cluster_id.to_le_bytes())?;
stream.flush()?;
if subscriber_cluster_id == state.cluster_id {
eprintln!(
"spg-server: rejecting MAGIC_SUB connection — peer cluster_id matches own ({})",
state.cluster_id
);
return Ok(());
}
let Some(wal_path) = state.wal_path.clone() else {
return Ok(());
};
return tail_wal_v2_filtered(stream, &wal_path, effective_start, filter);
}
let (snapshot, wal_position) = if start_offset == 0 {
capture_snapshot(state)?
} else {
(Vec::new(), start_offset)
};
if start_offset == 0 {
let snap_len = u64::try_from(snapshot.len()).unwrap_or(u64::MAX);
stream.write_all(&snap_len.to_le_bytes())?;
if !snapshot.is_empty() {
stream.write_all(&snapshot)?;
}
stream.write_all(&wal_position.to_le_bytes())?;
} else {
stream.write_all(&0_u64.to_le_bytes())?;
}
stream.flush()?;
if matches!(protocol, Protocol::V2) {
forward_cold_segments(&mut stream, state)?;
}
let Some(wal_path) = state.wal_path.clone() else {
return Ok(());
};
match protocol {
Protocol::V1 => tail_wal_v1(stream, &wal_path, wal_position),
Protocol::V2 | Protocol::Sub => tail_wal_v2(stream, &wal_path, wal_position),
}
}
struct SegmentReceiveState {
bytes: Vec<u8>,
expected_total: u32,
next_seq: u32,
skip: bool,
}
fn absorb_segment_chunk(
segment_buffers: &mut std::collections::BTreeMap<u32, SegmentReceiveState>,
chunk: &SegmentChunk<'_>,
db_path: &Path,
) -> std::io::Result<Option<Vec<u8>>> {
let entry = segment_buffers
.entry(chunk.segment_id)
.or_insert_with(|| SegmentReceiveState {
bytes: Vec::new(),
expected_total: chunk.chunk_total,
next_seq: 0,
skip: cold_segment_file_already_present(db_path, chunk.segment_id),
});
if entry.skip {
entry.next_seq = entry.next_seq.saturating_add(1);
if entry.next_seq >= entry.expected_total {
segment_buffers.remove(&chunk.segment_id);
}
return Ok(None);
}
if chunk.chunk_total != entry.expected_total {
return Err(std::io::Error::other(format!(
"segment {}: chunk_total {} ≠ first-seen {}",
chunk.segment_id, chunk.chunk_total, entry.expected_total
)));
}
if chunk.chunk_seq != entry.next_seq {
return Err(std::io::Error::other(format!(
"segment {}: chunk_seq {} out of order (expected {})",
chunk.segment_id, chunk.chunk_seq, entry.next_seq
)));
}
entry.bytes.extend_from_slice(chunk.body);
entry.next_seq += 1;
if entry.next_seq >= entry.expected_total {
let state = segment_buffers
.remove(&chunk.segment_id)
.expect("just inserted");
return Ok(Some(state.bytes));
}
Ok(None)
}
fn cold_segment_file_already_present(db_path: &Path, segment_id: u32) -> bool {
let parent = db_path.parent().unwrap_or_else(|| Path::new("."));
let stem = db_path
.file_stem()
.unwrap_or_else(|| std::ffi::OsStr::new("db"))
.to_string_lossy();
let seg_path = parent
.join(format!("{stem}.spg"))
.join("segments")
.join(format!("seg_{segment_id}.spg"));
seg_path.exists()
}
fn commit_received_segment(
segment_id: u32,
bytes: Vec<u8>,
db_path: &Path,
state: &ServerState,
) -> std::io::Result<()> {
let parent = db_path.parent().unwrap_or_else(|| Path::new("."));
let stem = db_path
.file_stem()
.unwrap_or_else(|| std::ffi::OsStr::new("db"))
.to_string_lossy();
let seg_dir = parent.join(format!("{stem}.spg")).join("segments");
std::fs::create_dir_all(&seg_dir)?;
let final_path = seg_dir.join(format!("seg_{segment_id}.spg"));
let tmp_path = seg_dir.join(format!("seg_{segment_id}.spg.tmp"));
std::fs::write(&tmp_path, &bytes)?;
std::fs::rename(&tmp_path, &final_path)?;
{
let mut eng = state
.engine
.write()
.map_err(|_| std::io::Error::other("engine lock poisoned"))?;
eng.receive_cold_segment(segment_id, bytes).map_err(|e| {
std::io::Error::other(format!(
"follower receive_cold_segment(id={segment_id}): {e:?}"
))
})?;
}
if let Ok(mut paths) = state.cold_segment_paths.lock() {
paths.insert(segment_id, final_path);
}
let snap_bytes = match std::fs::read(db_path) {
Ok(b) => b,
Err(e) => {
eprintln!("spg-server: follower manifest refresh skipped — db_path read failed: {e}");
return Ok(());
}
};
let cold_paths = state
.cold_segment_paths
.lock()
.map(|g| g.clone())
.unwrap_or_default();
crate::write_manifest_alongside(db_path, &snap_bytes, &cold_paths, 0);
if let Ok(eng) = state.engine.read() {
state.metrics.cold_segments.store(
eng.catalog().cold_segment_count() as u64,
std::sync::atomic::Ordering::Relaxed,
);
}
Ok(())
}
fn forward_cold_segments(stream: &mut TcpStream, state: &ServerState) -> std::io::Result<()> {
let snapshot: Vec<(u32, std::path::PathBuf)> = {
let Ok(paths) = state.cold_segment_paths.lock() else {
eprintln!(
"spg-server: cold_segment_paths mutex poisoned; \
skipping segment forwarding for this follower"
);
return Ok(());
};
paths.iter().map(|(id, p)| (*id, p.clone())).collect()
};
for (segment_id, path) in snapshot {
let bytes = match std::fs::read(&path) {
Ok(b) => b,
Err(e) => {
eprintln!(
"spg-server: segment forwarding skip seg {segment_id}: read {} failed: {e}",
path.display()
);
continue;
}
};
let total_chunks = bytes.len().div_ceil(SEGMENT_CHUNK_SIZE_BYTES).max(1);
let total_u32 = u32::try_from(total_chunks).map_err(|_| {
std::io::Error::other(format!(
"segment {segment_id}: chunk_total {total_chunks} exceeds u32::MAX"
))
})?;
for seq in 0..total_chunks {
let start = seq * SEGMENT_CHUNK_SIZE_BYTES;
let end = (start + SEGMENT_CHUNK_SIZE_BYTES).min(bytes.len());
let chunk = &bytes[start..end];
let payload = encode_segment_chunk_payload(segment_id, seq as u32, total_u32, chunk);
write_frame(stream, FRAME_TYPE_SEGMENT_FILE_CHUNK, &payload)?;
}
}
Ok(())
}
fn read_publication_list(stream: &mut TcpStream) -> std::io::Result<Vec<String>> {
let mut num_buf = [0u8; 2];
stream.read_exact(&mut num_buf)?;
let num = u16::from_le_bytes(num_buf) as usize;
let mut out = Vec::with_capacity(num);
for _ in 0..num {
let mut len_buf = [0u8; 2];
stream.read_exact(&mut len_buf)?;
let len = u16::from_le_bytes(len_buf) as usize;
let mut name_buf = vec![0u8; len];
if len > 0 {
stream.read_exact(&mut name_buf)?;
}
let name = String::from_utf8(name_buf)
.map_err(|e| std::io::Error::other(format!("publication name not UTF-8: {e}")))?;
out.push(name);
}
Ok(out)
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum OwnerKind {
Dml(String),
Skip,
}
#[derive(Debug, Clone)]
struct PublicationFilter {
any_all_tables: bool,
allow: std::collections::HashSet<String>,
deny_sets: Vec<std::collections::HashSet<String>>,
}
impl PublicationFilter {
fn accept_all() -> Self {
Self {
any_all_tables: true,
allow: std::collections::HashSet::new(),
deny_sets: Vec::new(),
}
}
fn accepts_owner(&self, owner: &str) -> bool {
if self.any_all_tables {
return true;
}
if self.allow.contains(owner) {
return true;
}
self.deny_sets.iter().any(|deny| !deny.contains(owner))
}
}
fn build_publication_filter(state: &ServerState, names: &[String]) -> PublicationFilter {
if names.is_empty() {
return PublicationFilter::accept_all();
}
let eng = match state.engine.read() {
Ok(e) => e,
Err(_) => return PublicationFilter::accept_all(),
};
let pubs = eng.publications();
let mut filter = PublicationFilter {
any_all_tables: false,
allow: std::collections::HashSet::new(),
deny_sets: Vec::new(),
};
for n in names {
let Some(scope) = pubs.get(n) else {
eprintln!(
"spg-server: subscriber requested unknown publication {n:?} — \
contributes no records"
);
continue;
};
match scope {
PublicationScope::AllTables => {
filter.any_all_tables = true;
}
PublicationScope::ForTables(ts) => {
for t in ts {
filter.allow.insert(t.clone());
}
}
PublicationScope::AllTablesExcept(ts) => {
filter.deny_sets.push(ts.iter().cloned().collect());
}
}
}
filter
}
fn extract_owner_from_sql(sql: &str) -> OwnerKind {
let s = sql.trim_start();
let mut chars = s.bytes().enumerate();
let mut verb_end = s.len();
for (i, b) in chars.by_ref() {
if b.is_ascii_whitespace() {
verb_end = i;
break;
}
}
if verb_end == 0 {
return OwnerKind::Skip;
}
let verb = &s[..verb_end];
let upper_first = verb.as_bytes().first().map(|b| b.to_ascii_uppercase());
let after_verb = s[verb_end..].trim_start();
match upper_first {
Some(b'I') if eq_ci(verb, b"INSERT") => {
let (kw, rest) = split_token(after_verb);
if !eq_ci(kw, b"INTO") {
return OwnerKind::Skip;
}
let (owner, _) = split_ident_token(rest.trim_start());
if owner.is_empty() {
OwnerKind::Skip
} else {
OwnerKind::Dml(strip_ident_punct(owner))
}
}
Some(b'U') if eq_ci(verb, b"UPDATE") => {
let (owner, _) = split_ident_token(after_verb);
if owner.is_empty() {
OwnerKind::Skip
} else {
OwnerKind::Dml(strip_ident_punct(owner))
}
}
Some(b'D') if eq_ci(verb, b"DELETE") => {
let (kw, rest) = split_token(after_verb);
if !eq_ci(kw, b"FROM") {
return OwnerKind::Skip;
}
let (owner, _) = split_ident_token(rest.trim_start());
if owner.is_empty() {
OwnerKind::Skip
} else {
OwnerKind::Dml(strip_ident_punct(owner))
}
}
_ => OwnerKind::Skip,
}
}
fn eq_ci(a: &str, b_upper: &[u8]) -> bool {
let ab = a.as_bytes();
if ab.len() != b_upper.len() {
return false;
}
for i in 0..ab.len() {
if ab[i].to_ascii_uppercase() != b_upper[i] {
return false;
}
}
true
}
fn split_token(s: &str) -> (&str, &str) {
let bytes = s.as_bytes();
for (i, b) in bytes.iter().enumerate() {
if b.is_ascii_whitespace() {
return (&s[..i], &s[i..]);
}
}
(s, "")
}
fn split_ident_token(s: &str) -> (&str, &str) {
let bytes = s.as_bytes();
for (i, b) in bytes.iter().enumerate() {
if b.is_ascii_whitespace() || matches!(*b, b'(' | b',' | b';') {
return (&s[..i], &s[i..]);
}
}
(s, "")
}
fn strip_ident_punct(s: &str) -> String {
let mut end = s.len();
while let Some(b) = s.as_bytes().get(end.wrapping_sub(1))
&& matches!(*b, b'(' | b';' | b',' | b'"' | b'\'')
{
end -= 1;
}
let mut start = 0usize;
while let Some(b) = s.as_bytes().get(start)
&& matches!(*b, b'"' | b'\'')
{
start += 1;
}
s[start..end].to_string()
}
fn current_wal_len(state: &ServerState) -> std::io::Result<u64> {
let Some(wal_path) = state.wal_path.as_ref() else {
return Ok(0);
};
let _eng_guard = state
.engine
.read()
.map_err(|_| std::io::Error::other("engine lock poisoned"))?;
Ok(std::fs::metadata(wal_path).map_or(0, |m| m.len()))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Protocol {
V1,
V2,
Sub,
}
fn applied_pos_sidecar_path(wal_path: &Path) -> PathBuf {
let mut name = wal_path
.file_name()
.map(std::ffi::OsStr::to_os_string)
.unwrap_or_default();
name.push(".applied_pos");
wal_path
.parent()
.map_or_else(|| PathBuf::from(&name), |p| p.join(&name))
}
fn applied_pos_sidecar_tmp_path(wal_path: &Path) -> PathBuf {
let mut name = wal_path
.file_name()
.map(std::ffi::OsStr::to_os_string)
.unwrap_or_default();
name.push(".applied_pos.tmp");
wal_path
.parent()
.map_or_else(|| PathBuf::from(&name), |p| p.join(&name))
}
fn read_applied_pos_sidecar(wal_path: &Path) -> Option<u64> {
let bytes = std::fs::read(applied_pos_sidecar_path(wal_path)).ok()?;
let arr: [u8; 8] = bytes.as_slice().try_into().ok()?;
Some(u64::from_le_bytes(arr))
}
fn write_applied_pos_sidecar(wal_path: &Path, pos: u64) -> std::io::Result<()> {
let tmp = applied_pos_sidecar_tmp_path(wal_path);
let dst = applied_pos_sidecar_path(wal_path);
std::fs::write(&tmp, pos.to_le_bytes())?;
std::fs::rename(&tmp, &dst)?;
Ok(())
}
fn capture_snapshot(state: &ServerState) -> std::io::Result<(Vec<u8>, u64)> {
let engine_guard = state
.engine
.write()
.map_err(|_| std::io::Error::other("engine lock poisoned"))?;
let snapshot = engine_guard.snapshot();
let wal_position = match &state.wal_path {
Some(p) if p.exists() => std::fs::metadata(p).map_or(0, |m| m.len()),
_ => 0,
};
drop(engine_guard);
Ok((snapshot, wal_position))
}
fn tail_wal_v1(mut stream: TcpStream, wal_path: &Path, start_offset: u64) -> std::io::Result<()> {
let mut f = std::fs::File::open(wal_path)?;
f.seek(SeekFrom::Start(start_offset))?;
let mut buf = [0u8; 4096];
loop {
let n = f.read(&mut buf)?;
if n == 0 {
thread::sleep(TAIL_POLL);
continue;
}
stream.write_all(&buf[..n])?;
stream.flush()?;
}
}
fn tail_wal_v2(mut stream: TcpStream, wal_path: &Path, start_offset: u64) -> std::io::Result<()> {
let mut f = std::fs::File::open(wal_path)?;
f.seek(SeekFrom::Start(start_offset))?;
let mut current_offset = start_offset;
let mut buf = [0u8; 4096];
let mut last_status = std::time::Instant::now()
.checked_sub(STATUS_INTERVAL)
.unwrap_or_else(std::time::Instant::now);
loop {
let n = f.read(&mut buf)?;
if n > 0 {
write_frame(&mut stream, FRAME_TYPE_WAL, &buf[..n])?;
current_offset = current_offset.saturating_add(n as u64);
}
if n > 0 || last_status.elapsed() >= STATUS_INTERVAL {
let primary_pos = std::fs::metadata(wal_path).map_or(current_offset, |m| m.len());
let wall_time_us = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| d.as_micros()),
)
.unwrap_or(0);
let mut payload = [0u8; 16];
payload[..8].copy_from_slice(&primary_pos.to_le_bytes());
payload[8..].copy_from_slice(&wall_time_us.to_le_bytes());
write_frame(&mut stream, FRAME_TYPE_STATUS, &payload)?;
last_status = std::time::Instant::now();
}
if n == 0 {
thread::sleep(TAIL_POLL);
}
}
}
#[allow(clippy::too_many_lines)] fn tail_wal_v2_filtered(
mut stream: TcpStream,
wal_path: &Path,
start_offset: u64,
filter: PublicationFilter,
) -> std::io::Result<()> {
let mut f = std::fs::File::open(wal_path)?;
f.seek(SeekFrom::Start(start_offset))?;
let mut current_offset = start_offset;
let mut buf = [0u8; 4096];
let mut pending: Vec<u8> = Vec::with_capacity(4096);
let mut last_status = std::time::Instant::now()
.checked_sub(STATUS_INTERVAL)
.unwrap_or_else(std::time::Instant::now);
loop {
let n = f.read(&mut buf)?;
if n > 0 {
pending.extend_from_slice(&buf[..n]);
current_offset = current_offset.saturating_add(n as u64);
let mut cur = 0usize;
let mut skip_run_start: Option<usize> = None;
loop {
if pending.len() - cur < 4 {
break;
}
let len_bytes: [u8; 4] = pending[cur..cur + 4].try_into().unwrap();
let raw_len = u32::from_le_bytes(len_bytes);
let is_v2 = raw_len & crate::WAL_V2_SENTINEL != 0;
let is_v3 = is_v2 && (raw_len & crate::WAL_V3_FLAG != 0);
let len_mask = if is_v3 {
!(crate::WAL_V2_SENTINEL | crate::WAL_V3_FLAG)
} else {
!crate::WAL_V2_SENTINEL
};
let rec_len = (raw_len & len_mask) as usize;
let header_len = if is_v3 {
9
} else if is_v2 {
8
} else {
4
};
let total = header_len + rec_len;
if pending.len() - cur < total {
break;
}
let sql_bytes = &pending[cur + header_len..cur + header_len + rec_len];
let owner_kind = if is_v3 {
let type_byte = pending[cur + 8];
if type_byte == crate::WAL_V3_TYPE_AUTO_COMMIT_SQL {
match core::str::from_utf8(sql_bytes) {
Ok(s) => extract_owner_from_sql(s),
Err(_) => OwnerKind::Skip,
}
} else {
OwnerKind::Skip
}
} else {
match core::str::from_utf8(sql_bytes) {
Ok(s) => extract_owner_from_sql(s),
Err(_) => OwnerKind::Skip,
}
};
let accept = match &owner_kind {
OwnerKind::Dml(owner) => filter.accepts_owner(owner),
OwnerKind::Skip => false,
};
if accept {
if let Some(start) = skip_run_start.take() {
let skipped = (cur - start) as u64;
write_frame(&mut stream, FRAME_TYPE_SKIP, &skipped.to_le_bytes())?;
}
write_frame(&mut stream, FRAME_TYPE_WAL, &pending[cur..cur + total])?;
} else if skip_run_start.is_none() {
skip_run_start = Some(cur);
}
cur += total;
}
if let Some(start) = skip_run_start.take() {
let skipped = (cur - start) as u64;
write_frame(&mut stream, FRAME_TYPE_SKIP, &skipped.to_le_bytes())?;
}
if cur > 0 {
pending.drain(0..cur);
}
}
if n > 0 || last_status.elapsed() >= STATUS_INTERVAL {
let primary_pos = std::fs::metadata(wal_path).map_or(current_offset, |m| m.len());
let wall_time_us = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| d.as_micros()),
)
.unwrap_or(0);
let mut payload = [0u8; 16];
payload[..8].copy_from_slice(&primary_pos.to_le_bytes());
payload[8..].copy_from_slice(&wall_time_us.to_le_bytes());
write_frame(&mut stream, FRAME_TYPE_STATUS, &payload)?;
last_status = std::time::Instant::now();
}
if n == 0 {
thread::sleep(TAIL_POLL);
}
}
}
fn write_frame(stream: &mut TcpStream, frame_type: u8, payload: &[u8]) -> std::io::Result<()> {
let len = u32::try_from(payload.len())
.map_err(|_| std::io::Error::other("replication frame payload too large"))?;
let mut header = [0u8; 5];
header[0] = frame_type;
header[1..].copy_from_slice(&len.to_le_bytes());
stream.write_all(&header)?;
if !payload.is_empty() {
stream.write_all(payload)?;
}
stream.flush()
}
#[allow(clippy::needless_pass_by_value)] pub fn run_follower(
master_addr: String,
db_path: PathBuf,
wal_path: PathBuf,
state: Arc<ServerState>,
) {
loop {
match follow_once(&master_addr, &db_path, &wal_path, &state) {
Ok(()) => {
eprintln!("spg-server: follower disconnected — retrying");
}
Err(e) => {
eprintln!("spg-server: follower error: {e} — retrying");
}
}
thread::sleep(RECONNECT_DELAY);
}
}
#[allow(clippy::too_many_lines)] fn follow_once(
master_addr: &str,
db_path: &Path,
wal_path: &Path,
state: &ServerState,
) -> std::io::Result<()> {
let mut stream = TcpStream::connect(master_addr)?;
stream.set_read_timeout(Some(Duration::from_secs(30)))?;
if state.lag_state.follower_applied_pos.load(Ordering::Acquire) == 0
&& let Some(persisted) = read_applied_pos_sidecar(wal_path)
&& persisted > 0
{
state
.lag_state
.follower_applied_pos
.store(persisted, Ordering::Release);
}
let stored_applied = state.lag_state.follower_applied_pos.load(Ordering::Acquire);
let start_offset: u64 = if db_path.exists() && stored_applied > 0 {
stored_applied
} else if db_path.exists() && wal_path.exists() {
let n = std::fs::metadata(wal_path).map_or(0, |m| m.len());
if n > 0 {
eprintln!(
"spg-server: follower sidecar .applied_pos missing — \
falling back to wal length {n}; this is byte-exact \
only if master's wal_position was 0 at first handshake"
);
}
n
} else {
0
};
let mut hs = Vec::with_capacity(16);
hs.extend_from_slice(MAGIC_V2);
hs.extend_from_slice(&start_offset.to_le_bytes());
stream.write_all(&hs)?;
stream.flush()?;
let mut len_buf = [0u8; 8];
stream.read_exact(&mut len_buf)?;
let snap_len = u64::from_le_bytes(len_buf);
let mut applied_offset = start_offset;
if snap_len > 0 {
let mut snap = vec![
0u8;
usize::try_from(snap_len).map_err(|_| {
std::io::Error::other("snapshot length exceeds usize range")
})?
];
stream.read_exact(&mut snap)?;
std::fs::write(db_path, &snap)?;
let mut pos_buf = [0u8; 8];
stream.read_exact(&mut pos_buf)?;
applied_offset = u64::from_le_bytes(pos_buf);
std::fs::write(wal_path, b"")?;
let new_engine = Engine::restore_envelope(&snap)
.map_err(|e| std::io::Error::other(format!("follower restore from snapshot: {e}")))?;
let mut g = state
.engine
.write()
.map_err(|_| std::io::Error::other("engine lock poisoned"))?;
*g = new_engine.with_clock(crate::wall_clock_micros);
}
state
.lag_state
.follower_applied_pos
.store(applied_offset, Ordering::Release);
if let Err(e) = write_applied_pos_sidecar(wal_path, applied_offset) {
eprintln!(
"spg-server: follower sidecar write failed at handshake offset {applied_offset}: {e}"
);
}
let mut wal_appender = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(wal_path)?;
let mut pending: Vec<u8> = Vec::with_capacity(4096);
let mut segment_buffers: std::collections::BTreeMap<u32, SegmentReceiveState> =
std::collections::BTreeMap::new();
loop {
let mut header = [0u8; 5];
if let Err(e) = stream.read_exact(&mut header) {
return if e.kind() == std::io::ErrorKind::UnexpectedEof {
Ok(()) } else {
Err(e)
};
}
let frame_type = header[0];
let payload_len = u32::from_le_bytes(header[1..].try_into().unwrap()) as usize;
let mut payload = vec![0u8; payload_len];
if payload_len > 0 {
stream.read_exact(&mut payload)?;
}
match frame_type {
FRAME_TYPE_SEGMENT_FILE_CHUNK => {
let chunk = decode_segment_chunk(&payload)?;
if let Some(committed_bytes) =
absorb_segment_chunk(&mut segment_buffers, &chunk, db_path)?
{
commit_received_segment(chunk.segment_id, committed_bytes, db_path, state)?;
}
}
FRAME_TYPE_WAL => {
wal_appender.write_all(&payload)?;
wal_appender.sync_data()?;
pending.extend_from_slice(&payload);
let mut cur = 0usize;
loop {
if pending.len() - cur < 4 {
break;
}
let len_bytes: [u8; 4] = pending[cur..cur + 4].try_into().unwrap();
let raw_len = u32::from_le_bytes(len_bytes);
let is_v2 = raw_len & crate::WAL_V2_SENTINEL != 0;
let is_v3 = is_v2 && (raw_len & crate::WAL_V3_FLAG != 0);
let len_mask = if is_v3 {
!(crate::WAL_V2_SENTINEL | crate::WAL_V3_FLAG)
} else {
!crate::WAL_V2_SENTINEL
};
let rec_len = (raw_len & len_mask) as usize;
let header_len = if is_v3 {
9
} else if is_v2 {
8
} else {
4
};
if pending.len() - cur < header_len + rec_len {
break;
}
let payload_off = cur + header_len;
let sql_bytes = &pending[payload_off..payload_off + rec_len];
if is_v2 {
let expected =
u32::from_le_bytes(pending[cur + 4..cur + 8].try_into().unwrap());
let actual = if is_v3 {
let type_byte = pending[cur + 8];
let mut buf = Vec::with_capacity(1 + sql_bytes.len());
buf.push(type_byte);
buf.extend_from_slice(sql_bytes);
spg_crypto::crc32::crc32(&buf)
} else {
spg_crypto::crc32::crc32(sql_bytes)
};
if actual != expected {
return Err(std::io::Error::other(format!(
"replicated WAL CRC mismatch at follower offset {} (expected={expected:#010x}, computed={actual:#010x}, payload_len={rec_len})",
applied_offset.saturating_add(cur as u64)
)));
}
}
if is_v3 {
let type_byte = pending[cur + 8];
match type_byte {
crate::WAL_V3_TYPE_AUTO_COMMIT_SQL => {}
other => {
return Err(std::io::Error::other(format!(
"replicated WAL v3 unknown type byte {other:#04x} at follower offset {} — refusing to apply",
applied_offset.saturating_add(cur as u64)
)));
}
}
}
let sql = core::str::from_utf8(sql_bytes).map_err(|_| {
std::io::Error::other("non-UTF-8 SQL in replicated WAL record")
})?;
{
let mut eng = state
.engine
.write()
.map_err(|_| std::io::Error::other("engine lock poisoned"))?;
if let Err(e) = eng.execute(sql) {
return Err(std::io::Error::other(format!(
"follower apply rejected {sql:?}: {e}"
)));
}
}
cur += header_len + rec_len;
applied_offset = applied_offset.saturating_add((header_len + rec_len) as u64);
}
if cur > 0 {
pending.drain(0..cur);
}
state
.lag_state
.follower_applied_pos
.store(applied_offset, Ordering::Release);
if let Err(e) = write_applied_pos_sidecar(wal_path, applied_offset) {
eprintln!(
"spg-server: follower sidecar write failed at offset {applied_offset}: {e}"
);
}
}
FRAME_TYPE_STATUS if payload.len() == 16 => {
let primary_pos = u64::from_le_bytes(payload[..8].try_into().unwrap());
let wall_time_us = u64::from_le_bytes(payload[8..].try_into().unwrap());
state
.lag_state
.primary_pos
.store(primary_pos, Ordering::Release);
state
.lag_state
.primary_wall_time_us
.store(wall_time_us, Ordering::Release);
}
_ => {
}
}
}
}
const SUB_READ_TIMEOUT: Duration = Duration::from_millis(500);
fn parse_conn_str(s: &str) -> Result<(String, u16), String> {
let mut host: Option<String> = None;
let mut port: Option<u16> = None;
for tok in s.split_ascii_whitespace() {
let Some((k, v)) = tok.split_once('=') else {
return Err(format!("expected key=value token, got {tok:?}"));
};
match k.to_ascii_lowercase().as_str() {
"host" => host = Some(v.to_string()),
"port" => {
port = Some(
v.parse::<u16>()
.map_err(|e| format!("bad port {v:?}: {e}"))?,
);
}
_ => {}
}
}
let host = host.ok_or_else(|| "conn_str missing host=…".to_string())?;
let port = port.ok_or_else(|| "conn_str missing port=…".to_string())?;
Ok((host, port))
}
pub fn run_subscription_worker(
name: String,
conn_str: String,
state: Arc<ServerState>,
shutdown: Arc<AtomicBool>,
) {
while !shutdown.load(Ordering::Acquire) {
match subscribe_once(&name, &conn_str, &state, &shutdown) {
Ok(()) => {
if shutdown.load(Ordering::Acquire) {
return;
}
eprintln!("spg-server: subscription {name:?} disconnected — retrying");
}
Err(e) => {
eprintln!("spg-server: subscription {name:?} error: {e} — retrying");
}
}
let mut slept = Duration::ZERO;
while slept < RECONNECT_DELAY {
if shutdown.load(Ordering::Acquire) {
return;
}
thread::sleep(SUB_READ_TIMEOUT);
slept += SUB_READ_TIMEOUT;
}
}
}
#[allow(clippy::too_many_lines)] fn subscribe_once(
name: &str,
conn_str: &str,
state: &Arc<ServerState>,
shutdown: &Arc<AtomicBool>,
) -> std::io::Result<()> {
let (host, port) = parse_conn_str(conn_str).map_err(std::io::Error::other)?;
let addr = format!("{host}:{port}");
let mut stream = TcpStream::connect(&addr)?;
stream.set_read_timeout(Some(SUB_READ_TIMEOUT))?;
let (start_offset, requested_publications) = {
let eng = state
.engine
.read()
.map_err(|_| std::io::Error::other("engine lock poisoned"))?;
match eng.subscriptions().get(name) {
Some(s) => (s.last_received_pos, s.publications.clone()),
None => return Ok(()),
}
};
let mut hs = Vec::with_capacity(
16 + 2
+ requested_publications
.iter()
.map(|p| 2 + p.len())
.sum::<usize>()
+ 8,
);
hs.extend_from_slice(MAGIC_SUB);
hs.extend_from_slice(&start_offset.to_le_bytes());
let num_pubs = u16::try_from(requested_publications.len()).map_err(|_| {
std::io::Error::other("subscription requests too many publications (max 65,535)")
})?;
hs.extend_from_slice(&num_pubs.to_le_bytes());
for p in &requested_publications {
let len = u16::try_from(p.len())
.map_err(|_| std::io::Error::other("publication name too long (max 65,535 bytes)"))?;
hs.extend_from_slice(&len.to_le_bytes());
hs.extend_from_slice(p.as_bytes());
}
hs.extend_from_slice(&state.cluster_id.to_le_bytes());
stream.write_all(&hs)?;
stream.flush()?;
let mut reply = [0u8; 16];
read_exact_with_shutdown(&mut stream, &mut reply, shutdown)?;
let mut applied_offset = u64::from_le_bytes(reply[..8].try_into().unwrap());
let master_cluster_id = u64::from_le_bytes(reply[8..].try_into().unwrap());
if master_cluster_id == state.cluster_id {
eprintln!(
"spg-server: subscription {name:?}: REPLICATION_LOOP — master cluster_id \
{master_cluster_id} matches own; aborting link"
);
return Err(std::io::Error::other("REPLICATION_LOOP"));
}
let mut pending: Vec<u8> = Vec::with_capacity(4096);
loop {
if shutdown.load(Ordering::Acquire) {
return Ok(());
}
let mut header = [0u8; 5];
if !read_exact_with_shutdown(&mut stream, &mut header, shutdown)? {
return Ok(());
}
let frame_type = header[0];
let payload_len = u32::from_le_bytes(header[1..].try_into().unwrap()) as usize;
let mut payload = vec![0u8; payload_len];
if payload_len > 0 && !read_exact_with_shutdown(&mut stream, &mut payload, shutdown)? {
return Ok(());
}
match frame_type {
FRAME_TYPE_WAL => {
pending.extend_from_slice(&payload);
let mut cur = 0usize;
loop {
if pending.len() - cur < 4 {
break;
}
let len_bytes: [u8; 4] = pending[cur..cur + 4].try_into().unwrap();
let raw_len = u32::from_le_bytes(len_bytes);
let is_v2 = raw_len & crate::WAL_V2_SENTINEL != 0;
let is_v3 = is_v2 && (raw_len & crate::WAL_V3_FLAG != 0);
let len_mask = if is_v3 {
!(crate::WAL_V2_SENTINEL | crate::WAL_V3_FLAG)
} else {
!crate::WAL_V2_SENTINEL
};
let rec_len = (raw_len & len_mask) as usize;
let header_len = if is_v3 {
9
} else if is_v2 {
8
} else {
4
};
if pending.len() - cur < header_len + rec_len {
break;
}
let payload_off = cur + header_len;
let sql_bytes = &pending[payload_off..payload_off + rec_len];
if is_v2 {
let expected =
u32::from_le_bytes(pending[cur + 4..cur + 8].try_into().unwrap());
let actual = if is_v3 {
let type_byte = pending[cur + 8];
let mut buf = Vec::with_capacity(1 + sql_bytes.len());
buf.push(type_byte);
buf.extend_from_slice(sql_bytes);
spg_crypto::crc32::crc32(&buf)
} else {
spg_crypto::crc32::crc32(sql_bytes)
};
if actual != expected {
return Err(std::io::Error::other(format!(
"subscription {name:?} WAL CRC mismatch at offset {} \
(expected={expected:#010x}, computed={actual:#010x}, \
payload_len={rec_len})",
applied_offset.saturating_add(cur as u64)
)));
}
}
if is_v3 {
let type_byte = pending[cur + 8];
match type_byte {
crate::WAL_V3_TYPE_AUTO_COMMIT_SQL => {}
crate::WAL_V3_TYPE_DURABILITY_CHECKPOINT => {
cur += header_len + rec_len;
applied_offset =
applied_offset.saturating_add((header_len + rec_len) as u64);
continue;
}
other => {
return Err(std::io::Error::other(format!(
"subscription {name:?}: unknown WAL v3 type byte \
{other:#04x} at offset {} — refusing to apply",
applied_offset.saturating_add(cur as u64)
)));
}
}
}
let sql = core::str::from_utf8(sql_bytes).map_err(|_| {
std::io::Error::other("non-UTF-8 SQL in subscribed WAL record")
})?;
let record_size = (header_len + rec_len) as u64;
let new_pos = applied_offset.saturating_add(cur as u64) + record_size;
{
let mut eng = state
.engine
.write()
.map_err(|_| std::io::Error::other("engine lock poisoned"))?;
if let Err(e) = eng.execute(sql) {
let msg = format!("{e:?}");
let tolerant = msg.contains("DuplicateTable")
|| msg.contains("DuplicateIndex")
|| msg.contains("DuplicateUser")
|| msg.contains("AlreadyExists");
if !tolerant {
return Err(std::io::Error::other(format!(
"subscription {name:?} apply rejected {sql:?}: {msg}"
)));
}
eprintln!(
"spg-server: subscription {name:?} tolerating apply error \
on {sql:?}: {msg}"
);
}
if !eng.subscription_advance(name, new_pos) {
return Ok(());
}
}
cur += header_len + rec_len;
applied_offset = applied_offset.saturating_add((header_len + rec_len) as u64);
}
if cur > 0 {
pending.drain(0..cur);
}
}
FRAME_TYPE_STATUS => {
}
FRAME_TYPE_SKIP => {
if payload.len() == 8 {
let skipped = u64::from_le_bytes(payload[..8].try_into().unwrap());
applied_offset = applied_offset.saturating_add(skipped);
let mut eng = state
.engine
.write()
.map_err(|_| std::io::Error::other("engine lock poisoned"))?;
if !eng.subscription_advance(name, applied_offset) {
return Ok(());
}
}
}
_ => {
}
}
}
}
fn read_exact_with_shutdown(
stream: &mut TcpStream,
buf: &mut [u8],
shutdown: &Arc<AtomicBool>,
) -> std::io::Result<bool> {
let mut got = 0usize;
while got < buf.len() {
if shutdown.load(Ordering::Acquire) {
return Ok(false);
}
match stream.read(&mut buf[got..]) {
Ok(0) => return Ok(false), Ok(n) => got += n,
Err(e)
if matches!(
e.kind(),
std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut
) =>
{
}
Err(e) => return Err(e),
}
}
Ok(true)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_owner_insert_into_table() {
assert_eq!(
extract_owner_from_sql("INSERT INTO foo VALUES (1)"),
OwnerKind::Dml("foo".to_string())
);
assert_eq!(
extract_owner_from_sql("INSERT INTO \"Foo\" VALUES (1)"),
OwnerKind::Dml("Foo".to_string())
);
assert_eq!(
extract_owner_from_sql("insert into bar(id) values (1)"),
OwnerKind::Dml("bar".to_string())
);
}
#[test]
fn extract_owner_update_delete() {
assert_eq!(
extract_owner_from_sql("UPDATE users SET x=1 WHERE id=2"),
OwnerKind::Dml("users".to_string())
);
assert_eq!(
extract_owner_from_sql("DELETE FROM users WHERE id=2"),
OwnerKind::Dml("users".to_string())
);
}
#[test]
fn extract_owner_ddl_is_skip() {
for sql in [
"CREATE TABLE t (id INT)",
"DROP TABLE t",
"ALTER INDEX idx REBUILD",
"TRUNCATE t",
"BEGIN",
"COMMIT",
"ROLLBACK",
"SAVEPOINT sp1",
"RELEASE SAVEPOINT sp1",
"SET search_path = public",
"CREATE PUBLICATION p FOR ALL TABLES",
"DROP PUBLICATION p",
"CREATE SUBSCRIPTION s CONNECTION 'h=x' PUBLICATION p",
"CREATE USER 'alice' WITH PASSWORD 'x'",
] {
assert_eq!(
extract_owner_from_sql(sql),
OwnerKind::Skip,
"expected Skip for {sql:?}"
);
}
}
#[test]
fn extract_owner_garbage_is_skip() {
assert_eq!(extract_owner_from_sql(""), OwnerKind::Skip);
assert_eq!(extract_owner_from_sql(" "), OwnerKind::Skip);
assert_eq!(extract_owner_from_sql("INSERT VALUES (1)"), OwnerKind::Skip);
assert_eq!(extract_owner_from_sql("INSERT INTO"), OwnerKind::Skip);
}
#[test]
#[ignore]
fn extract_owner_perf_under_200ns() {
const ITERS: u32 = 10_000;
let sql = "INSERT INTO some_table_name VALUES (1, 'hello world', 3.14)";
let t0 = std::time::Instant::now();
for _ in 0..ITERS {
let r = std::hint::black_box(extract_owner_from_sql(std::hint::black_box(sql)));
std::hint::black_box(r);
}
let ns_per_call = t0.elapsed().as_nanos() / u128::from(ITERS);
eprintln!("extract_owner_from_sql: {ns_per_call} ns/call (budget ≤ 200 ns)");
assert!(
ns_per_call < 200,
"owner scanner exceeded the v6.1.5 200 ns budget: {ns_per_call} ns/call"
);
}
#[test]
fn publication_filter_accept_all_matches_everything() {
let f = PublicationFilter::accept_all();
assert!(f.accepts_owner("t1"));
assert!(f.accepts_owner("anything"));
}
#[test]
fn publication_filter_for_tables_allow_list() {
let mut f = PublicationFilter {
any_all_tables: false,
allow: std::collections::HashSet::new(),
deny_sets: Vec::new(),
};
f.allow.insert("t1".to_string());
f.allow.insert("t3".to_string());
assert!(f.accepts_owner("t1"));
assert!(!f.accepts_owner("t2"));
assert!(f.accepts_owner("t3"));
}
#[test]
fn publication_filter_all_tables_except_deny_list() {
let mut deny = std::collections::HashSet::new();
deny.insert("bad".to_string());
let f = PublicationFilter {
any_all_tables: false,
allow: std::collections::HashSet::new(),
deny_sets: vec![deny],
};
assert!(!f.accepts_owner("bad"));
assert!(f.accepts_owner("good"));
}
#[test]
fn publication_filter_or_combines_multiple_scopes() {
let mut allow = std::collections::HashSet::new();
allow.insert("t1".to_string());
let mut deny = std::collections::HashSet::new();
deny.insert("bad".to_string());
let f = PublicationFilter {
any_all_tables: false,
allow,
deny_sets: vec![deny],
};
assert!(f.accepts_owner("t1")); assert!(f.accepts_owner("anything_else")); assert!(!f.accepts_owner("bad")); }
#[test]
fn segment_chunk_encode_decode_round_trip() {
let body: Vec<u8> = (0u8..=200).collect();
let payload = encode_segment_chunk_payload(7, 1, 4, &body);
let decoded = decode_segment_chunk(&payload).expect("decode ok");
assert_eq!(decoded.segment_id, 7);
assert_eq!(decoded.chunk_seq, 1);
assert_eq!(decoded.chunk_total, 4);
assert_eq!(decoded.body, body.as_slice());
}
#[test]
fn segment_chunk_decode_rejects_truncated_header() {
let r = decode_segment_chunk(&[0u8; 12]);
assert!(r.is_err(), "header < 16 bytes must error");
}
#[test]
fn segment_chunk_decode_rejects_oversize_chunk() {
let mut payload = Vec::new();
payload.extend_from_slice(&0u32.to_le_bytes()); payload.extend_from_slice(&0u32.to_le_bytes()); payload.extend_from_slice(&1u32.to_le_bytes()); payload.extend_from_slice(&(32u32 * 1024 * 1024).to_le_bytes()); let r = decode_segment_chunk(&payload);
assert!(r.is_err(), "chunk_bytes > 16 MiB cap must error");
}
#[test]
fn segment_chunk_decode_rejects_size_tail_mismatch() {
let mut payload = Vec::new();
payload.extend_from_slice(&0u32.to_le_bytes());
payload.extend_from_slice(&0u32.to_le_bytes());
payload.extend_from_slice(&1u32.to_le_bytes());
payload.extend_from_slice(&10u32.to_le_bytes());
payload.extend_from_slice(&[0u8; 5]);
let r = decode_segment_chunk(&payload);
assert!(r.is_err(), "declared chunk_bytes ≠ tail length must error");
}
#[test]
fn segment_chunk_decode_rejects_seq_out_of_range() {
let mut payload = Vec::new();
payload.extend_from_slice(&0u32.to_le_bytes());
payload.extend_from_slice(&5u32.to_le_bytes()); payload.extend_from_slice(&3u32.to_le_bytes()); payload.extend_from_slice(&0u32.to_le_bytes()); let r = decode_segment_chunk(&payload);
assert!(r.is_err(), "chunk_seq ≥ chunk_total must error");
}
}