nt_memory/coordination/
consensus.rs

1//! Simple consensus engine (Raft-inspired)
2
3use serde::{Serialize, Deserialize};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9/// Proposal for consensus
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Proposal {
12    /// Proposal ID
13    pub id: String,
14
15    /// Proposer agent ID
16    pub proposer: String,
17
18    /// Proposal data
19    pub data: serde_json::Value,
20
21    /// Required quorum (0.0 - 1.0)
22    pub quorum: f64,
23}
24
25/// Vote on proposal
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct Vote {
28    /// Proposal ID
29    pub proposal_id: String,
30
31    /// Voter agent ID
32    pub voter: String,
33
34    /// Approve or reject
35    pub approve: bool,
36
37    /// Vote weight (default 1.0)
38    pub weight: f64,
39}
40
41/// Consensus result
42#[derive(Debug, Clone)]
43pub enum ConsensusResult {
44    /// Consensus reached
45    Approved,
46
47    /// Consensus failed
48    Rejected,
49
50    /// Still pending
51    Pending {
52        approval_rate: f64,
53        votes_needed: usize,
54    },
55}
56
57/// Proposal state
58#[derive(Debug, Clone)]
59struct ProposalState {
60    proposal: Proposal,
61    votes: Vec<Vote>,
62    created_at: std::time::Instant,
63}
64
65impl ProposalState {
66    fn calculate_result(&self, total_agents: usize) -> ConsensusResult {
67        let total_weight: f64 = self.votes.iter().map(|v| v.weight).sum();
68        let approval_weight: f64 = self
69            .votes
70            .iter()
71            .filter(|v| v.approve)
72            .map(|v| v.weight)
73            .sum();
74
75        let approval_rate = if total_weight > 0.0 {
76            approval_weight / total_weight
77        } else {
78            0.0
79        };
80
81        let votes_received = self.votes.len();
82        let quorum_votes = (total_agents as f64 * self.proposal.quorum).ceil() as usize;
83
84        if votes_received >= quorum_votes {
85            if approval_rate >= 0.5 {
86                ConsensusResult::Approved
87            } else {
88                ConsensusResult::Rejected
89            }
90        } else {
91            ConsensusResult::Pending {
92                approval_rate,
93                votes_needed: quorum_votes - votes_received,
94            }
95        }
96    }
97}
98
99/// Consensus engine
100pub struct ConsensusEngine {
101    /// Active proposals
102    proposals: Arc<RwLock<HashMap<String, ProposalState>>>,
103
104    /// Registered agents
105    agents: Arc<RwLock<HashMap<String, AgentInfo>>>,
106}
107
108#[derive(Debug, Clone)]
109struct AgentInfo {
110    id: String,
111    weight: f64,
112}
113
114impl ConsensusEngine {
115    /// Create new consensus engine
116    pub fn new() -> Self {
117        Self {
118            proposals: Arc::new(RwLock::new(HashMap::new())),
119            agents: Arc::new(RwLock::new(HashMap::new())),
120        }
121    }
122
123    /// Register agent
124    pub async fn register_agent(&self, agent_id: String, weight: f64) {
125        let mut agents = self.agents.write().await;
126        agents.insert(
127            agent_id.clone(),
128            AgentInfo {
129                id: agent_id,
130                weight,
131            },
132        );
133    }
134
135    /// Submit proposal
136    pub async fn submit_proposal(&self, proposal: Proposal) -> String {
137        let id = Uuid::new_v4().to_string();
138
139        let state = ProposalState {
140            proposal: Proposal {
141                id: id.clone(),
142                ..proposal
143            },
144            votes: Vec::new(),
145            created_at: std::time::Instant::now(),
146        };
147
148        let mut proposals = self.proposals.write().await;
149        proposals.insert(id.clone(), state);
150
151        tracing::debug!("Proposal submitted: {}", id);
152
153        id
154    }
155
156    /// Vote on proposal
157    pub async fn vote(&self, vote: Vote) -> anyhow::Result<ConsensusResult> {
158        let mut proposals = self.proposals.write().await;
159
160        let state = proposals
161            .get_mut(&vote.proposal_id)
162            .ok_or_else(|| anyhow::anyhow!("Proposal not found"))?;
163
164        // Check if agent already voted
165        if state.votes.iter().any(|v| v.voter == vote.voter) {
166            return Err(anyhow::anyhow!("Agent already voted"));
167        }
168
169        state.votes.push(vote);
170
171        // Calculate result
172        let agents = self.agents.read().await;
173        let result = state.calculate_result(agents.len());
174
175        Ok(result)
176    }
177
178    /// Get proposal result
179    pub async fn get_result(&self, proposal_id: &str) -> Option<ConsensusResult> {
180        let proposals = self.proposals.read().await;
181        let agents = self.agents.read().await;
182
183        proposals
184            .get(proposal_id)
185            .map(|state| state.calculate_result(agents.len()))
186    }
187
188    /// Get proposal details
189    pub async fn get_proposal(&self, proposal_id: &str) -> Option<Proposal> {
190        let proposals = self.proposals.read().await;
191        proposals.get(proposal_id).map(|s| s.proposal.clone())
192    }
193
194    /// List all proposals
195    pub async fn list_proposals(&self) -> Vec<Proposal> {
196        let proposals = self.proposals.read().await;
197        proposals.values().map(|s| s.proposal.clone()).collect()
198    }
199
200    /// Cleanup old proposals
201    pub async fn cleanup_old(&self, max_age: std::time::Duration) {
202        let mut proposals = self.proposals.write().await;
203        proposals.retain(|_, state| state.created_at.elapsed() < max_age);
204    }
205
206    /// Get agent count
207    pub async fn agent_count(&self) -> usize {
208        self.agents.read().await.len()
209    }
210}
211
212impl Default for ConsensusEngine {
213    fn default() -> Self {
214        Self::new()
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[tokio::test]
223    async fn test_consensus_approval() {
224        let engine = ConsensusEngine::new();
225
226        // Register agents
227        engine.register_agent("agent1".to_string(), 1.0).await;
228        engine.register_agent("agent2".to_string(), 1.0).await;
229        engine.register_agent("agent3".to_string(), 1.0).await;
230
231        // Submit proposal
232        let proposal = Proposal {
233            id: String::new(),
234            proposer: "agent1".to_string(),
235            data: serde_json::json!({"action": "test"}),
236            quorum: 0.67, // Need 2/3 agents
237        };
238
239        let proposal_id = engine.submit_proposal(proposal).await;
240
241        // Vote (2 approve, quorum reached)
242        let result1 = engine
243            .vote(Vote {
244                proposal_id: proposal_id.clone(),
245                voter: "agent1".to_string(),
246                approve: true,
247                weight: 1.0,
248            })
249            .await
250            .unwrap();
251
252        assert!(matches!(result1, ConsensusResult::Pending { .. }));
253
254        let result2 = engine
255            .vote(Vote {
256                proposal_id: proposal_id.clone(),
257                voter: "agent2".to_string(),
258                approve: true,
259                weight: 1.0,
260            })
261            .await
262            .unwrap();
263
264        assert!(matches!(result2, ConsensusResult::Approved));
265    }
266
267    #[tokio::test]
268    async fn test_consensus_rejection() {
269        let engine = ConsensusEngine::new();
270
271        // Register 3 agents
272        for i in 1..=3 {
273            engine
274                .register_agent(format!("agent{}", i), 1.0)
275                .await;
276        }
277
278        // Submit proposal
279        let proposal = Proposal {
280            id: String::new(),
281            proposer: "agent1".to_string(),
282            data: serde_json::json!({"action": "test"}),
283            quorum: 0.67,
284        };
285
286        let proposal_id = engine.submit_proposal(proposal).await;
287
288        // Vote (1 approve, 2 reject)
289        engine
290            .vote(Vote {
291                proposal_id: proposal_id.clone(),
292                voter: "agent1".to_string(),
293                approve: true,
294                weight: 1.0,
295            })
296            .await
297            .unwrap();
298
299        engine
300            .vote(Vote {
301                proposal_id: proposal_id.clone(),
302                voter: "agent2".to_string(),
303                approve: false,
304                weight: 1.0,
305            })
306            .await
307            .unwrap();
308
309        let result = engine
310            .vote(Vote {
311                proposal_id: proposal_id.clone(),
312                voter: "agent3".to_string(),
313                approve: false,
314                weight: 1.0,
315            })
316            .await
317            .unwrap();
318
319        assert!(matches!(result, ConsensusResult::Rejected));
320    }
321
322    #[tokio::test]
323    async fn test_duplicate_vote() {
324        let engine = ConsensusEngine::new();
325
326        engine.register_agent("agent1".to_string(), 1.0).await;
327
328        let proposal = Proposal {
329            id: String::new(),
330            proposer: "agent1".to_string(),
331            data: serde_json::json!({}),
332            quorum: 0.5,
333        };
334
335        let proposal_id = engine.submit_proposal(proposal).await;
336
337        // First vote
338        engine
339            .vote(Vote {
340                proposal_id: proposal_id.clone(),
341                voter: "agent1".to_string(),
342                approve: true,
343                weight: 1.0,
344            })
345            .await
346            .unwrap();
347
348        // Duplicate vote should fail
349        let result = engine
350            .vote(Vote {
351                proposal_id: proposal_id.clone(),
352                voter: "agent1".to_string(),
353                approve: false,
354                weight: 1.0,
355            })
356            .await;
357
358        assert!(result.is_err());
359    }
360}