xaynet 0.9.0

The Xayn Network project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant.
Documentation
//! Loading and validation of settings.
//!
//! Values defined in the configuration file can be overridden by environment variables. Examples of
//! configuration files can be found in the `configs/` directory located in the repository root.

use std::{fmt, path::PathBuf};

use config::{Config, ConfigError, Environment};
use serde::de::{self, Deserializer, Visitor};
use thiserror::Error;
use tracing_subscriber::filter::EnvFilter;
use validator::{Validate, ValidationError, ValidationErrors};

use crate::mask::config::{BoundType, DataType, GroupType, MaskConfig, ModelType};

#[derive(Error, Debug)]
/// An error related to loading and validation of settings.
pub enum SettingsError {
    #[error("configuration loading failed: {0}")]
    Loading(#[from] ConfigError),
    #[error("validation failed: {0}")]
    Validation(#[from] ValidationErrors),
}

#[derive(Debug, Validate, Deserialize)]
/// The combined settings.
///
/// Each section in the configuration file corresponds to the identically named settings field.
pub struct Settings {
    #[validate]
    pub api: ApiSettings,
    #[validate]
    pub pet: PetSettings,
    pub mask: MaskSettings,
    pub log: LoggingSettings,
    pub model: ModelSettings,
}

impl Settings {
    /// Loads and validates the settings via a configuration file.
    ///
    /// # Errors
    /// Fails when the loading of the configuration file or its validation failed.
    pub fn new(path: PathBuf) -> Result<Self, SettingsError> {
        let settings: Settings = Self::load(path)?;
        settings.validate()?;
        Ok(settings)
    }

    fn load(path: PathBuf) -> Result<Self, ConfigError> {
        let mut config = Config::new();
        config.merge(config::File::from(path))?;
        config.merge(Environment::with_prefix("xaynet").separator("__"))?;
        config.try_into()
    }
}

#[derive(Debug, Validate, Deserialize, Clone, Copy)]
#[validate(schema(function = "validate_pet"))]
/// PET protocol settings.
pub struct PetSettings {
    #[validate(range(min = 1))]
    /// The minimal number of participants selected for computing the unmasking sum. The value must
    /// be greater or equal to `1` (i.e. `min_sum_count >= 1`), otherwise the PET protocol will be
    /// broken.
    ///
    /// This parameter should only be used to enforce security constraints. To control the expected
    /// number of sum participants, the `sum` fraction should be adjusted wrt the total number of
    /// `expected_participants`.
    ///
    /// # Examples
    ///
    /// **TOML**
    /// ```text
    /// [pet]
    /// min_sum_count = 1
    /// ```
    ///
    /// **Environment variable**
    /// ```text
    /// XAYNET_PET__MIN_SUM_COUNT=1
    /// ```
    pub min_sum_count: usize,

    #[validate(range(min = 3))]
    /// The expected fraction of participants selected for submitting an updated local model for
    /// aggregation. The value must be greater or equal to `3` (i.e. `min_update_count >= 3`),
    /// otherwise the PET protocol will be broken.
    ///
    /// This parameter should only be used to enforce security constraints. To control the expected
    /// number of update participants, the `update` fraction should be adjusted wrt the total number
    /// of `expected_participants`.
    ///
    /// # Examples
    ///
    /// **TOML**
    /// ```text
    /// [pet]
    /// min_update_count = 3
    /// ```
    ///
    /// **Environment variable**
    /// ```text
    /// XAYNET_PET__MIN_UPDATE_COUNT=3
    /// ```
    pub min_update_count: usize,

    /// The minimum amount of time reserved for processing messages in the `sum`
    /// and `sum2` phases, in seconds.
    ///
    /// Defaults to 0 i.e. `sum` and `sum2` phases end *as soon as*
    /// [`min_sum_count`] messages have been processed. Set this higher to allow
    /// for the possibility of more than [`min_sum_count`] messages to be
    /// processed in the `sum` and `sum2` phases.
    ///
    /// # Examples
    ///
    /// **TOML**
    /// ```text
    /// [pet]
    /// min_sum_time = 5
    /// ```
    ///
    /// **Environment variable**
    /// ```text
    /// XAYNET_PET__MIN_SUM_TIME=5
    /// ```
    pub min_sum_time: u64,

    /// The minimum amount of time reserved for processing messages in the
    /// `update` phase, in seconds.
    ///
    /// Defaults to 0 i.e. `update` phase ends *as soon as* [`min_update_count`]
    /// messages have been processed. Set this higher to allow for the
    /// possibility of more than [`min_update_count`] messages to be processed
    /// in the `update` phase.
    ///
    /// # Examples
    ///
    /// **TOML**
    /// ```text
    /// [pet]
    /// min_update_time = 10
    /// ```
    ///
    /// **Environment variable**
    /// ```text
    /// XAYNET_PET__MIN_UPDATE_TIME=10
    /// ```
    pub min_update_time: u64,

    /// The maximum amount of time permitted for processing messages in the `sum`
    /// and `sum2` phases, in seconds.
    ///
    /// Defaults to a large number (effectively 1 week). Set this lower to allow
    /// for the processing of [`min_sum_count`] messages to time-out sooner in
    /// the `sum` and `sum2` phases.
    ///
    /// # Examples
    ///
    /// **TOML**
    /// ```text
    /// [pet]
    /// max_sum_time = 30
    /// ```
    ///
    /// **Environment variable**
    /// ```text
    /// XAYNET_PET__MAX_SUM_TIME=30
    /// ```
    pub max_sum_time: u64,

    /// The maximum amount of time permitted for processing messages in the
    /// `update` phase, in seconds.
    ///
    /// Defaults to a large number (effectively 1 week). Set this lower to allow
    /// for the processing of [`min_update_count`] messages to time-out sooner
    /// in the `update` phase.
    ///
    /// # Examples
    ///
    /// **TOML**
    /// ```text
    /// [pet]
    /// max_update_time = 60
    /// ```
    ///
    /// **Environment variable**
    /// ```text
    /// XAYNET_PET__MAX_UPDATE_TIME=60
    /// ```
    pub max_update_time: u64,

    /// The expected fraction of participants selected for computing the unmasking sum. The value
    /// must be between `0` and `1` (i.e. `0 < sum < 1`).
    ///
    /// Additionally, it is enforced that `0 < sum + update - sum*update < 1` to avoid pathological
    /// cases of deadlocks.
    ///
    /// # Examples
    ///
    /// **TOML**
    /// ```text
    /// [pet]
    /// sum = 0.01
    /// ```
    ///
    /// **Environment variable**
    /// ```text
    /// XAYNET_PET__SUM=0.01
    /// ```
    pub sum: f64,

    /// The expected fraction of participants selected for submitting an updated local model for
    /// aggregation. The value must be between `0` and `1` (i.e. `0 < update < 1`).
    ///
    /// Additionally, it is enforced that `0 < sum + update - sum*update < 1` to avoid pathological
    /// cases of deadlocks.
    ///
    /// # Examples
    ///
    /// **TOML**
    /// ```text
    /// [pet]
    /// update = 0.01
    /// ```
    ///
    /// **Environment variable**
    /// ```text
    /// XAYNET_PET__UPDATE=0.01
    /// ```
    pub update: f64,

    #[validate(range(min = 1))]
    /// The total number of participants that are expected by the coordinator. The value must be a
    /// positive integer (i.e. `expected_participants >= 1`).
    ///
    /// # Examples
    ///
    /// **TOML**
    /// ```text
    /// [pet]
    /// expected_participants = 10
    /// ```
    ///
    /// **Environment variable**
    /// ```text
    /// XAYNET_PET__EXPECTED_PARTICIPANTS=10
    /// ```
    pub expected_participants: usize,
}

impl Default for PetSettings {
    fn default() -> Self {
        Self {
            min_sum_count: 1_usize,
            min_update_count: 3_usize,
            min_sum_time: 0_u64,
            min_update_time: 0_u64,
            max_sum_time: 604800_u64,
            max_update_time: 604800_u64,
            sum: 0.01_f64,
            update: 0.1_f64,
            expected_participants: 10,
        }
    }
}

/// Checks PET settings.
fn validate_pet(s: &PetSettings) -> Result<(), ValidationError> {
    validate_phase_times(s)?;
    validate_fractions(s)
}

/// Checks validity of phase time ranges.
fn validate_phase_times(s: &PetSettings) -> Result<(), ValidationError> {
    if s.min_sum_time <= s.max_sum_time && s.min_update_time <= s.max_update_time {
        Ok(())
    } else {
        Err(ValidationError::new("invalid phase time range(s)"))
    }
}

/// Checks pathological cases of deadlocks.
fn validate_fractions(s: &PetSettings) -> Result<(), ValidationError> {
    if 0. < s.sum
        && s.sum < 1.
        && 0. < s.update
        && s.update < 1.
        && 0. < s.sum + s.update - s.sum * s.update
        && s.sum + s.update - s.sum * s.update < 1.
    {
        Ok(())
    } else {
        Err(ValidationError::new("starvation"))
    }
}

#[derive(Debug, Validate, Deserialize, Clone, Copy)]
/// REST API settings.
pub struct ApiSettings {
    /// The address to which the REST API should be bound.
    ///
    /// # Examples
    ///
    /// **TOML**
    /// ```text
    /// [api]
    /// bind_address = "0.0.0.0:8081"
    /// # or
    /// bind_address = "127.0.0.1:8081"
    /// ```
    ///
    /// **Environment variable**
    /// ```text
    /// XAYNET_API__BIND_ADDRESS=127.0.0.1:8081
    /// ```
    pub bind_address: std::net::SocketAddr,
}

#[derive(Debug, Validate, Deserialize, Clone, Copy)]
/// Masking settings.
pub struct MaskSettings {
    /// The order of the finite group.
    ///
    /// # Examples
    ///
    /// **TOML**
    /// ```text
    /// [mask]
    /// group_type = "Integer"
    /// ```
    ///
    /// **Environment variable**
    /// ```text
    /// XAYNET_MASK__GROUP_TYPE=Integer
    /// ```
    pub group_type: GroupType,

    /// The data type of the numbers to be masked.
    ///
    /// # Examples
    ///
    /// **TOML**
    /// ```text
    /// [mask]
    /// data_type = "F32"
    /// ```
    ///
    /// **Environment variable**
    /// ```text
    /// XAYNET_MASK__DATA_TYPE=F32
    /// ```
    pub data_type: DataType,

    /// The bounds of the numbers to be masked.
    ///
    /// # Examples
    ///
    /// **TOML**
    /// ```text
    /// [mask]
    /// bound_type = "B0"
    /// ```
    ///
    /// **Environment variable**
    /// ```text
    /// XAYNET_MASK__BOUND_TYPE=B0
    /// ```
    pub bound_type: BoundType,

    /// The maximum number of models to be aggregated.
    ///
    /// # Examples
    ///
    /// **TOML**
    /// ```text
    /// [mask]
    /// model_type = "M3"
    /// ```
    ///
    /// **Environment variable**
    /// ```text
    /// XAYNET_MASK__MODEL_TYPE=M3
    /// ```
    pub model_type: ModelType,
}

impl Default for MaskSettings {
    fn default() -> Self {
        Self {
            group_type: GroupType::Prime,
            data_type: DataType::F32,
            bound_type: BoundType::B0,
            model_type: ModelType::M3,
        }
    }
}

impl From<MaskSettings> for MaskConfig {
    fn from(
        MaskSettings {
            group_type,
            data_type,
            bound_type,
            model_type,
        }: MaskSettings,
    ) -> MaskConfig {
        MaskConfig {
            group_type,
            data_type,
            bound_type,
            model_type,
        }
    }
}

#[derive(Debug, Deserialize)]
/// Model settings.
pub struct ModelSettings {
    /// The expected size of the model. The model size corresponds to the number of elements.
    /// This value is used to validate the uniform length of the submitted models/masks.
    ///
    /// # Examples
    ///
    /// **TOML**
    /// ```text
    /// [model]
    /// size = 100
    /// ```
    ///
    /// **Environment variable**
    /// ```text
    /// XAYNET_MODEL__SIZE=100
    /// ```
    pub size: usize,
}

#[derive(Debug, Deserialize)]
/// Logging settings.
pub struct LoggingSettings {
    /// A comma-separated list of logging directives. More information about logging directives
    /// can be found [here].
    ///
    /// # Examples
    ///
    /// **TOML**
    /// ```text
    /// [log]
    /// filter = "info"
    /// ```
    ///
    /// **Environment variable**
    /// ```text
    /// XAYNET_LOG__FILTER=info
    /// ```
    ///
    /// [here]: https://docs.rs/tracing-subscriber/0.2.6/tracing_subscriber/filter/struct.EnvFilter.html#directives
    #[serde(deserialize_with = "deserialize_env_filter")]
    pub filter: EnvFilter,
}

fn deserialize_env_filter<'de, D>(deserializer: D) -> Result<EnvFilter, D::Error>
where
    D: Deserializer<'de>,
{
    struct EnvFilterVisitor;

    impl<'de> Visitor<'de> for EnvFilterVisitor {
        type Value = EnvFilter;

        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
            write!(formatter, "a valid tracing filter directive: https://docs.rs/tracing-subscriber/0.2.6/tracing_subscriber/filter/struct.EnvFilter.html#directives")
        }

        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
        where
            E: de::Error,
        {
            EnvFilter::try_new(value)
                .map_err(|_| de::Error::invalid_value(serde::de::Unexpected::Str(value), &self))
        }
    }

    deserializer.deserialize_str(EnvFilterVisitor)
}