use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Vote {
pub agent_id: Uuid,
pub agrees: bool,
pub confidence: f64,
pub reasoning: Option<String>,
}
impl Vote {
pub fn new(agent_id: &str, agrees: bool, confidence: f64) -> Self {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(agent_id.as_bytes());
let hash = hasher.finalize();
let mut bytes = [0u8; 16];
bytes.copy_from_slice(&hash[..16]);
Self {
agent_id: Uuid::from_bytes(bytes),
agrees,
confidence,
reasoning: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ConsensusProtocol {
Majority,
SuperMajority,
Unanimous,
WeightedConfidence,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Consensus {
pub protocol: ConsensusProtocol,
pub votes: Vec<Vote>,
pub reached: bool,
pub decision: Option<bool>,
pub confidence: f64,
}
impl Consensus {
pub fn new(protocol: ConsensusProtocol) -> Self {
Self {
protocol,
votes: Vec::new(),
reached: false,
decision: None,
confidence: 0.0,
}
}
pub fn add_vote(&mut self, vote: Vote) {
self.votes.push(vote);
}
pub fn evaluate(&mut self) {
if self.votes.is_empty() {
return;
}
let total = self.votes.len() as f64;
let agrees: f64 = self.votes.iter().filter(|v| v.agrees).count() as f64;
if total == 0.0 {
self.reached = false;
self.decision = None;
return;
}
let agree_ratio = agrees / total;
let (reached, decision) = match self.protocol {
ConsensusProtocol::Majority => (agree_ratio != 0.5, Some(agree_ratio > 0.5)),
ConsensusProtocol::SuperMajority => {
if agree_ratio > 0.66 {
(true, Some(true))
} else if agree_ratio < 0.34 {
(true, Some(false))
} else {
(false, None)
}
}
ConsensusProtocol::Unanimous => {
if agree_ratio == 1.0 {
(true, Some(true))
} else if agree_ratio == 0.0 {
(true, Some(false))
} else {
(false, None)
}
}
ConsensusProtocol::WeightedConfidence => {
let weighted_agree: f64 = self
.votes
.iter()
.filter(|v| v.agrees)
.map(|v| v.confidence)
.sum();
let weighted_disagree: f64 = self
.votes
.iter()
.filter(|v| !v.agrees)
.map(|v| v.confidence)
.sum();
let total_confidence = weighted_agree + weighted_disagree;
if total_confidence > 0.0 {
let weighted_ratio = weighted_agree / total_confidence;
(true, Some(weighted_ratio > 0.5))
} else {
(false, None)
}
}
};
self.reached = reached;
self.decision = decision;
if total == 0.0 {
self.confidence = 0.0;
} else {
self.confidence = self.votes.iter().map(|v| v.confidence).sum::<f64>() / total;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_majority_consensus() {
let mut consensus = Consensus::new(ConsensusProtocol::Majority);
consensus.add_vote(Vote {
agent_id: Uuid::new_v4(),
agrees: true,
confidence: 0.9,
reasoning: None,
});
consensus.add_vote(Vote {
agent_id: Uuid::new_v4(),
agrees: true,
confidence: 0.8,
reasoning: None,
});
consensus.add_vote(Vote {
agent_id: Uuid::new_v4(),
agrees: false,
confidence: 0.7,
reasoning: None,
});
consensus.evaluate();
assert!(consensus.reached);
assert_eq!(consensus.decision, Some(true));
}
#[test]
fn test_unanimous_fails() {
let mut consensus = Consensus::new(ConsensusProtocol::Unanimous);
consensus.add_vote(Vote {
agent_id: Uuid::new_v4(),
agrees: true,
confidence: 0.9,
reasoning: None,
});
consensus.add_vote(Vote {
agent_id: Uuid::new_v4(),
agrees: false,
confidence: 0.8,
reasoning: None,
});
consensus.evaluate();
assert!(!consensus.reached);
assert_eq!(consensus.decision, None);
}
}