use std::collections::HashMap;
use sodiumoxide::{self, crypto::box_};
use crate::{
crypto::{encrypt::EncryptKeyPair, ByteObject},
mask::{config::MaskConfig, object::MaskObject},
settings::{MaskSettings, ModelSettings, PetSettings},
state_machine::{
events::{EventPublisher, EventSubscriber},
phases::PhaseName,
},
CoordinatorPublicKey,
};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RoundParameters {
pub pk: CoordinatorPublicKey,
pub sum: f64,
pub update: f64,
pub seed: RoundSeed,
}
#[derive(Debug)]
pub struct CoordinatorState {
pub keys: EncryptKeyPair,
pub round_params: RoundParameters,
pub min_sum_count: usize,
pub min_update_count: usize,
pub min_sum_time: u64,
pub min_update_time: u64,
pub max_sum_time: u64,
pub max_update_time: u64,
pub expected_participants: usize,
pub mask_config: MaskConfig,
pub model_size: usize,
pub events: EventPublisher,
}
impl CoordinatorState {
pub fn new(
pet_settings: PetSettings,
mask_settings: MaskSettings,
model_settings: ModelSettings,
) -> (Self, EventSubscriber) {
let keys = EncryptKeyPair::generate();
let round_params = RoundParameters {
pk: keys.public,
sum: pet_settings.sum,
update: pet_settings.update,
seed: RoundSeed::zeroed(),
};
let phase = PhaseName::Idle;
let (publisher, subscriber) =
EventPublisher::init(keys.clone(), round_params.clone(), phase);
let coordinator_state = Self {
keys,
round_params,
events: publisher,
min_sum_count: pet_settings.min_sum_count,
min_update_count: pet_settings.min_update_count,
min_sum_time: pet_settings.min_sum_time,
min_update_time: pet_settings.min_update_time,
max_sum_time: pet_settings.max_sum_time,
max_update_time: pet_settings.max_update_time,
expected_participants: pet_settings.expected_participants,
mask_config: mask_settings.into(),
model_size: model_settings.size,
};
(coordinator_state, subscriber)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct RoundSeed(box_::Seed);
impl ByteObject for RoundSeed {
const LENGTH: usize = box_::SEEDBYTES;
fn from_slice(bytes: &[u8]) -> Option<Self> {
box_::Seed::from_slice(bytes).map(Self)
}
fn zeroed() -> Self {
Self(box_::Seed([0_u8; Self::LENGTH]))
}
fn as_slice(&self) -> &[u8] {
self.0.as_ref()
}
}
pub type MaskDict = HashMap<MaskObject, usize>;