use crate::{AdapterError, DomainInfo, PredicateInfo, SymbolTable};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::{Arc, RwLock};
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct NodeId(String);
impl NodeId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn random() -> Self {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let id = COUNTER.fetch_add(1, Ordering::SeqCst);
Self(format!("node-{}", id))
}
}
impl std::fmt::Display for NodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct VectorClock {
clocks: HashMap<NodeId, u64>,
}
impl VectorClock {
pub fn new() -> Self {
Self {
clocks: HashMap::new(),
}
}
pub fn increment(&mut self, node: &NodeId) {
*self.clocks.entry(node.clone()).or_insert(0) += 1;
}
pub fn get(&self, node: &NodeId) -> u64 {
self.clocks.get(node).copied().unwrap_or(0)
}
pub fn merge(&mut self, other: &VectorClock) {
for (node, &value) in &other.clocks {
let current = self.clocks.entry(node.clone()).or_insert(0);
*current = (*current).max(value);
}
}
pub fn happens_before(&self, other: &VectorClock) -> bool {
let mut strictly_less = false;
for (node, &self_val) in &self.clocks {
let other_val = other.get(node);
if self_val > other_val {
return false; }
if self_val < other_val {
strictly_less = true;
}
}
for (node, &other_val) in &other.clocks {
if !self.clocks.contains_key(node) && other_val > 0 {
strictly_less = true;
}
}
strictly_less
}
pub fn is_concurrent(&self, other: &VectorClock) -> bool {
!self.happens_before(other) && !other.happens_before(self) && self != other
}
}
impl Default for VectorClock {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum SyncChangeType {
DomainAdded,
DomainModified,
DomainRemoved,
PredicateAdded,
PredicateModified,
PredicateRemoved,
VariableAdded,
VariableRemoved,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SyncEvent {
pub id: String,
pub origin: NodeId,
pub clock: VectorClock,
pub change_type: SyncChangeType,
pub entity_name: String,
pub entity_data: Option<String>,
pub timestamp: u64,
}
impl SyncEvent {
pub fn new(
origin: NodeId,
clock: VectorClock,
change_type: SyncChangeType,
entity_name: String,
entity_data: Option<String>,
) -> Self {
let id = format!(
"{}-{}-{}",
origin.as_str(),
clock.get(&origin),
Self::current_timestamp()
);
Self {
id,
origin,
clock,
change_type,
entity_name,
entity_data,
timestamp: Self::current_timestamp(),
}
}
fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime after UNIX_EPOCH")
.as_millis() as u64
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ConflictResolution {
LastWriteWins,
FirstWriteWins,
Manual,
Merge,
VectorClock,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ApplyResult {
Applied,
Ignored,
ConflictResolved,
ManualRequired,
}
pub trait SyncProtocol: Send + Sync {
fn send_event(&self, target: &NodeId, event: &SyncEvent) -> Result<(), AdapterError>;
fn broadcast_event(&self, event: &SyncEvent) -> Result<(), AdapterError>;
fn receive_events(&self) -> Result<Vec<SyncEvent>, AdapterError>;
}
#[derive(Debug, Clone)]
pub struct InMemorySyncProtocol {
events: Arc<RwLock<VecDeque<SyncEvent>>>,
}
impl InMemorySyncProtocol {
pub fn new() -> Self {
Self {
events: Arc::new(RwLock::new(VecDeque::new())),
}
}
}
impl Default for InMemorySyncProtocol {
fn default() -> Self {
Self::new()
}
}
impl SyncProtocol for InMemorySyncProtocol {
fn send_event(&self, _target: &NodeId, event: &SyncEvent) -> Result<(), AdapterError> {
self.events
.write()
.map_err(|e| AdapterError::InvalidOperation(format!("Lock poisoned: {}", e)))?
.push_back(event.clone());
Ok(())
}
fn broadcast_event(&self, event: &SyncEvent) -> Result<(), AdapterError> {
self.events
.write()
.map_err(|e| AdapterError::InvalidOperation(format!("Lock poisoned: {}", e)))?
.push_back(event.clone());
Ok(())
}
fn receive_events(&self) -> Result<Vec<SyncEvent>, AdapterError> {
let mut events = self
.events
.write()
.map_err(|e| AdapterError::InvalidOperation(format!("Lock poisoned: {}", e)))?;
Ok(events.drain(..).collect())
}
}
pub trait EventListener: Send + Sync {
fn on_event_received(&self, event: &SyncEvent);
fn on_event_applied(&self, event: &SyncEvent, result: &ApplyResult);
fn on_conflict_detected(&self, event: &SyncEvent, conflict_type: &str);
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SyncStatistics {
pub events_sent: usize,
pub events_received: usize,
pub events_applied: usize,
pub events_ignored: usize,
pub conflicts_detected: usize,
pub conflicts_resolved: usize,
pub manual_resolutions_required: usize,
}
pub struct SynchronizationManager {
node_id: NodeId,
table: SymbolTable,
clock: VectorClock,
resolution_strategy: ConflictResolution,
pending_events: VecDeque<SyncEvent>,
applied_events: HashSet<String>,
listeners: Vec<Arc<dyn EventListener>>,
stats: SyncStatistics,
known_nodes: HashSet<NodeId>,
}
impl SynchronizationManager {
pub fn new(node_id: NodeId, table: SymbolTable) -> Self {
let mut clock = VectorClock::new();
clock.increment(&node_id);
Self {
node_id,
table,
clock,
resolution_strategy: ConflictResolution::VectorClock,
pending_events: VecDeque::new(),
applied_events: HashSet::new(),
listeners: Vec::new(),
stats: SyncStatistics::default(),
known_nodes: HashSet::new(),
}
}
pub fn set_resolution_strategy(&mut self, strategy: ConflictResolution) {
self.resolution_strategy = strategy;
}
pub fn add_listener(&mut self, listener: Arc<dyn EventListener>) {
self.listeners.push(listener);
}
pub fn register_node(&mut self, node_id: NodeId) {
self.known_nodes.insert(node_id);
}
pub fn table(&self) -> &SymbolTable {
&self.table
}
pub fn table_mut(&mut self) -> &mut SymbolTable {
&mut self.table
}
pub fn pending_events(&self) -> Vec<SyncEvent> {
self.pending_events.iter().cloned().collect()
}
pub fn clear_pending_events(&mut self) {
self.pending_events.clear();
}
pub fn statistics(&self) -> &SyncStatistics {
&self.stats
}
pub fn add_domain(&mut self, domain: DomainInfo) -> Result<(), AdapterError> {
let name = domain.name.clone();
self.table
.add_domain(domain.clone())
.map_err(|e| AdapterError::InvalidOperation(format!("Add domain failed: {}", e)))?;
self.clock.increment(&self.node_id);
let entity_data = serde_json::to_string(&domain)
.map_err(|e| AdapterError::InvalidOperation(format!("Serialization error: {}", e)))?;
let event = SyncEvent::new(
self.node_id.clone(),
self.clock.clone(),
SyncChangeType::DomainAdded,
name,
Some(entity_data),
);
self.pending_events.push_back(event.clone());
self.applied_events.insert(event.id.clone());
self.stats.events_sent += 1;
Ok(())
}
pub fn add_predicate(&mut self, predicate: PredicateInfo) -> Result<(), AdapterError> {
let name = predicate.name.clone();
self.table
.add_predicate(predicate.clone())
.map_err(|e| AdapterError::InvalidOperation(format!("Add predicate failed: {}", e)))?;
self.clock.increment(&self.node_id);
let entity_data = serde_json::to_string(&predicate)
.map_err(|e| AdapterError::InvalidOperation(format!("Serialization error: {}", e)))?;
let event = SyncEvent::new(
self.node_id.clone(),
self.clock.clone(),
SyncChangeType::PredicateAdded,
name,
Some(entity_data),
);
self.pending_events.push_back(event.clone());
self.applied_events.insert(event.id.clone());
self.stats.events_sent += 1;
Ok(())
}
pub fn remove_domain(&mut self, name: &str) -> Result<(), AdapterError> {
if self.table.get_domain(name).is_none() {
return Err(AdapterError::DomainNotFound(name.to_string()));
}
self.clock.increment(&self.node_id);
let event = SyncEvent::new(
self.node_id.clone(),
self.clock.clone(),
SyncChangeType::DomainRemoved,
name.to_string(),
None,
);
self.pending_events.push_back(event.clone());
self.applied_events.insert(event.id.clone());
self.stats.events_sent += 1;
Ok(())
}
pub fn apply_event(&mut self, event: SyncEvent) -> Result<ApplyResult, AdapterError> {
self.stats.events_received += 1;
for listener in &self.listeners {
listener.on_event_received(&event);
}
if self.applied_events.contains(&event.id) {
self.stats.events_ignored += 1;
return Ok(ApplyResult::Ignored);
}
self.clock.merge(&event.clock);
let result = match event.change_type {
SyncChangeType::DomainAdded => self.apply_domain_added(&event)?,
SyncChangeType::DomainModified => self.apply_domain_modified(&event)?,
SyncChangeType::DomainRemoved => self.apply_domain_removed(&event)?,
SyncChangeType::PredicateAdded => self.apply_predicate_added(&event)?,
SyncChangeType::PredicateModified => self.apply_predicate_modified(&event)?,
SyncChangeType::PredicateRemoved => self.apply_predicate_removed(&event)?,
SyncChangeType::VariableAdded => self.apply_variable_added(&event)?,
SyncChangeType::VariableRemoved => self.apply_variable_removed(&event)?,
};
self.applied_events.insert(event.id.clone());
match result {
ApplyResult::Applied => self.stats.events_applied += 1,
ApplyResult::Ignored => self.stats.events_ignored += 1,
ApplyResult::ConflictResolved => {
self.stats.conflicts_resolved += 1;
self.stats.events_applied += 1;
}
ApplyResult::ManualRequired => self.stats.manual_resolutions_required += 1,
}
for listener in &self.listeners {
listener.on_event_applied(&event, &result);
}
Ok(result)
}
fn apply_domain_added(&mut self, event: &SyncEvent) -> Result<ApplyResult, AdapterError> {
if let Some(_existing) = self.table.get_domain(&event.entity_name) {
self.stats.conflicts_detected += 1;
for listener in &self.listeners {
listener.on_conflict_detected(event, "domain_already_exists");
}
match self.resolution_strategy {
ConflictResolution::LastWriteWins => {
Ok(ApplyResult::ConflictResolved)
}
ConflictResolution::FirstWriteWins => {
Ok(ApplyResult::Ignored)
}
ConflictResolution::Manual => Ok(ApplyResult::ManualRequired),
ConflictResolution::Merge | ConflictResolution::VectorClock => {
Ok(ApplyResult::ConflictResolved)
}
}
} else {
let entity_data = event
.entity_data
.as_ref()
.ok_or_else(|| AdapterError::InvalidOperation("Missing entity data".to_string()))?;
let domain: DomainInfo = serde_json::from_str(entity_data).map_err(|e| {
AdapterError::InvalidOperation(format!("Deserialization error: {}", e))
})?;
self.table
.add_domain(domain)
.map_err(|e| AdapterError::InvalidOperation(format!("Add domain failed: {}", e)))?;
Ok(ApplyResult::Applied)
}
}
fn apply_domain_modified(&mut self, _event: &SyncEvent) -> Result<ApplyResult, AdapterError> {
Ok(ApplyResult::Applied)
}
fn apply_domain_removed(&mut self, _event: &SyncEvent) -> Result<ApplyResult, AdapterError> {
Ok(ApplyResult::Applied)
}
fn apply_predicate_added(&mut self, event: &SyncEvent) -> Result<ApplyResult, AdapterError> {
if self.table.get_predicate(&event.entity_name).is_some() {
self.stats.conflicts_detected += 1;
for listener in &self.listeners {
listener.on_conflict_detected(event, "predicate_already_exists");
}
match self.resolution_strategy {
ConflictResolution::FirstWriteWins => Ok(ApplyResult::Ignored),
ConflictResolution::Manual => Ok(ApplyResult::ManualRequired),
_ => Ok(ApplyResult::ConflictResolved),
}
} else {
let entity_data = event
.entity_data
.as_ref()
.ok_or_else(|| AdapterError::InvalidOperation("Missing entity data".to_string()))?;
let predicate: PredicateInfo = serde_json::from_str(entity_data).map_err(|e| {
AdapterError::InvalidOperation(format!("Deserialization error: {}", e))
})?;
self.table.add_predicate(predicate).map_err(|e| {
AdapterError::InvalidOperation(format!("Add predicate failed: {}", e))
})?;
Ok(ApplyResult::Applied)
}
}
fn apply_predicate_modified(
&mut self,
_event: &SyncEvent,
) -> Result<ApplyResult, AdapterError> {
Ok(ApplyResult::Applied)
}
fn apply_predicate_removed(&mut self, _event: &SyncEvent) -> Result<ApplyResult, AdapterError> {
Ok(ApplyResult::Applied)
}
fn apply_variable_added(&mut self, event: &SyncEvent) -> Result<ApplyResult, AdapterError> {
if let Some(data) = &event.entity_data {
let parts: Vec<&str> = data.split(':').collect();
if parts.len() == 2 {
let var_name = parts[0];
let domain_name = parts[1];
if self.table.bind_variable(var_name, domain_name).is_err() {
Ok(ApplyResult::Ignored)
} else {
Ok(ApplyResult::Applied)
}
} else {
Err(AdapterError::InvalidOperation(
"Invalid variable data format".to_string(),
))
}
} else {
Err(AdapterError::InvalidOperation(
"Missing entity data for variable".to_string(),
))
}
}
fn apply_variable_removed(&mut self, _event: &SyncEvent) -> Result<ApplyResult, AdapterError> {
Ok(ApplyResult::Applied)
}
pub fn synchronize<P: SyncProtocol>(&mut self, protocol: &P) -> Result<(), AdapterError> {
for event in &self.pending_events {
protocol.broadcast_event(event)?;
}
self.pending_events.clear();
let events = protocol.receive_events()?;
for event in events {
self.apply_event(event)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_node_id_creation() {
let node = NodeId::new("test-node");
assert_eq!(node.as_str(), "test-node");
assert_eq!(node.to_string(), "test-node");
}
#[test]
fn test_node_id_random() {
let node1 = NodeId::random();
let node2 = NodeId::random();
assert_ne!(node1, node2);
}
#[test]
fn test_vector_clock_basics() {
let mut clock = VectorClock::new();
let node1 = NodeId::new("node1");
let node2 = NodeId::new("node2");
assert_eq!(clock.get(&node1), 0);
clock.increment(&node1);
assert_eq!(clock.get(&node1), 1);
clock.increment(&node1);
assert_eq!(clock.get(&node1), 2);
clock.increment(&node2);
assert_eq!(clock.get(&node2), 1);
}
#[test]
fn test_vector_clock_merge() {
let node1 = NodeId::new("node1");
let node2 = NodeId::new("node2");
let mut clock1 = VectorClock::new();
clock1.increment(&node1);
clock1.increment(&node1);
let mut clock2 = VectorClock::new();
clock2.increment(&node2);
clock1.merge(&clock2);
assert_eq!(clock1.get(&node1), 2);
assert_eq!(clock1.get(&node2), 1);
}
#[test]
fn test_vector_clock_happens_before() {
let node1 = NodeId::new("node1");
let mut clock1 = VectorClock::new();
clock1.increment(&node1);
let mut clock2 = VectorClock::new();
clock2.increment(&node1);
clock2.increment(&node1);
assert!(clock1.happens_before(&clock2));
assert!(!clock2.happens_before(&clock1));
}
#[test]
fn test_vector_clock_concurrent() {
let node1 = NodeId::new("node1");
let node2 = NodeId::new("node2");
let mut clock1 = VectorClock::new();
clock1.increment(&node1);
let mut clock2 = VectorClock::new();
clock2.increment(&node2);
assert!(clock1.is_concurrent(&clock2));
assert!(clock2.is_concurrent(&clock1));
}
#[test]
fn test_sync_event_creation() {
let node = NodeId::new("test-node");
let clock = VectorClock::new();
let event = SyncEvent::new(
node.clone(),
clock,
SyncChangeType::DomainAdded,
"Person".to_string(),
Some("{}".to_string()),
);
assert_eq!(event.origin, node);
assert_eq!(event.entity_name, "Person");
assert!(event.timestamp > 0);
}
#[test]
fn test_in_memory_protocol() {
let protocol = InMemorySyncProtocol::new();
let node = NodeId::new("test");
let clock = VectorClock::new();
let event = SyncEvent::new(
node.clone(),
clock,
SyncChangeType::DomainAdded,
"Person".to_string(),
None,
);
protocol.send_event(&node, &event).expect("unwrap");
let received = protocol.receive_events().expect("unwrap");
assert_eq!(received.len(), 1);
assert_eq!(received[0].entity_name, "Person");
}
#[test]
fn test_sync_manager_creation() {
let node = NodeId::new("test");
let table = SymbolTable::new();
let mgr = SynchronizationManager::new(node.clone(), table);
assert_eq!(mgr.node_id, node);
assert_eq!(mgr.stats.events_sent, 0);
}
#[test]
fn test_sync_manager_add_domain() {
let node = NodeId::new("test");
let table = SymbolTable::new();
let mut mgr = SynchronizationManager::new(node, table);
let domain = DomainInfo::new("Person", 100);
mgr.add_domain(domain).expect("unwrap");
assert_eq!(mgr.pending_events().len(), 1);
assert_eq!(mgr.stats.events_sent, 1);
assert!(mgr.table().get_domain("Person").is_some());
}
#[test]
fn test_sync_manager_add_predicate() {
let node = NodeId::new("test");
let mut table = SymbolTable::new();
table
.add_domain(DomainInfo::new("Person", 100))
.expect("unwrap");
let mut mgr = SynchronizationManager::new(node, table);
let predicate = PredicateInfo::new("knows", vec!["Person".to_string()]);
mgr.add_predicate(predicate).expect("unwrap");
assert_eq!(mgr.pending_events().len(), 1);
assert_eq!(mgr.stats.events_sent, 1);
}
#[test]
fn test_sync_manager_apply_domain_event() {
let node1 = NodeId::new("node1");
let node2 = NodeId::new("node2");
let table1 = SymbolTable::new();
let mut mgr1 = SynchronizationManager::new(node1, table1);
let table2 = SymbolTable::new();
let mut mgr2 = SynchronizationManager::new(node2, table2);
let domain = DomainInfo::new("Person", 100);
mgr1.add_domain(domain).expect("unwrap");
let events = mgr1.pending_events();
let result = mgr2.apply_event(events[0].clone()).expect("unwrap");
assert_eq!(result, ApplyResult::Applied);
assert!(mgr2.table().get_domain("Person").is_some());
assert_eq!(mgr2.stats.events_received, 1);
assert_eq!(mgr2.stats.events_applied, 1);
}
#[test]
fn test_sync_manager_duplicate_event() {
let node = NodeId::new("node1");
let table = SymbolTable::new();
let mut mgr = SynchronizationManager::new(node.clone(), table);
let domain = DomainInfo::new("Person", 100);
let event_data = serde_json::to_string(&domain).expect("unwrap");
let mut clock = VectorClock::new();
clock.increment(&node);
let event = SyncEvent::new(
node,
clock,
SyncChangeType::DomainAdded,
"Person".to_string(),
Some(event_data),
);
let result1 = mgr.apply_event(event.clone()).expect("unwrap");
assert_eq!(result1, ApplyResult::Applied);
let result2 = mgr.apply_event(event).expect("unwrap");
assert_eq!(result2, ApplyResult::Ignored);
assert_eq!(mgr.stats.events_ignored, 1);
}
#[test]
fn test_sync_manager_conflict_resolution() {
let node1 = NodeId::new("node1");
let node2 = NodeId::new("node2");
let mut table = SymbolTable::new();
table
.add_domain(DomainInfo::new("Person", 100))
.expect("unwrap");
let mut mgr = SynchronizationManager::new(node2.clone(), table);
mgr.set_resolution_strategy(ConflictResolution::FirstWriteWins);
let domain = DomainInfo::new("Person", 200);
let event_data = serde_json::to_string(&domain).expect("unwrap");
let mut clock = VectorClock::new();
clock.increment(&node1);
let event = SyncEvent::new(
node1,
clock,
SyncChangeType::DomainAdded,
"Person".to_string(),
Some(event_data),
);
let result = mgr.apply_event(event).expect("unwrap");
assert_eq!(result, ApplyResult::Ignored);
assert_eq!(mgr.stats.conflicts_detected, 1);
}
#[test]
fn test_sync_manager_full_synchronization() {
let node1 = NodeId::new("node1");
let node2 = NodeId::new("node2");
let table1 = SymbolTable::new();
let mut mgr1 = SynchronizationManager::new(node1, table1);
let table2 = SymbolTable::new();
let mut mgr2 = SynchronizationManager::new(node2, table2);
mgr1.add_domain(DomainInfo::new("Person", 100))
.expect("unwrap");
mgr1.add_domain(DomainInfo::new("Place", 50))
.expect("unwrap");
let events = mgr1.pending_events();
assert_eq!(events.len(), 2);
for event in events {
mgr2.apply_event(event).expect("unwrap");
}
assert!(mgr2.table().get_domain("Person").is_some());
assert!(mgr2.table().get_domain("Place").is_some());
assert_eq!(mgr2.stats.events_applied, 2);
}
#[test]
fn test_conflict_resolution_strategies() {
for strategy in &[
ConflictResolution::LastWriteWins,
ConflictResolution::FirstWriteWins,
ConflictResolution::Manual,
ConflictResolution::Merge,
ConflictResolution::VectorClock,
] {
let node = NodeId::new("test");
let table = SymbolTable::new();
let mut mgr = SynchronizationManager::new(node, table);
mgr.set_resolution_strategy(*strategy);
}
}
#[test]
fn test_register_nodes() {
let node1 = NodeId::new("node1");
let node2 = NodeId::new("node2");
let node3 = NodeId::new("node3");
let table = SymbolTable::new();
let mut mgr = SynchronizationManager::new(node1, table);
mgr.register_node(node2.clone());
mgr.register_node(node3.clone());
assert_eq!(mgr.known_nodes.len(), 2);
assert!(mgr.known_nodes.contains(&node2));
assert!(mgr.known_nodes.contains(&node3));
}
}