use arc_swap::ArcSwap;
use rusqlite::{Connection, OpenFlags, Result as SqliteResult};
use std::collections::HashMap;
use std::sync::Arc;
pub type NodeId = i64;
#[derive(Debug, Clone)]
pub struct SnapshotState {
pub outgoing: HashMap<NodeId, Vec<NodeId>>,
pub incoming: HashMap<NodeId, Vec<NodeId>>,
pub created_at: std::time::SystemTime,
}
impl SnapshotState {
pub fn new(
outgoing: &HashMap<NodeId, Vec<NodeId>>,
incoming: &HashMap<NodeId, Vec<NodeId>>,
) -> Self {
Self {
outgoing: outgoing.clone(),
incoming: incoming.clone(),
created_at: std::time::SystemTime::now(),
}
}
pub fn node_count(&self) -> usize {
self.outgoing.len()
}
pub fn edge_count(&self) -> usize {
self.outgoing.values().map(|adj| adj.len()).sum()
}
pub fn contains_node(&self, node_id: NodeId) -> bool {
self.outgoing.contains_key(&node_id)
}
pub fn get_outgoing(&self, node_id: NodeId) -> Option<&Vec<NodeId>> {
self.outgoing.get(&node_id)
}
pub fn get_incoming(&self, node_id: NodeId) -> Option<&Vec<NodeId>> {
self.incoming.get(&node_id)
}
}
#[derive(Debug)]
pub struct SnapshotManager {
current: ArcSwap<SnapshotState>,
}
impl SnapshotManager {
pub fn new() -> Self {
let initial_state = SnapshotState::new(&HashMap::new(), &HashMap::new());
Self {
current: ArcSwap::new(Arc::new(initial_state)),
}
}
pub fn with_state(
outgoing: &HashMap<NodeId, Vec<NodeId>>,
incoming: &HashMap<NodeId, Vec<NodeId>>,
) -> Self {
let initial_state = SnapshotState::new(outgoing, incoming);
Self {
current: ArcSwap::new(Arc::new(initial_state)),
}
}
pub fn update_snapshot(
&self,
outgoing: &HashMap<NodeId, Vec<NodeId>>,
incoming: &HashMap<NodeId, Vec<NodeId>>,
) {
let new_state = SnapshotState::new(outgoing, incoming);
#[cfg(debug_assertions)]
{
assert_eq!(
new_state.node_count(),
outgoing.len(),
"Snapshot state node count mismatch"
);
assert_eq!(
new_state.edge_count(),
outgoing.values().map(|v| v.len()).sum::<usize>(),
"Snapshot state edge count mismatch"
);
}
self.current.store(Arc::new(new_state));
}
pub fn acquire_snapshot(&self) -> Arc<SnapshotState> {
let state = self.current.load();
let snapshot = Arc::clone(&state);
#[cfg(debug_assertions)]
{
let node_count = snapshot.node_count();
let edge_count = snapshot.edge_count();
assert!(node_count <= 10_000_000, "Suspiciously large node count");
assert!(edge_count <= 100_000_000, "Suspiciously large edge count");
}
snapshot
}
pub fn current_snapshot(&self) -> Arc<SnapshotState> {
self.current.load().clone()
}
}
impl Default for SnapshotManager {
fn default() -> Self {
Self::new()
}
}
pub struct GraphSnapshot {
state: Arc<SnapshotState>,
conn: Connection,
}
impl GraphSnapshot {
pub fn new(state: Arc<SnapshotState>, db_path: &str) -> SqliteResult<Self> {
let conn = Connection::open_with_flags(
db_path,
OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_NO_MUTEX,
)?;
Ok(Self { state, conn })
}
pub fn state(&self) -> &Arc<SnapshotState> {
&self.state
}
pub fn connection(&self) -> &Connection {
&self.conn
}
pub fn node_count(&self) -> usize {
self.state.node_count()
}
pub fn edge_count(&self) -> usize {
self.state.edge_count()
}
pub fn contains_node(&self, node_id: NodeId) -> bool {
self.state.contains_node(node_id)
}
pub fn get_outgoing(&self, node_id: NodeId) -> Option<&Vec<NodeId>> {
self.state.get_outgoing(node_id)
}
pub fn get_incoming(&self, node_id: NodeId) -> Option<&Vec<NodeId>> {
self.state.get_incoming(node_id)
}
pub fn created_at(&self) -> std::time::SystemTime {
self.state.created_at
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_snapshot_state_creation() {
let mut outgoing = HashMap::new();
let mut incoming = HashMap::new();
outgoing.insert(1, vec![2, 3]);
incoming.insert(1, vec![]);
let state = SnapshotState::new(&outgoing, &incoming);
assert_eq!(state.node_count(), 1);
assert_eq!(state.edge_count(), 2);
assert!(state.contains_node(1));
assert!(!state.contains_node(2));
}
#[test]
fn test_snapshot_manager() {
let mut outgoing = HashMap::new();
let mut incoming = HashMap::new();
outgoing.insert(1, vec![2]);
incoming.insert(1, vec![]);
let manager = SnapshotManager::with_state(&outgoing, &incoming);
let snapshot = manager.acquire_snapshot();
assert_eq!(snapshot.node_count(), 1);
assert!(snapshot.contains_node(1));
outgoing.insert(2, vec![]);
incoming.insert(2, vec![1]);
manager.update_snapshot(&outgoing, &incoming);
let new_snapshot = manager.acquire_snapshot();
assert_eq!(new_snapshot.node_count(), 2);
assert_eq!(snapshot.node_count(), 1);
}
}