use std::{collections::HashMap, mem, ops::Index};
use elliptic_curve::Field;
use serde::Serialize;
use crate::{compat::CSCurve, protocol::Participant};
#[derive(Clone, Debug, Serialize)]
pub struct ParticipantList {
participants: Vec<Participant>,
#[serde(skip_serializing)]
indices: HashMap<Participant, usize>,
}
impl ParticipantList {
fn new_vec(mut participants: Vec<Participant>) -> Option<Self> {
participants.sort();
let indices: HashMap<_, _> = participants
.iter()
.enumerate()
.map(|(p, x)| (*x, p))
.collect();
if indices.len() < participants.len() {
return None;
}
Some(Self {
participants,
indices,
})
}
pub fn new(participants: &[Participant]) -> Option<Self> {
Self::new_vec(participants.to_owned())
}
pub fn len(&self) -> usize {
self.participants.len()
}
pub fn contains(&self, participant: Participant) -> bool {
self.indices.contains_key(&participant)
}
pub fn others(&self, me: Participant) -> impl Iterator<Item = Participant> + '_ {
self.participants.iter().filter(move |x| **x != me).copied()
}
pub fn index(&self, participant: Participant) -> usize {
self.indices[&participant]
}
pub fn lagrange<C: CSCurve>(&self, p: Participant) -> C::Scalar {
let p_scalar = p.scalar::<C>();
let mut top = C::Scalar::ONE;
let mut bot = C::Scalar::ONE;
for q in &self.participants {
if p == *q {
continue;
}
let q_scalar = q.scalar::<C>();
top *= q_scalar;
bot *= q_scalar - p_scalar;
}
top * bot.invert().unwrap()
}
pub fn intersection(&self, others: &ParticipantList) -> Self {
let mut out = Vec::new();
for &p in &self.participants {
if others.contains(p) {
out.push(p);
}
}
Self::new_vec(out).unwrap()
}
}
impl From<ParticipantList> for Vec<Participant> {
fn from(val: ParticipantList) -> Self {
val.participants
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ParticipantMap<'a, T> {
#[serde(skip_serializing)]
participants: &'a ParticipantList,
data: Vec<Option<T>>,
#[serde(skip_serializing)]
count: usize,
}
impl<'a, T> ParticipantMap<'a, T> {
pub fn new(participants: &'a ParticipantList) -> Self {
let size = participants.participants.len();
let mut data = Vec::with_capacity(size);
for _ in 0..size {
data.push(None);
}
Self {
participants,
data,
count: 0,
}
}
pub fn full(&self) -> bool {
self.count == self.data.len()
}
pub fn put(&mut self, participant: Participant, data: T) {
let i = self.participants.indices.get(&participant);
if i.is_none() {
return;
}
let i = *i.unwrap();
if self.data[i].is_some() {
return;
}
self.data[i] = Some(data);
self.count += 1;
}
}
impl<'a, T> Index<Participant> for ParticipantMap<'a, T> {
type Output = T;
fn index(&self, index: Participant) -> &Self::Output {
self.data[self.participants.index(index)].as_ref().unwrap()
}
}
#[derive(Debug, Clone)]
pub struct ParticipantCounter<'a> {
participants: &'a ParticipantList,
seen: Vec<bool>,
counter: usize,
}
impl<'a> ParticipantCounter<'a> {
pub fn new(participants: &'a ParticipantList) -> Self {
Self {
participants,
seen: vec![false; participants.len()],
counter: participants.len(),
}
}
pub fn put(&mut self, participant: Participant) -> bool {
let i = match self.participants.indices.get(&participant) {
None => return false,
Some(&i) => i,
};
let inserted = !mem::replace(&mut self.seen[i], true);
if inserted {
self.counter -= 1;
}
inserted
}
pub fn clear(&mut self) {
for x in &mut self.seen {
*x = false
}
self.counter = self.participants.len();
}
pub fn full(&self) -> bool {
self.counter == 0
}
}