use std::sync::{Arc, Mutex};
use graphrefly_core::{Core, CoreFull, Message};
use graphrefly_graph::{Graph, GraphObserveAllReactive, GraphPersistSnapshot, NodeSlice};
use graphrefly_structures::{BaseChange, Lifecycle, Version};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::{PhaseStat, RestoreError, RestoreResult, StorageError};
use crate::tier::{KvStorageTier, SnapshotStorageTier};
use crate::wal::{
graph_wal_prefix, verify_wal_frame_checksum, wal_frame_checksum, wal_frame_key, WALFrame,
WalTag, REPLAY_ORDER,
};
pub const SNAPSHOT_VERSION: u64 = 1;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphCheckpointRecord {
pub name: String,
pub mode: String,
pub snapshot: GraphPersistSnapshot,
pub seq: u64,
pub timestamp_ns: u64,
pub format_version: u64,
}
#[derive(Debug, Clone)]
pub struct GraphSnapshotDiff {
pub nodes_added: Vec<String>,
pub nodes_added_slices: Vec<NodeSlice>,
pub nodes_removed: Vec<String>,
pub value_changes: Vec<ValueChange>,
pub subgraphs_added: Vec<String>,
pub subgraphs_removed: Vec<String>,
}
impl GraphSnapshotDiff {
#[must_use]
pub fn is_empty(&self) -> bool {
self.nodes_added.is_empty()
&& self.nodes_removed.is_empty()
&& self.value_changes.is_empty()
&& self.subgraphs_added.is_empty()
&& self.subgraphs_removed.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct ValueChange {
pub path: String,
pub to: Option<Value>,
}
#[must_use]
pub fn diff_snapshots(
before: &GraphPersistSnapshot,
after: &GraphPersistSnapshot,
) -> GraphSnapshotDiff {
let mut nodes_added = Vec::new();
let mut nodes_added_slices = Vec::new();
let mut nodes_removed = Vec::new();
let mut value_changes = Vec::new();
let mut subgraphs_added = Vec::new();
let mut subgraphs_removed = Vec::new();
for (name, after_slice) in &after.nodes {
if let Some(before_slice) = before.nodes.get(name) {
if before_slice.value != after_slice.value {
value_changes.push(ValueChange {
path: name.clone(),
to: after_slice.value.clone(),
});
}
} else {
nodes_added.push(name.clone());
nodes_added_slices.push(after_slice.clone());
}
}
for name in before.nodes.keys() {
if !after.nodes.contains_key(name) {
nodes_removed.push(name.clone());
}
}
for name in after.subgraphs.keys() {
if !before.subgraphs.contains_key(name) {
subgraphs_added.push(name.clone());
}
}
for name in before.subgraphs.keys() {
if !after.subgraphs.contains_key(name) {
subgraphs_removed.push(name.clone());
}
}
GraphSnapshotDiff {
nodes_added,
nodes_added_slices,
nodes_removed,
value_changes,
subgraphs_added,
subgraphs_removed,
}
}
struct DecomposedFrame {
lifecycle: Lifecycle,
path: String,
change: BaseChange<Value>,
}
pub fn decompose_diff_to_frames(
diff: &GraphSnapshotDiff,
timestamp_ns: u64,
base_seq: u64,
) -> Result<(Vec<WALFrame<Value>>, u64), StorageError> {
let mut decomposed = Vec::new();
let wrap = |structure: &str, lifecycle: Lifecycle, payload: Value| -> BaseChange<Value> {
BaseChange {
structure: structure.to_owned(),
version: Version::Counter(SNAPSHOT_VERSION),
t_ns: timestamp_ns,
seq: None,
lifecycle,
change: payload,
}
};
for (i, name) in diff.nodes_added.iter().enumerate() {
let slice = &diff.nodes_added_slices[i];
let payload = serde_json::json!({
"kind": "graph.add",
"nodeId": name,
"slice": serde_json::to_value(slice).map_err(|e|
StorageError::Codec(crate::codec::CodecError::Encode(e.to_string()))
)?,
});
decomposed.push(DecomposedFrame {
lifecycle: Lifecycle::Spec,
path: name.clone(),
change: wrap("graph.spec", Lifecycle::Spec, payload),
});
}
for name in &diff.nodes_removed {
let payload = serde_json::json!({
"kind": "graph.remove",
"nodeId": name,
});
decomposed.push(DecomposedFrame {
lifecycle: Lifecycle::Spec,
path: name.clone(),
change: wrap("graph.spec", Lifecycle::Spec, payload),
});
}
for name in &diff.subgraphs_added {
let payload = serde_json::json!({
"kind": "graph.mount",
"path": name,
"subgraphId": name,
});
decomposed.push(DecomposedFrame {
lifecycle: Lifecycle::Spec,
path: name.clone(),
change: wrap("graph.spec", Lifecycle::Spec, payload),
});
}
for name in &diff.subgraphs_removed {
let payload = serde_json::json!({
"kind": "graph.unmount",
"path": name,
});
decomposed.push(DecomposedFrame {
lifecycle: Lifecycle::Spec,
path: name.clone(),
change: wrap("graph.spec", Lifecycle::Spec, payload),
});
}
for vc in &diff.value_changes {
let payload = if let Some(ref value) = vc.to {
serde_json::json!({
"kind": "node.set",
"path": vc.path,
"value": value,
})
} else {
serde_json::json!({
"kind": "node.invalidate",
"path": vc.path,
})
};
decomposed.push(DecomposedFrame {
lifecycle: Lifecycle::Data,
path: vc.path.clone(),
change: wrap("graph.value", Lifecycle::Data, payload),
});
}
let mut seq = base_seq;
let mut frames = Vec::with_capacity(decomposed.len());
for d in decomposed {
seq += 1;
let mut frame = WALFrame {
t: WalTag,
lifecycle: d.lifecycle,
path: d.path,
change: d.change,
frame_seq: seq,
frame_t_ns: timestamp_ns,
checksum: String::new(),
format_version: 1,
};
frame.checksum = wal_frame_checksum(&frame)?;
frames.push(frame);
}
Ok((frames, seq))
}
struct TierState {
snapshot_tier: Box<dyn SnapshotStorageTier<GraphCheckpointRecord>>,
wal_tier: Option<Box<dyn KvStorageTier<WALFrame<Value>>>>,
wal_prefix: String,
seq: u64,
flush_count: u64,
compact_every: u32,
snapshot_debounce_ms: u32,
wal_debounce_ms: u32,
last_snapshot: Option<GraphPersistSnapshot>,
disposed: bool,
}
pub struct AttachTierPair {
pub snapshot: Box<dyn SnapshotStorageTier<GraphCheckpointRecord>>,
pub wal: Option<Box<dyn KvStorageTier<WALFrame<Value>>>>,
}
pub type PathFilter = Box<dyn Fn(&str) -> bool + Send + Sync>;
pub type ErrorCallback = Box<dyn Fn(&StorageError) + Send + Sync>;
#[derive(Default)]
pub struct AttachOptions {
pub filter: Option<PathFilter>,
pub on_error: Option<ErrorCallback>,
}
pub struct StorageHandle {
state: Arc<Mutex<Vec<TierState>>>,
observe: Mutex<Option<GraphObserveAllReactive>>,
}
impl StorageHandle {
pub fn dispose(&self) {
let mut states = self
.state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
for s in states.iter_mut() {
s.disposed = true;
}
}
pub fn detach(&self, core: &Core) {
self.dispose();
let observe = self
.observe
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take();
if let Some(mut observe) = observe {
observe.detach(core);
}
}
pub fn flush_all(&self) -> Result<(), StorageError> {
let tier_count = {
let states = self
.state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
states.len()
};
let mut first_err: Option<StorageError> = None;
for idx in 0..tier_count {
let snapshot_err: Option<StorageError>;
let wal_err: Option<StorageError>;
{
let mut states = self
.state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let Some(s) = states.get_mut(idx) else {
continue;
};
if s.disposed {
continue;
}
snapshot_err = s.snapshot_tier.flush().err();
wal_err = s.wal_tier.as_ref().and_then(|wal| wal.flush().err());
}
if let Some(e) = snapshot_err {
if first_err.is_none() {
first_err = Some(e);
}
}
if let Some(e) = wal_err {
if first_err.is_none() {
first_err = Some(e);
}
}
}
match first_err {
None => Ok(()),
Some(e) => Err(e),
}
}
}
#[must_use = "the returned StorageHandle owns the observe subscription; \
call StorageHandle::detach(core) to unsubscribe and stop \
persistence (D246 r3 — no RAII Drop)"]
pub fn attach_snapshot_storage(
core: &Core,
graph: &Graph,
pairs: Vec<AttachTierPair>,
options: AttachOptions,
) -> StorageHandle {
let graph_name = graph.name();
let wal_prefix = graph_wal_prefix(&graph_name);
let mut states = Vec::with_capacity(pairs.len());
for pair in pairs {
let mut high_seq: u64 = 0;
if let Some(ref wal) = pair.wal {
if let Ok(keys) = wal.list(&wal_prefix) {
for key in keys {
if let Some(seg) = key.rsplit('/').next() {
if let Ok(s) = seg.parse::<u64>() {
high_seq = high_seq.max(s);
}
}
}
}
}
let compact_every = pair.snapshot.compact_every().unwrap_or(10);
let snapshot_debounce = pair.snapshot.debounce_ms().unwrap_or(0);
let wal_debounce = pair.wal.as_ref().and_then(|w| w.debounce_ms()).unwrap_or(0);
states.push(TierState {
snapshot_tier: pair.snapshot,
wal_tier: pair.wal,
wal_prefix: wal_prefix.clone(),
seq: high_seq,
flush_count: 0,
compact_every,
snapshot_debounce_ms: snapshot_debounce,
wal_debounce_ms: wal_debounce,
last_snapshot: None,
disposed: false,
});
}
let shared_states = Arc::new(Mutex::new(states));
let states_for_sink = shared_states.clone();
let filter = options.filter;
let on_error = options.on_error.map(Arc::new);
let graph_for_sink = graph.clone();
let deferred = core.defer_queue();
let scheduled = std::rc::Rc::new(std::cell::Cell::new(false));
let pending_count = std::rc::Rc::new(std::cell::Cell::new(0usize));
let mut observe = graph.observe_all_reactive();
observe.subscribe(core, move |path: &str, messages: &[Message]| {
let dominated_by_tier = messages.iter().any(|m| {
let t = m.tier();
(3..6).contains(&t)
});
if !dominated_by_tier {
return;
}
if let Some(ref f) = filter {
if !f(path) {
return;
}
}
pending_count.set(pending_count.get().saturating_add(1));
if scheduled.get() {
return; }
let graph_for_defer = graph_for_sink.clone();
let states_for_defer = states_for_sink.clone();
let on_error = on_error.clone();
let sched = std::rc::Rc::clone(&scheduled);
let pc = std::rc::Rc::clone(&pending_count);
sched.set(true);
let _ = deferred.post(Box::new(move |cf: &dyn CoreFull| {
sched.set(false);
let count = pc.replace(0);
let snapshot = graph_for_defer.snapshot_full(cf);
let collected_errors: Vec<StorageError> = {
let mut states = states_for_defer
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let mut errs = Vec::new();
for s in states.iter_mut() {
if s.disposed {
continue;
}
if let Err(e) = flush_tier(s, &snapshot, count) {
errs.push(e);
}
}
errs
};
if !collected_errors.is_empty() {
if let Some(ref cb) = on_error {
for e in &collected_errors {
cb(e);
}
}
}
}));
});
StorageHandle {
state: shared_states,
observe: Mutex::new(Some(observe)),
}
}
fn flush_tier(
s: &mut TierState,
snapshot: &GraphPersistSnapshot,
count: usize,
) -> Result<(), StorageError> {
let before = s.flush_count;
let inc = count as u64;
s.flush_count = s.flush_count.saturating_add(inc);
let write_full = s.wal_tier.is_none()
|| s.last_snapshot.is_none()
|| (s.compact_every > 0 && {
let cmp = u64::from(s.compact_every);
(before / cmp) < (s.flush_count / cmp)
});
if write_full {
write_full_baseline(s, snapshot)?;
} else {
write_wal_delta(s, snapshot)?;
}
s.last_snapshot = Some(snapshot.clone());
Ok(())
}
fn write_full_baseline(
s: &mut TierState,
snapshot: &GraphPersistSnapshot,
) -> Result<(), StorageError> {
let timestamp_ns = graphrefly_core::wall_clock_ns();
let record = GraphCheckpointRecord {
name: snapshot.name.clone(),
mode: "full".to_owned(),
snapshot: snapshot.clone(),
seq: s.seq,
timestamp_ns,
format_version: SNAPSHOT_VERSION,
};
s.snapshot_tier.save(record)?;
if s.snapshot_debounce_ms == 0 {
s.snapshot_tier.flush()?;
}
Ok(())
}
fn write_wal_delta(s: &mut TierState, snapshot: &GraphPersistSnapshot) -> Result<(), StorageError> {
let last = s
.last_snapshot
.as_ref()
.expect("caller ensures last_snapshot is Some");
let diff = diff_snapshots(last, snapshot);
if diff.is_empty() {
return Ok(());
}
let timestamp_ns = graphrefly_core::wall_clock_ns();
let (frames, next_seq) = decompose_diff_to_frames(&diff, timestamp_ns, s.seq)?;
if let Some(ref wal) = s.wal_tier {
for frame in &frames {
let key = wal_frame_key(&s.wal_prefix, frame.frame_seq);
wal.save(&key, frame.clone())?;
}
if s.wal_debounce_ms == 0 {
wal.flush()?;
}
}
s.seq = next_seq;
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TornWritePolicy {
Skip,
Abort,
}
pub type OnTornWrite = Box<dyn Fn(u64, &str) -> TornWritePolicy + Send + Sync>;
pub struct RestoreOptions<'a> {
pub snapshot_tier: &'a dyn SnapshotStorageTier<GraphCheckpointRecord>,
pub wal_tier: &'a dyn KvStorageTier<WALFrame<Value>>,
pub target_seq: Option<u64>,
pub on_torn_write: Option<OnTornWrite>,
}
pub fn restore_snapshot(
core: &Core,
graph: &Graph,
opts: &RestoreOptions<'_>,
) -> Result<RestoreResult, RestoreError> {
let baseline = load_baseline(core, graph, opts)?;
let baseline_seq = baseline.seq;
let collected = collect_wal_frames(opts, &baseline.name, baseline_seq)?;
let (verified, skipped) = verify_frames(collected, opts.on_torn_write.as_ref())?;
Ok(replay_by_lifecycle(
core,
graph,
&verified,
baseline_seq,
skipped,
))
}
fn load_baseline(
core: &Core,
graph: &Graph,
opts: &RestoreOptions<'_>,
) -> Result<GraphCheckpointRecord, RestoreError> {
let baseline = opts
.snapshot_tier
.load()
.map_err(|e| RestoreError::PhaseFailed {
lifecycle: Lifecycle::Spec,
frame_seq: 0,
message: format!("baseline load failed: {e}"),
})?
.ok_or(RestoreError::BaselineMissing)?;
if baseline.mode != "full" {
return Err(RestoreError::BaselineMissing);
}
graph
.restore(core, &baseline.snapshot)
.map_err(|e| RestoreError::PhaseFailed {
lifecycle: Lifecycle::Spec,
frame_seq: 0,
message: format!("baseline restore failed: {e}"),
})?;
Ok(baseline)
}
fn collect_wal_frames(
opts: &RestoreOptions<'_>,
graph_name: &str,
baseline_seq: u64,
) -> Result<Vec<WALFrame<Value>>, RestoreError> {
let wal_prefix = graph_wal_prefix(graph_name);
let keys = opts
.wal_tier
.list(&wal_prefix)
.map_err(|e| RestoreError::PhaseFailed {
lifecycle: Lifecycle::Spec,
frame_seq: 0,
message: format!("WAL frame enumeration failed: {e}"),
})?;
let mut collected: Vec<WALFrame<Value>> = Vec::new();
for key in keys {
let frame_seq = key
.rsplit('/')
.next()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
if frame_seq <= baseline_seq {
continue;
}
if let Some(target) = opts.target_seq {
if frame_seq > target {
continue;
}
}
if let Some(frame) = opts
.wal_tier
.load(&key)
.map_err(|e| RestoreError::PhaseFailed {
lifecycle: Lifecycle::Data,
frame_seq,
message: format!("WAL frame load failed: {e}"),
})?
{
collected.push(frame);
}
}
collected.sort_by_key(|f| f.frame_seq);
Ok(collected)
}
fn verify_frames(
collected: Vec<WALFrame<Value>>,
on_torn_write: Option<&OnTornWrite>,
) -> Result<(Vec<WALFrame<Value>>, u64), RestoreError> {
let mut verified = Vec::new();
let mut skipped: u64 = 0;
let total = collected.len();
for (i, frame) in collected.into_iter().enumerate() {
if verify_wal_frame_checksum(&frame).unwrap_or(false) {
verified.push(frame);
continue;
}
let is_tail = i == total - 1;
let policy = if let Some(cb) = on_torn_write {
cb(frame.frame_seq, "checksum-mismatch")
} else if is_tail {
TornWritePolicy::Skip
} else {
TornWritePolicy::Abort
};
match policy {
TornWritePolicy::Skip => skipped += 1,
TornWritePolicy::Abort => {
return Err(RestoreError::TornWriteMidStream {
frame_seq: frame.frame_seq,
reason: "checksum-mismatch".to_owned(),
});
}
}
}
Ok((verified, skipped))
}
fn replay_by_lifecycle(
core: &Core,
graph: &Graph,
verified: &[WALFrame<Value>],
baseline_seq: u64,
skipped: u64,
) -> RestoreResult {
let mut grouped: [Vec<WALFrame<Value>>; 3] = [Vec::new(), Vec::new(), Vec::new()];
for frame in verified {
for (idx, lifecycle) in REPLAY_ORDER.iter().enumerate() {
if frame.lifecycle == *lifecycle {
grouped[idx].push(frame.clone());
break;
}
}
}
let mut phases = Vec::new();
let mut replayed: u64 = 0;
let mut final_seq: u64 = baseline_seq;
for (idx, lifecycle) in REPLAY_ORDER.iter().enumerate() {
let life_frames = &grouped[idx];
if life_frames.is_empty() {
continue;
}
let frame_count = life_frames.len() as u64;
let max_seq = life_frames.iter().map(|f| f.frame_seq).max().unwrap_or(0);
graph.batch(core, || {
for frame in life_frames {
apply_wal_frame(core, graph, frame);
}
});
replayed += frame_count;
final_seq = final_seq.max(max_seq);
phases.push(PhaseStat {
lifecycle: *lifecycle,
frames: frame_count,
});
}
RestoreResult {
replayed_frames: replayed,
skipped_frames: skipped,
final_seq,
phases,
}
}
fn apply_wal_frame(core: &Core, graph: &Graph, frame: &WALFrame<Value>) {
let change = &frame.change.change;
let kind = change.get("kind").and_then(Value::as_str).unwrap_or("");
match frame.lifecycle {
Lifecycle::Spec => match kind {
"graph.add" => {
let node_id_str = change.get("nodeId").and_then(Value::as_str).unwrap_or("");
if node_id_str.is_empty() || graph.try_resolve(node_id_str).is_some() {
return; }
let slice = change.get("slice");
let node_type = slice
.and_then(|s| s.get("type"))
.and_then(Value::as_str)
.unwrap_or("");
if node_type != "state" {
return;
}
let initial_value = slice.and_then(|s| s.get("value")).cloned();
let handle = initial_value.map_or(graphrefly_core::NO_HANDLE, |v| {
core.binding_ptr().deserialize_value(v)
});
let _ = graph.state(core, node_id_str, Some(handle));
}
"graph.remove" => {
let node_id_str = change.get("nodeId").and_then(Value::as_str).unwrap_or("");
if !node_id_str.is_empty() && graph.try_resolve(node_id_str).is_some() {
let _ = graph.remove(core, node_id_str);
}
}
_ => {}
},
Lifecycle::Data => match kind {
"node.set" => {
let path = change.get("path").and_then(Value::as_str).unwrap_or("");
if let Some(value) = change.get("value") {
if !path.is_empty() && graph.try_resolve(path).is_some() {
let handle = core.binding_ptr().deserialize_value(value.clone());
graph.set(core, path, handle);
}
}
}
"node.invalidate" => {
let path = change.get("path").and_then(Value::as_str).unwrap_or("");
if !path.is_empty() {
if let Some(id) = graph.try_resolve(path) {
graph.invalidate(core, id);
}
}
}
_ => {}
},
Lifecycle::Ownership => {}
}
}