1use crate::{geometric_mean, serde_ga3, GeometricCRDT, OperationType};
4use cliffy_core::GA3;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::{broadcast, RwLock};
9use uuid::Uuid;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ConsensusMessage {
14 pub sender_id: Uuid,
15 #[serde(with = "serde_ga3")]
16 pub proposal: GA3,
17 pub round: u64,
18 pub message_type: MessageType,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub enum MessageType {
24 Propose(#[serde(with = "serde_ga3")] GA3),
25 Vote(bool, #[serde(with = "serde_ga3")] GA3),
26 Commit(#[serde(with = "serde_ga3")] GA3),
27 Sync(GeometricCRDT),
28}
29
30pub struct GeometricConsensus {
32 node_id: Uuid,
33 current_round: u64,
34 #[allow(dead_code)]
36 proposals: Arc<RwLock<HashMap<u64, Vec<GA3>>>>,
37 #[allow(dead_code)]
39 votes: Arc<RwLock<HashMap<u64, HashMap<Uuid, bool>>>>,
40 #[allow(dead_code)]
42 committed_states: Arc<RwLock<HashMap<u64, GA3>>>,
43 message_sender: broadcast::Sender<ConsensusMessage>,
44 message_receiver: broadcast::Receiver<ConsensusMessage>,
45 crdt_state: Arc<RwLock<GeometricCRDT>>,
46}
47
48impl GeometricConsensus {
49 pub fn new(node_id: Uuid, initial_state: GA3) -> Self {
51 let (sender, receiver) = broadcast::channel(1000);
52 let crdt = GeometricCRDT::new(node_id, initial_state);
53
54 Self {
55 node_id,
56 current_round: 0,
57 proposals: Arc::new(RwLock::new(HashMap::new())),
58 votes: Arc::new(RwLock::new(HashMap::new())),
59 committed_states: Arc::new(RwLock::new(HashMap::new())),
60 message_sender: sender,
61 message_receiver: receiver,
62 crdt_state: Arc::new(RwLock::new(crdt)),
63 }
64 }
65
66 pub async fn propose(&mut self, value: GA3) -> Result<(), Box<dyn std::error::Error>> {
68 let round = self.current_round;
69 self.current_round += 1;
70
71 let message = ConsensusMessage {
72 sender_id: self.node_id,
73 proposal: value.clone(),
74 round,
75 message_type: MessageType::Propose(value),
76 };
77
78 self.message_sender.send(message)?;
79 Ok(())
80 }
81
82 pub async fn geometric_consensus(
84 &self,
85 proposals: &[GA3],
86 threshold: f64,
87 ) -> Result<GA3, Box<dyn std::error::Error>> {
88 if proposals.is_empty() {
89 return Ok(GA3::zero());
90 }
91
92 let consensus_value = geometric_mean(proposals);
94
95 let max_distance = proposals
97 .iter()
98 .map(|proposal| {
99 let diff = &consensus_value - proposal;
100 diff.magnitude()
101 })
102 .fold(0.0_f64, |acc, dist| acc.max(dist));
103
104 if max_distance <= threshold {
105 Ok(consensus_value)
106 } else {
107 self.weighted_geometric_consensus(proposals).await
109 }
110 }
111
112 async fn weighted_geometric_consensus(
114 &self,
115 proposals: &[GA3],
116 ) -> Result<GA3, Box<dyn std::error::Error>> {
117 let weights: Vec<f64> = proposals.iter().map(|p| p.magnitude()).collect();
119
120 let total_weight: f64 = weights.iter().sum();
121
122 if total_weight == 0.0 {
123 return Ok(GA3::zero());
124 }
125
126 let mut result = GA3::zero();
128 for (proposal, weight) in proposals.iter().zip(weights.iter()) {
129 let scaled: Vec<f64> = proposal
130 .as_slice()
131 .iter()
132 .map(|&c| c * weight / total_weight)
133 .collect();
134 let scaled_mv = GA3::from_slice(&scaled);
135 result = &result + &scaled_mv;
136 }
137
138 Ok(result)
139 }
140
141 pub async fn run_consensus_round(
143 &mut self,
144 proposal: GA3,
145 participants: &[Uuid],
146 ) -> Result<Option<GA3>, Box<dyn std::error::Error>> {
147 let round = self.current_round;
148
149 self.propose(proposal.clone()).await?;
151
152 let mut received_proposals = vec![proposal];
154 let mut proposal_count = 1;
155
156 while proposal_count < participants.len() {
157 if let Ok(message) = self.message_receiver.recv().await {
158 if message.round == round {
159 if let MessageType::Propose(prop) = message.message_type {
160 received_proposals.push(prop);
161 proposal_count += 1;
162 }
163 }
164 }
165 }
166
167 let consensus_candidate = self.geometric_consensus(&received_proposals, 0.1).await?;
169 let vote_message = ConsensusMessage {
170 sender_id: self.node_id,
171 proposal: consensus_candidate.clone(),
172 round,
173 message_type: MessageType::Vote(true, consensus_candidate.clone()),
174 };
175
176 self.message_sender.send(vote_message)?;
177
178 let mut votes = HashMap::new();
180 votes.insert(self.node_id, true);
181
182 while votes.len() < participants.len() {
183 if let Ok(message) = self.message_receiver.recv().await {
184 if message.round == round {
185 if let MessageType::Vote(vote, _) = message.message_type {
186 votes.insert(message.sender_id, vote);
187 }
188 }
189 }
190 }
191
192 let yes_votes = votes.values().filter(|&&v| v).count();
194 if yes_votes > participants.len() / 2 {
195 let commit_message = ConsensusMessage {
196 sender_id: self.node_id,
197 proposal: consensus_candidate.clone(),
198 round,
199 message_type: MessageType::Commit(consensus_candidate.clone()),
200 };
201
202 self.message_sender.send(commit_message)?;
203
204 let mut crdt_guard = self.crdt_state.write().await;
206 let op =
207 crdt_guard.create_operation(consensus_candidate.clone(), OperationType::Addition);
208 crdt_guard.apply_operation(op);
209
210 Ok(Some(consensus_candidate))
211 } else {
212 Ok(None)
213 }
214 }
215
216 pub async fn sync_crdt_state(
218 &self,
219 _other_node: Uuid,
220 ) -> Result<(), Box<dyn std::error::Error>> {
221 let crdt_guard = self.crdt_state.read().await;
222 let sync_message = ConsensusMessage {
223 sender_id: self.node_id,
224 proposal: crdt_guard.state.clone(),
225 round: self.current_round,
226 message_type: MessageType::Sync(crdt_guard.clone()),
227 };
228
229 self.message_sender.send(sync_message)?;
230 Ok(())
231 }
232
233 pub async fn handle_sync_message(&self, crdt_state: GeometricCRDT) {
235 let mut local_crdt = self.crdt_state.write().await;
236 *local_crdt = local_crdt.merge(&crdt_state);
237 }
238
239 pub async fn get_current_state(&self) -> GA3 {
241 let crdt_guard = self.crdt_state.read().await;
242 crdt_guard.state.clone()
243 }
244}
245
246pub fn lattice_join(a: &GA3, b: &GA3) -> GA3 {
248 let a_coeffs = a.as_slice();
249 let b_coeffs = b.as_slice();
250
251 let result_coeffs: Vec<f64> = a_coeffs
252 .iter()
253 .zip(b_coeffs.iter())
254 .map(|(&ac, &bc)| ac.max(bc))
255 .collect();
256
257 GA3::from_slice(&result_coeffs)
258}
259
260pub fn lattice_meet(a: &GA3, b: &GA3) -> GA3 {
262 let a_coeffs = a.as_slice();
263 let b_coeffs = b.as_slice();
264
265 let result_coeffs: Vec<f64> = a_coeffs
266 .iter()
267 .zip(b_coeffs.iter())
268 .map(|(&ac, &bc)| ac.min(bc))
269 .collect();
270
271 GA3::from_slice(&result_coeffs)
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[tokio::test]
279 async fn test_geometric_consensus_simple() {
280 let proposals = vec![GA3::scalar(1.0), GA3::scalar(2.0), GA3::scalar(4.0)];
281
282 let node_id = Uuid::new_v4();
283 let consensus = GeometricConsensus::new(node_id, GA3::zero());
284
285 let result = consensus
286 .geometric_consensus(&proposals, 5.0) .await
288 .unwrap();
289
290 assert!(result.scalar_part() > 0.0);
292 }
293
294 #[test]
295 fn test_lattice_operations() {
296 let a = GA3::from_slice(&[1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
297 let b = GA3::from_slice(&[2.0, 1.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
298
299 let join = lattice_join(&a, &b);
300 let meet = lattice_meet(&a, &b);
301
302 assert!((join.scalar_part() - 2.0).abs() < 1e-10);
303 assert!((join.get(1) - 2.0).abs() < 1e-10);
304 assert!((join.get(2) - 4.0).abs() < 1e-10);
305
306 assert!((meet.scalar_part() - 1.0).abs() < 1e-10);
307 assert!((meet.get(1) - 1.0).abs() < 1e-10);
308 assert!((meet.get(2) - 3.0).abs() < 1e-10);
309 }
310}