1use crate::serde_ga3;
4use crate::vector_clock::VectorClock;
5use cliffy_core::GA3;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use uuid::Uuid;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct GeometricCRDT {
13 #[serde(with = "serde_ga3")]
14 pub state: GA3,
15 pub vector_clock: VectorClock,
16 pub node_id: Uuid,
17 pub operations: HashMap<u64, GeometricOperation>,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct GeometricOperation {
23 pub id: u64,
24 pub node_id: Uuid,
25 pub timestamp: VectorClock,
26 #[serde(with = "serde_ga3")]
27 pub transform: GA3,
28 pub operation_type: OperationType,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub enum OperationType {
34 GeometricProduct,
35 Addition,
36 Sandwich,
37 Exponential,
38}
39
40impl GeometricCRDT {
41 pub fn new(node_id: Uuid, initial_state: GA3) -> Self {
43 Self {
44 state: initial_state,
45 vector_clock: VectorClock::new(),
46 node_id,
47 operations: HashMap::new(),
48 }
49 }
50
51 pub fn apply_operation(&mut self, operation: GeometricOperation) {
53 if self.operations.contains_key(&operation.id) {
54 return;
55 }
56
57 self.vector_clock.update(&operation.timestamp);
58 self.operations.insert(operation.id, operation.clone());
59
60 self.state = match operation.operation_type {
61 OperationType::GeometricProduct => self.state.geometric_product(&operation.transform),
62 OperationType::Addition => &self.state + &operation.transform,
63 OperationType::Sandwich => {
64 let rev = operation.transform.reverse();
66 operation
67 .transform
68 .geometric_product(&self.state)
69 .geometric_product(&rev)
70 }
71 OperationType::Exponential => operation.transform.exp().geometric_product(&self.state),
72 };
73 }
74
75 pub fn create_operation(
77 &mut self,
78 transform: GA3,
79 op_type: OperationType,
80 ) -> GeometricOperation {
81 self.vector_clock.tick(self.node_id);
82 let op_id = self.operations.len() as u64;
83
84 GeometricOperation {
85 id: op_id,
86 node_id: self.node_id,
87 timestamp: self.vector_clock.clone(),
88 transform,
89 operation_type: op_type,
90 }
91 }
92
93 pub fn merge(&mut self, other: &GeometricCRDT) -> GeometricCRDT {
95 let merged_clock = self.vector_clock.merge(&other.vector_clock);
96
97 let mut merged_ops = self.operations.clone();
98 for (id, op) in &other.operations {
99 if !merged_ops.contains_key(id) {
100 merged_ops.insert(*id, op.clone());
101 }
102 }
103
104 let mut sorted_ops: Vec<_> = merged_ops.values().cloned().collect();
106 sorted_ops.sort_by(|a, b| {
107 if a.timestamp.happens_before(&b.timestamp) {
108 std::cmp::Ordering::Less
109 } else if b.timestamp.happens_before(&a.timestamp) {
110 std::cmp::Ordering::Greater
111 } else {
112 a.id.cmp(&b.id) }
114 });
115
116 let mut result = GeometricCRDT::new(self.node_id, GA3::zero());
117 result.vector_clock = merged_clock;
118 result.operations = merged_ops;
119
120 for op in sorted_ops {
121 result.apply_operation(op);
122 }
123
124 result
125 }
126
127 pub fn geometric_join(&self, other: &GA3) -> GA3 {
131 let self_norm = self.state.magnitude();
132 let other_norm = other.magnitude();
133
134 if self_norm > other_norm {
135 self.state.clone()
136 } else if other_norm > self_norm {
137 other.clone()
138 } else {
139 geometric_mean(&[self.state.clone(), other.clone()])
141 }
142 }
143}
144
145pub fn geometric_mean(multivectors: &[GA3]) -> GA3 {
147 if multivectors.is_empty() {
148 return GA3::zero();
149 }
150
151 let n = multivectors.len() as f64;
152 let sum_logs: GA3 = multivectors
153 .iter()
154 .map(|mv| mv.exp()) .fold(GA3::zero(), |acc, log_mv| &acc + &log_mv);
156
157 let coeffs: Vec<f64> = sum_logs.as_slice().iter().map(|&c| c / n).collect();
159 GA3::from_slice(&coeffs)
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn test_geometric_crdt_basic() {
168 let node_id = Uuid::new_v4();
169 let initial_state = GA3::scalar(1.0);
170 let mut crdt = GeometricCRDT::new(node_id, initial_state);
171
172 let transform = GA3::scalar(2.0);
174 let op = crdt.create_operation(transform, OperationType::Addition);
175 crdt.apply_operation(op);
176
177 assert!((crdt.state.scalar_part() - 3.0).abs() < 1e-10);
179 }
180
181 #[test]
182 fn test_geometric_crdt_convergence() {
183 let node1_id = Uuid::new_v4();
184 let node2_id = Uuid::new_v4();
185
186 let initial_state = GA3::scalar(1.0);
187 let mut crdt1 = GeometricCRDT::new(node1_id, initial_state.clone());
188 let mut crdt2 = GeometricCRDT::new(node2_id, initial_state);
189
190 let op1 = crdt1.create_operation(GA3::scalar(2.0), OperationType::Addition);
192 crdt1.apply_operation(op1.clone());
193
194 let op2 = crdt2.create_operation(GA3::scalar(3.0), OperationType::Addition);
196 crdt2.apply_operation(op2.clone());
197
198 let merged1 = crdt1.merge(&crdt2);
200 let merged2 = crdt2.merge(&crdt1);
201
202 let diff = merged1.state.scalar_part() - merged2.state.scalar_part();
204 assert!(diff.abs() < 1e-10);
205 }
206
207 #[test]
208 fn test_vector_clock_ordering() {
209 let mut clock1 = VectorClock::new();
210 let mut clock2 = VectorClock::new();
211
212 let node1 = Uuid::new_v4();
213 let node2 = Uuid::new_v4();
214
215 clock1.tick(node1);
216 clock2.tick(node2);
217
218 assert!(clock1.concurrent(&clock2));
219
220 clock1.update(&clock2);
221 clock1.tick(node1);
222
223 assert!(clock2.happens_before(&clock1));
224 }
225}