use crate::clock::CompactTimestamp;
use crate::error::{CRDTError, CRDTResult};
use crate::memory::{MemoryConfig, NodeId};
use crate::traits::{BoundedCRDT, CRDT, RealTimeCRDT};
#[cfg(feature = "hardware-atomic")]
use core::cell::UnsafeCell;
#[cfg(feature = "hardware-atomic")]
use core::sync::atomic::{AtomicU8, AtomicU32, Ordering};
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[derive(Debug)]
pub struct LWWRegister<T, C: MemoryConfig> {
#[cfg(not(feature = "hardware-atomic"))]
current_value: Option<T>,
#[cfg(not(feature = "hardware-atomic"))]
current_timestamp: CompactTimestamp,
#[cfg(not(feature = "hardware-atomic"))]
current_node_id: NodeId,
#[cfg(feature = "hardware-atomic")]
current_value: UnsafeCell<Option<T>>,
#[cfg(feature = "hardware-atomic")]
current_timestamp: AtomicU32,
#[cfg(feature = "hardware-atomic")]
current_node_id: AtomicU8,
node_id: NodeId,
_phantom: core::marker::PhantomData<C>,
}
#[cfg(feature = "hardware-atomic")]
unsafe impl<T, C: MemoryConfig> Sync for LWWRegister<T, C>
where
T: Send,
C: Send + Sync,
{
}
impl<T, C: MemoryConfig> Clone for LWWRegister<T, C>
where
T: Clone,
{
fn clone(&self) -> Self {
#[cfg(not(feature = "hardware-atomic"))]
{
Self {
current_value: self.current_value.clone(),
current_timestamp: self.current_timestamp,
current_node_id: self.current_node_id,
node_id: self.node_id,
_phantom: core::marker::PhantomData,
}
}
#[cfg(feature = "hardware-atomic")]
{
let cloned_value = unsafe { (*self.current_value.get()).clone() };
Self {
current_value: UnsafeCell::new(cloned_value),
current_timestamp: AtomicU32::new(self.current_timestamp.load(Ordering::Relaxed)),
current_node_id: AtomicU8::new(self.current_node_id.load(Ordering::Relaxed)),
node_id: self.node_id,
_phantom: core::marker::PhantomData,
}
}
}
}
impl<T, C: MemoryConfig> LWWRegister<T, C>
where
T: Clone + PartialEq,
{
pub fn new(node_id: NodeId) -> Self {
#[cfg(not(feature = "hardware-atomic"))]
{
Self {
current_value: None,
current_timestamp: CompactTimestamp::zero(),
current_node_id: 0,
node_id,
_phantom: core::marker::PhantomData,
}
}
#[cfg(feature = "hardware-atomic")]
{
Self {
current_value: UnsafeCell::new(None),
current_timestamp: AtomicU32::new(0),
current_node_id: AtomicU8::new(0),
node_id,
_phantom: core::marker::PhantomData,
}
}
}
#[cfg(not(feature = "hardware-atomic"))]
pub fn set(&mut self, value: T, timestamp: u64) -> CRDTResult<()> {
let new_timestamp = CompactTimestamp::new(timestamp);
if self.should_update(&new_timestamp, self.node_id) {
self.current_value = Some(value);
self.current_timestamp = new_timestamp;
self.current_node_id = self.node_id;
}
Ok(())
}
#[cfg(feature = "hardware-atomic")]
pub fn set(&self, value: T, timestamp: u64) -> CRDTResult<()> {
let new_timestamp_u32 = timestamp as u32;
loop {
let current_timestamp = self.current_timestamp.load(Ordering::Relaxed);
let current_node_id = self.current_node_id.load(Ordering::Relaxed);
let should_update = if current_timestamp == 0 {
true } else if new_timestamp_u32 > current_timestamp {
true } else if new_timestamp_u32 == current_timestamp {
self.node_id > current_node_id } else {
false };
if !should_update {
return Ok(());
}
match self.current_timestamp.compare_exchange_weak(
current_timestamp,
new_timestamp_u32,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => {
self.current_node_id.store(self.node_id, Ordering::Relaxed);
unsafe {
*self.current_value.get() = Some(value);
}
break;
}
Err(_) => {
continue;
}
}
}
Ok(())
}
pub fn get(&self) -> Option<&T> {
#[cfg(not(feature = "hardware-atomic"))]
{
self.current_value.as_ref()
}
#[cfg(feature = "hardware-atomic")]
{
unsafe { (*self.current_value.get()).as_ref() }
}
}
pub fn timestamp(&self) -> CompactTimestamp {
#[cfg(not(feature = "hardware-atomic"))]
{
self.current_timestamp
}
#[cfg(feature = "hardware-atomic")]
{
CompactTimestamp::new(self.current_timestamp.load(Ordering::Relaxed) as u64)
}
}
pub fn current_node(&self) -> NodeId {
#[cfg(not(feature = "hardware-atomic"))]
{
self.current_node_id
}
#[cfg(feature = "hardware-atomic")]
{
self.current_node_id.load(Ordering::Relaxed)
}
}
pub fn is_empty(&self) -> bool {
#[cfg(not(feature = "hardware-atomic"))]
{
self.current_value.is_none()
}
#[cfg(feature = "hardware-atomic")]
{
unsafe { (*self.current_value.get()).is_none() }
}
}
#[cfg(not(feature = "hardware-atomic"))]
fn should_update(&self, new_timestamp: &CompactTimestamp, new_node_id: NodeId) -> bool {
if self.current_value.is_none() {
return true;
}
match new_timestamp.cmp(&self.current_timestamp) {
core::cmp::Ordering::Greater => true,
core::cmp::Ordering::Less => false,
core::cmp::Ordering::Equal => {
new_node_id > self.current_node_id
}
}
}
}
#[cfg(feature = "serde")]
impl<T, C: MemoryConfig> Serialize for LWWRegister<T, C>
where
T: Serialize + Clone + PartialEq,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("LWWRegister", 4)?;
#[cfg(not(feature = "hardware-atomic"))]
{
state.serialize_field("current_value", &self.current_value)?;
state.serialize_field("current_timestamp", &self.current_timestamp.as_u64())?;
state.serialize_field("current_node_id", &self.current_node_id)?;
}
#[cfg(feature = "hardware-atomic")]
{
let current_value = unsafe { &*self.current_value.get() };
let current_timestamp = self.current_timestamp.load(Ordering::Relaxed) as u64;
let current_node_id = self.current_node_id.load(Ordering::Relaxed);
state.serialize_field("current_value", current_value)?;
state.serialize_field("current_timestamp", ¤t_timestamp)?;
state.serialize_field("current_node_id", ¤t_node_id)?;
}
state.serialize_field("node_id", &self.node_id)?;
state.end()
}
}
#[cfg(feature = "serde")]
impl<'de, T, C: MemoryConfig> Deserialize<'de> for LWWRegister<T, C>
where
T: Deserialize<'de> + Clone + PartialEq,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
use core::fmt;
use serde::de::{self, MapAccess, Visitor};
#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "snake_case")]
enum Field {
CurrentValue,
CurrentTimestamp,
CurrentNodeId,
NodeId,
}
struct LWWRegisterVisitor<T, C: MemoryConfig> {
_phantom: core::marker::PhantomData<(T, C)>,
}
impl<'de, T, C: MemoryConfig> Visitor<'de> for LWWRegisterVisitor<T, C>
where
T: Deserialize<'de> + Clone + PartialEq,
{
type Value = LWWRegister<T, C>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct LWWRegister")
}
fn visit_map<V>(self, mut map: V) -> Result<LWWRegister<T, C>, V::Error>
where
V: MapAccess<'de>,
{
let mut current_value = None;
let mut current_timestamp = None;
let mut current_node_id = None;
let mut node_id = None;
while let Some(key) = map.next_key()? {
match key {
Field::CurrentValue => {
if current_value.is_some() {
return Err(de::Error::duplicate_field("current_value"));
}
current_value = Some(map.next_value::<Option<T>>()?);
}
Field::CurrentTimestamp => {
if current_timestamp.is_some() {
return Err(de::Error::duplicate_field("current_timestamp"));
}
current_timestamp = Some(map.next_value::<u64>()?);
}
Field::CurrentNodeId => {
if current_node_id.is_some() {
return Err(de::Error::duplicate_field("current_node_id"));
}
current_node_id = Some(map.next_value::<NodeId>()?);
}
Field::NodeId => {
if node_id.is_some() {
return Err(de::Error::duplicate_field("node_id"));
}
node_id = Some(map.next_value::<NodeId>()?);
}
}
}
let current_value =
current_value.ok_or_else(|| de::Error::missing_field("current_value"))?;
let current_timestamp = current_timestamp
.ok_or_else(|| de::Error::missing_field("current_timestamp"))?;
let current_node_id =
current_node_id.ok_or_else(|| de::Error::missing_field("current_node_id"))?;
let node_id = node_id.ok_or_else(|| de::Error::missing_field("node_id"))?;
#[cfg(not(feature = "hardware-atomic"))]
{
Ok(LWWRegister {
current_value,
current_timestamp: CompactTimestamp::new(current_timestamp),
current_node_id,
node_id,
_phantom: core::marker::PhantomData,
})
}
#[cfg(feature = "hardware-atomic")]
{
Ok(LWWRegister {
current_value: UnsafeCell::new(current_value),
current_timestamp: AtomicU32::new(current_timestamp as u32),
current_node_id: AtomicU8::new(current_node_id),
node_id,
_phantom: core::marker::PhantomData,
})
}
}
}
const FIELDS: &[&str] = &[
"current_value",
"current_timestamp",
"current_node_id",
"node_id",
];
deserializer.deserialize_struct(
"LWWRegister",
FIELDS,
LWWRegisterVisitor {
_phantom: core::marker::PhantomData,
},
)
}
}
impl<T, C: MemoryConfig> CRDT<C> for LWWRegister<T, C>
where
T: Clone + PartialEq + core::fmt::Debug,
{
type Error = CRDTError;
fn merge(&mut self, other: &Self) -> CRDTResult<()> {
#[cfg(not(feature = "hardware-atomic"))]
{
if let Some(ref other_value) = other.current_value {
if self.should_update(&other.current_timestamp, other.current_node_id) {
self.current_value = Some(other_value.clone());
self.current_timestamp = other.current_timestamp;
self.current_node_id = other.current_node_id;
}
}
}
#[cfg(feature = "hardware-atomic")]
{
let other_value_ref = unsafe { &*other.current_value.get() };
if let Some(other_value) = other_value_ref {
let other_timestamp = other.current_timestamp.load(Ordering::Relaxed);
let other_node_id = other.current_node_id.load(Ordering::Relaxed);
loop {
let current_timestamp = self.current_timestamp.load(Ordering::Relaxed);
let current_node_id = self.current_node_id.load(Ordering::Relaxed);
let should_update = if current_timestamp == 0 {
true } else if other_timestamp > current_timestamp {
true } else if other_timestamp == current_timestamp {
other_node_id > current_node_id } else {
false };
if !should_update {
break;
}
match self.current_timestamp.compare_exchange_weak(
current_timestamp,
other_timestamp,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => {
self.current_node_id.store(other_node_id, Ordering::Relaxed);
unsafe {
*self.current_value.get() = Some(other_value.clone());
}
break;
}
Err(_) => {
continue;
}
}
}
}
}
Ok(())
}
fn eq(&self, other: &Self) -> bool {
#[cfg(not(feature = "hardware-atomic"))]
{
self.current_value == other.current_value
&& self.current_timestamp == other.current_timestamp
&& self.current_node_id == other.current_node_id
}
#[cfg(feature = "hardware-atomic")]
{
unsafe {
(*self.current_value.get()) == (*other.current_value.get())
&& self.current_timestamp.load(Ordering::Relaxed)
== other.current_timestamp.load(Ordering::Relaxed)
&& self.current_node_id.load(Ordering::Relaxed)
== other.current_node_id.load(Ordering::Relaxed)
}
}
}
fn size_bytes(&self) -> usize {
core::mem::size_of::<Self>()
}
fn validate(&self) -> CRDTResult<()> {
if self.node_id as usize >= C::MAX_NODES {
return Err(CRDTError::InvalidNodeId);
}
#[cfg(not(feature = "hardware-atomic"))]
{
if self.current_node_id as usize >= C::MAX_NODES {
return Err(CRDTError::InvalidNodeId);
}
}
#[cfg(feature = "hardware-atomic")]
{
if self.current_node_id.load(Ordering::Relaxed) as usize >= C::MAX_NODES {
return Err(CRDTError::InvalidNodeId);
}
}
Ok(())
}
fn state_hash(&self) -> u32 {
let mut hash = 0u32;
#[cfg(not(feature = "hardware-atomic"))]
{
if let Some(ref _value) = self.current_value {
hash ^= self.current_timestamp.as_u64() as u32;
hash ^= (self.current_node_id as u32) << 16;
}
}
#[cfg(feature = "hardware-atomic")]
{
unsafe {
if let Some(_value) = &*self.current_value.get() {
hash ^= self.current_timestamp.load(Ordering::Relaxed) as u32;
hash ^= (self.current_node_id.load(Ordering::Relaxed) as u32) << 16;
}
}
}
hash
}
fn can_merge(&self, _other: &Self) -> bool {
true
}
}
impl<T, C: MemoryConfig> BoundedCRDT<C> for LWWRegister<T, C>
where
T: Clone + PartialEq + core::fmt::Debug,
{
const MAX_SIZE_BYTES: usize = core::mem::size_of::<Self>();
const MAX_ELEMENTS: usize = 1;
fn memory_usage(&self) -> usize {
core::mem::size_of::<Self>()
}
fn element_count(&self) -> usize {
#[cfg(not(feature = "hardware-atomic"))]
{
if self.current_value.is_some() { 1 } else { 0 }
}
#[cfg(feature = "hardware-atomic")]
{
unsafe {
if (*self.current_value.get()).is_some() {
1
} else {
0
}
}
}
}
fn compact(&mut self) -> CRDTResult<usize> {
Ok(0)
}
fn can_add_element(&self) -> bool {
self.element_count() < Self::MAX_ELEMENTS
}
}
impl<T, C: MemoryConfig> RealTimeCRDT<C> for LWWRegister<T, C>
where
T: Clone + PartialEq + core::fmt::Debug,
{
const MAX_MERGE_CYCLES: u32 = 100;
const MAX_VALIDATE_CYCLES: u32 = 50;
const MAX_SERIALIZE_CYCLES: u32 = 75;
fn merge_bounded(&mut self, other: &Self) -> CRDTResult<()> {
self.merge(other)
}
fn validate_bounded(&self) -> CRDTResult<()> {
self.validate()
}
fn remaining_budget(&self) -> Option<u32> {
None
}
fn set_budget(&mut self, _cycles: u32) {
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::DefaultConfig;
#[test]
fn test_new_register() {
let register = LWWRegister::<i32, DefaultConfig>::new(1);
assert!(register.is_empty());
assert_eq!(register.get(), None);
assert_eq!(register.node_id, 1);
}
#[test]
fn test_set_and_get() {
let mut register = LWWRegister::<i32, DefaultConfig>::new(1);
assert!(register.set(42, 1000).is_ok());
assert_eq!(register.get(), Some(&42));
assert!(!register.is_empty());
assert_eq!(register.current_node(), 1);
}
#[test]
fn test_lww_semantics() {
let mut register = LWWRegister::<i32, DefaultConfig>::new(1);
register.set(10, 1000).unwrap();
assert_eq!(register.get(), Some(&10));
register.set(20, 2000).unwrap();
assert_eq!(register.get(), Some(&20));
register.set(30, 500).unwrap();
assert_eq!(register.get(), Some(&20)); }
#[test]
fn test_merge() {
let mut register1 = LWWRegister::<i32, DefaultConfig>::new(1);
let mut register2 = LWWRegister::<i32, DefaultConfig>::new(2);
register1.set(10, 1000).unwrap();
register2.set(20, 2000).unwrap();
register1.merge(®ister2).unwrap();
assert_eq!(register1.get(), Some(&20));
let mut register3 = LWWRegister::<i32, DefaultConfig>::new(3);
register3.set(30, 500).unwrap();
register1.merge(®ister3).unwrap();
assert_eq!(register1.get(), Some(&20)); }
#[test]
fn test_tiebreaker() {
let mut register1 = LWWRegister::<i32, DefaultConfig>::new(1);
let mut register2 = LWWRegister::<i32, DefaultConfig>::new(2);
register1.set(10, 1000).unwrap();
register2.set(20, 1000).unwrap();
register1.merge(®ister2).unwrap();
assert_eq!(register1.get(), Some(&20)); }
#[test]
fn test_bounded_crdt() {
let register = LWWRegister::<i32, DefaultConfig>::new(1);
assert_eq!(register.element_count(), 0);
assert!(register.memory_usage() > 0);
assert!(register.can_add_element());
}
#[test]
fn test_validation() {
let register = LWWRegister::<i32, DefaultConfig>::new(1);
assert!(register.validate().is_ok());
}
#[test]
fn test_real_time_crdt() {
let mut register1 = LWWRegister::<i32, DefaultConfig>::new(1);
let register2 = LWWRegister::<i32, DefaultConfig>::new(2);
assert!(register1.merge_bounded(®ister2).is_ok());
assert!(register1.validate_bounded().is_ok());
}
#[cfg(all(test, feature = "serde"))]
mod serde_tests {
use super::*;
#[test]
fn test_serialize_deserialize() {
let mut register = LWWRegister::<i32, DefaultConfig>::new(1);
register.set(42, 1000).unwrap();
let mut other = LWWRegister::<i32, DefaultConfig>::new(2);
other.set(100, 2000).unwrap();
register.merge(&other).unwrap();
assert_eq!(register.get(), Some(&100)); assert_eq!(register.current_node(), 2);
assert_eq!(register.timestamp().as_u64(), 2000);
}
#[test]
fn test_atomic_vs_standard_compatibility() {
let mut register = LWWRegister::<i32, DefaultConfig>::new(1);
register.set(42, 1000).unwrap();
assert_eq!(register.get(), Some(&42));
assert_eq!(register.current_node(), 1);
assert_eq!(register.timestamp().as_u64(), 1000);
}
#[test]
fn test_empty_register_serialization() {
let register = LWWRegister::<i32, DefaultConfig>::new(1);
assert!(register.is_empty());
assert_eq!(register.get(), None);
assert_eq!(register.current_node(), 0);
}
}
}