Skip to main content

cliffy_protocols/
consensus.rs

1//! Geometric consensus protocol implementations
2
3use crate::{geometric_mean, serde_ga3, GeometricCRDT, OperationType};
4use cliffy_core::GA3;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::{broadcast, RwLock};
9use uuid::Uuid;
10
11/// A message in the consensus protocol
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ConsensusMessage {
14    pub sender_id: Uuid,
15    #[serde(with = "serde_ga3")]
16    pub proposal: GA3,
17    pub round: u64,
18    pub message_type: MessageType,
19}
20
21/// Types of consensus messages
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub enum MessageType {
24    Propose(#[serde(with = "serde_ga3")] GA3),
25    Vote(bool, #[serde(with = "serde_ga3")] GA3),
26    Commit(#[serde(with = "serde_ga3")] GA3),
27    Sync(GeometricCRDT),
28}
29
30/// A geometric consensus protocol implementation
31pub struct GeometricConsensus {
32    node_id: Uuid,
33    current_round: u64,
34    /// Proposals collected per round (for future Byzantine fault tolerance)
35    #[allow(dead_code)]
36    proposals: Arc<RwLock<HashMap<u64, Vec<GA3>>>>,
37    /// Votes collected per round (for future Byzantine fault tolerance)
38    #[allow(dead_code)]
39    votes: Arc<RwLock<HashMap<u64, HashMap<Uuid, bool>>>>,
40    /// Committed states per round (for state recovery)
41    #[allow(dead_code)]
42    committed_states: Arc<RwLock<HashMap<u64, GA3>>>,
43    message_sender: broadcast::Sender<ConsensusMessage>,
44    message_receiver: broadcast::Receiver<ConsensusMessage>,
45    crdt_state: Arc<RwLock<GeometricCRDT>>,
46}
47
48impl GeometricConsensus {
49    /// Create a new consensus protocol instance
50    pub fn new(node_id: Uuid, initial_state: GA3) -> Self {
51        let (sender, receiver) = broadcast::channel(1000);
52        let crdt = GeometricCRDT::new(node_id, initial_state);
53
54        Self {
55            node_id,
56            current_round: 0,
57            proposals: Arc::new(RwLock::new(HashMap::new())),
58            votes: Arc::new(RwLock::new(HashMap::new())),
59            committed_states: Arc::new(RwLock::new(HashMap::new())),
60            message_sender: sender,
61            message_receiver: receiver,
62            crdt_state: Arc::new(RwLock::new(crdt)),
63        }
64    }
65
66    /// Propose a value for consensus
67    pub async fn propose(&mut self, value: GA3) -> Result<(), Box<dyn std::error::Error>> {
68        let round = self.current_round;
69        self.current_round += 1;
70
71        let message = ConsensusMessage {
72            sender_id: self.node_id,
73            proposal: value.clone(),
74            round,
75            message_type: MessageType::Propose(value),
76        };
77
78        self.message_sender.send(message)?;
79        Ok(())
80    }
81
82    /// Compute consensus from a set of proposals using geometric algebra
83    pub async fn geometric_consensus(
84        &self,
85        proposals: &[GA3],
86        threshold: f64,
87    ) -> Result<GA3, Box<dyn std::error::Error>> {
88        if proposals.is_empty() {
89            return Ok(GA3::zero());
90        }
91
92        // Compute geometric mean of all proposals
93        let consensus_value = geometric_mean(proposals);
94
95        // Check if consensus meets threshold (based on geometric distance)
96        let max_distance = proposals
97            .iter()
98            .map(|proposal| {
99                let diff = &consensus_value - proposal;
100                diff.magnitude()
101            })
102            .fold(0.0_f64, |acc, dist| acc.max(dist));
103
104        if max_distance <= threshold {
105            Ok(consensus_value)
106        } else {
107            // No consensus reached, return weighted geometric mean
108            self.weighted_geometric_consensus(proposals).await
109        }
110    }
111
112    /// Compute weighted geometric consensus based on magnitudes
113    async fn weighted_geometric_consensus(
114        &self,
115        proposals: &[GA3],
116    ) -> Result<GA3, Box<dyn std::error::Error>> {
117        // Weight proposals by their geometric magnitude
118        let weights: Vec<f64> = proposals.iter().map(|p| p.magnitude()).collect();
119
120        let total_weight: f64 = weights.iter().sum();
121
122        if total_weight == 0.0 {
123            return Ok(GA3::zero());
124        }
125
126        // Simple weighted average for now
127        let mut result = GA3::zero();
128        for (proposal, weight) in proposals.iter().zip(weights.iter()) {
129            let scaled: Vec<f64> = proposal
130                .as_slice()
131                .iter()
132                .map(|&c| c * weight / total_weight)
133                .collect();
134            let scaled_mv = GA3::from_slice(&scaled);
135            result = &result + &scaled_mv;
136        }
137
138        Ok(result)
139    }
140
141    /// Run a full consensus round
142    pub async fn run_consensus_round(
143        &mut self,
144        proposal: GA3,
145        participants: &[Uuid],
146    ) -> Result<Option<GA3>, Box<dyn std::error::Error>> {
147        let round = self.current_round;
148
149        // Phase 1: Propose
150        self.propose(proposal.clone()).await?;
151
152        // Collect proposals from all participants
153        let mut received_proposals = vec![proposal];
154        let mut proposal_count = 1;
155
156        while proposal_count < participants.len() {
157            if let Ok(message) = self.message_receiver.recv().await {
158                if message.round == round {
159                    if let MessageType::Propose(prop) = message.message_type {
160                        received_proposals.push(prop);
161                        proposal_count += 1;
162                    }
163                }
164            }
165        }
166
167        // Phase 2: Vote
168        let consensus_candidate = self.geometric_consensus(&received_proposals, 0.1).await?;
169        let vote_message = ConsensusMessage {
170            sender_id: self.node_id,
171            proposal: consensus_candidate.clone(),
172            round,
173            message_type: MessageType::Vote(true, consensus_candidate.clone()),
174        };
175
176        self.message_sender.send(vote_message)?;
177
178        // Collect votes
179        let mut votes = HashMap::new();
180        votes.insert(self.node_id, true);
181
182        while votes.len() < participants.len() {
183            if let Ok(message) = self.message_receiver.recv().await {
184                if message.round == round {
185                    if let MessageType::Vote(vote, _) = message.message_type {
186                        votes.insert(message.sender_id, vote);
187                    }
188                }
189            }
190        }
191
192        // Phase 3: Commit if majority votes yes
193        let yes_votes = votes.values().filter(|&&v| v).count();
194        if yes_votes > participants.len() / 2 {
195            let commit_message = ConsensusMessage {
196                sender_id: self.node_id,
197                proposal: consensus_candidate.clone(),
198                round,
199                message_type: MessageType::Commit(consensus_candidate.clone()),
200            };
201
202            self.message_sender.send(commit_message)?;
203
204            // Update CRDT state
205            let mut crdt_guard = self.crdt_state.write().await;
206            let op =
207                crdt_guard.create_operation(consensus_candidate.clone(), OperationType::Addition);
208            crdt_guard.apply_operation(op);
209
210            Ok(Some(consensus_candidate))
211        } else {
212            Ok(None)
213        }
214    }
215
216    /// Sync CRDT state with another node
217    pub async fn sync_crdt_state(
218        &self,
219        _other_node: Uuid,
220    ) -> Result<(), Box<dyn std::error::Error>> {
221        let crdt_guard = self.crdt_state.read().await;
222        let sync_message = ConsensusMessage {
223            sender_id: self.node_id,
224            proposal: crdt_guard.state.clone(),
225            round: self.current_round,
226            message_type: MessageType::Sync(crdt_guard.clone()),
227        };
228
229        self.message_sender.send(sync_message)?;
230        Ok(())
231    }
232
233    /// Handle an incoming sync message
234    pub async fn handle_sync_message(&self, crdt_state: GeometricCRDT) {
235        let mut local_crdt = self.crdt_state.write().await;
236        *local_crdt = local_crdt.merge(&crdt_state);
237    }
238
239    /// Get the current consensus state
240    pub async fn get_current_state(&self) -> GA3 {
241        let crdt_guard = self.crdt_state.read().await;
242        crdt_guard.state.clone()
243    }
244}
245
246/// Compute the lattice join (least upper bound) of two multivectors
247pub fn lattice_join(a: &GA3, b: &GA3) -> GA3 {
248    let a_coeffs = a.as_slice();
249    let b_coeffs = b.as_slice();
250
251    let result_coeffs: Vec<f64> = a_coeffs
252        .iter()
253        .zip(b_coeffs.iter())
254        .map(|(&ac, &bc)| ac.max(bc))
255        .collect();
256
257    GA3::from_slice(&result_coeffs)
258}
259
260/// Compute the lattice meet (greatest lower bound) of two multivectors
261pub fn lattice_meet(a: &GA3, b: &GA3) -> GA3 {
262    let a_coeffs = a.as_slice();
263    let b_coeffs = b.as_slice();
264
265    let result_coeffs: Vec<f64> = a_coeffs
266        .iter()
267        .zip(b_coeffs.iter())
268        .map(|(&ac, &bc)| ac.min(bc))
269        .collect();
270
271    GA3::from_slice(&result_coeffs)
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[tokio::test]
279    async fn test_geometric_consensus_simple() {
280        let proposals = vec![GA3::scalar(1.0), GA3::scalar(2.0), GA3::scalar(4.0)];
281
282        let node_id = Uuid::new_v4();
283        let consensus = GeometricConsensus::new(node_id, GA3::zero());
284
285        let result = consensus
286            .geometric_consensus(&proposals, 5.0) // Use larger threshold for test
287            .await
288            .unwrap();
289
290        // Result should be some weighted average
291        assert!(result.scalar_part() > 0.0);
292    }
293
294    #[test]
295    fn test_lattice_operations() {
296        let a = GA3::from_slice(&[1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
297        let b = GA3::from_slice(&[2.0, 1.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
298
299        let join = lattice_join(&a, &b);
300        let meet = lattice_meet(&a, &b);
301
302        assert!((join.scalar_part() - 2.0).abs() < 1e-10);
303        assert!((join.get(1) - 2.0).abs() < 1e-10);
304        assert!((join.get(2) - 4.0).abs() < 1e-10);
305
306        assert!((meet.scalar_part() - 1.0).abs() < 1e-10);
307        assert!((meet.get(1) - 1.0).abs() < 1e-10);
308        assert!((meet.get(2) - 3.0).abs() < 1e-10);
309    }
310}