[][src]Module xaynet::mask

Masking, aggregation and unmasking of models.

Models

A Model is a collection of weights/parameters which are represented as finite numerical values (i.e. rational numbers) of arbitrary precision. As such, a model in itself is not bound to any particular primitive data type, but it can be created from those and converted back into them.

Currently, the primitive data types f32, f64, i32 and i64 are supported and this might be extended in the future.

let weights = vec![0_f32; 10];
let model = Model::from_primitives_bounded(weights.into_iter());
assert_eq!(
    model.into_primitives_unchecked().collect::<Vec<f32>>(),
    vec![0_f32; 10],
);

Masking configurations

The masking, aggregation and unmasking of models requires certain information about the models to guarantee that no information is lost during the process, which is configured via the MaskConfig. Each masking configuration consists of the group type, data type, bound type and model type. Usually, a masking configuration is decided on and configured depending on the specific machine learning use case as part of the setup for the XayNet federated learning platform.

Currently, those choices are catalogued for certain fixed variants for each type, but we aim to generalize this in the future to more flexible masking configurations to allow for a more fine-grained tradeoff between representability and performance.

Group type

The GroupType describes the order of the finite group in which the masked model weights are embedded. The smaller the gap between the maximum possible embedded weights and the group order is, the less theoretically possible information flow about the masks may be observed. Specific group orders provide potentially higher performance on the other hand, which always makes this a tradeoff between security and performance. The group type variants are:

  • Integer: no gap but potentially slowest performance.
  • Prime: usually small gap with higher performance.
  • Power2: usually higher gap with potentially highest performance.

Data type

The DataType describes the original primitive data type of the model weights. This in combination with the bound type influences the preserved decimal places of the model weights during the masking, aggregation and unmasking process, which are:

  • F32: 10 decimal places for bounded model weights and 45 decimal places for unbounded.
  • F64: 20 decimal places for bounded model weights and 324 decimal places for unbounded.
  • I32 and I64: 10 decimal places (required for scaled aggregation).

Currently the primitive data types f32, f64, i32 and i64 are supported via the data type variants.

Bound type

The BoundType describes the absolute bounds on all model weights. The smaller the bounds of the model weights, the less bytes are required to represent the masked model weights. These bounds are enforced on the model weights before masking them to prevent information loss during the masking, aggregation and unmasking process. The bound type variants are:

  • B0: all model weights are absolutely bounded by 1.
  • B2: all model weights are absolutely bounded by 100.
  • B4: all model weights are absolutely bounded by 10,000.
  • B6: all model weights are absolutely bounded by 1,000,000.
  • Bmax: all model weights are absolutely bounded by their primitive data type's absolute maximum value.

Model type

The ModelType describes the maximum number of masked models that can be aggregated without information loss. The smaller the number of masked models, the less bytes are required to represent masked model weights. The model type variants are:

  • M3: at most 1,000 masked models may be aggregated.
  • M6: at most 1,000,000 masked models may be aggregated.
  • M9: at most 1,000,000,000 masked models may be aggregated.
  • M12: at most 1,000,000,000,000 masked models may be aggregated.

Masking, aggregation and unmasking

Local models should be masked (i.e. encrypted) before they are communicated somewhere else to protect the possibly sensitive information learned from local data. The masking should allow for masked models to be aggregated while they are still masked (i.e. homomorphic encryption). Then the aggregated masked model can safely be unmasked without jeopardizing the secrecy of personal information if the model is generalized enough.

Masking

A Model can be masked with a Masker, which requires a MaskConfig. During the masking, the model weights are scaled, then embedded as elements of the chosen finite group and finally masked by randomly generated elements from that very same finite group. The scalar provides the necessary means to perform different aggregation strategies, for example federated averaging. The masked model is returned as a MaskObject and the mask used to mask the model can be generated via the additionally returned MaskSeed.

// create local models and a fitting masking configuration
let number_weights = 10;
let scalar = 0.5;
let local_model_1 = Model::from_primitives_bounded(vec![0_f32; number_weights].into_iter());
let local_model_2 = Model::from_primitives_bounded(vec![1_f32; number_weights].into_iter());
let config = MaskConfig {
    group_type: GroupType::Prime,
    data_type: DataType::F32,
    bound_type: BoundType::B0,
    model_type: ModelType::M3,
};

// mask the local models
let (local_mask_seed_1, masked_local_model_1) = Masker::new(config).mask(scalar, local_model_1);
let (local_mask_seed_2, masked_local_model_2) = Masker::new(config).mask(scalar, local_model_2);

// derive the masks of the local masked models
let local_mask_1 = local_mask_seed_1.derive_mask(number_weights, config);
let local_mask_2 = local_mask_seed_2.derive_mask(number_weights, config);

Aggregation

Masked models can be aggregated via an Aggregation. Masks themselves can be aggregated via an Aggregation as well. An aggregated masked model can only be unmasked by the aggregation of masks for each model. Aggregation should always be validated beforehand so that it may be safely performed wrt the chosen masking configuration without possible loss of information.

// aggregate the local masks
let mut mask_aggregator = Aggregation::new(config, number_weights);
if let Ok(_) = mask_aggregator.validate_aggregation(&local_mask_1) {
    mask_aggregator.aggregate(local_mask_1);
};
if let Ok(_) = mask_aggregator.validate_aggregation(&local_mask_2) {
    mask_aggregator.aggregate(local_mask_2);
};
let global_mask: MaskObject = mask_aggregator.into();

// aggregate the local masked models
let mut model_aggregator = Aggregation::new(config, number_weights);
if let Ok(_) = model_aggregator.validate_aggregation(&masked_local_model_1) {
    model_aggregator.aggregate(masked_local_model_1);
};
if let Ok(_) = model_aggregator.validate_aggregation(&masked_local_model_2) {
    model_aggregator.aggregate(masked_local_model_2);
};

Unmasking

A masked model can be unmasked by the corresponding mask via an Aggregation. Unmasking should always be validated beforehand so that it may be safely performed wrt the chosen mask configuration without possible loss of information.

// unmask the aggregated masked model with the aggregated mask
if let Ok(_) = model_aggregator.validate_unmasking(&global_mask) {
    let global_model = model_aggregator.unmask(global_mask);
    assert_eq!(
        global_model,
        Model::from_primitives_bounded(vec![0.5_f32; number_weights].into_iter()),
    );
};

Structs

Aggregation

An aggregator for masks and masked models.

EncryptedMaskSeed

An encrypted mask seed.

InvalidMaskObjectError

Errors related to invalid mask objects.

MaskConfig

A masking configuration.

MaskConfigBuffer

A buffer for serialized masking configurations.

MaskObject

A mask object which represents either a mask or a masked model.

MaskObjectBuffer

A buffer for serialized mask objects.

MaskSeed

A seed to generate a mask.

Masker

A masker for models.

Model

A numerical representation of a machine learning model.

ModelCastError

Errors related to model conversion into primitives.

PrimitiveCastError

Errors related to model conversion from primitives.

Enums

AggregationError

Errors related to the aggregation of masks and models.

BoundType

The bounds of the numerical values.

DataType

The original primitive data type of the numerical values to be masked.

GroupType

The order of the finite group.

InvalidMaskConfigError

Errors related to invalid masking configurations.

ModelType

The maximum number of models to be aggregated.

UnmaskingError

Errors related to the unmasking of models.

Traits

FromPrimitives

An interface to convert a collection of primitive values into an iterator of numerical values.

IntoPrimitives

An interface to convert a collection of numerical values into an iterator of primitive values.