use super::*;
use crate::computation::*;
use crate::error::Result;
use crate::execution::Session;
use crate::host::{HostFixedTensor, HostPlacement};
use crate::kernels::*;
use crate::replicated::{RepFixedTensor, RepTensor, ReplicatedPlacement};
impl MirrorOp {
pub(crate) fn kernel<S: Session, HostT>(
sess: &S,
plc: &Mirrored3Placement,
x: HostT,
) -> Result<Mir3Tensor<HostT>>
where
HostT: Clone,
Mirrored3Placement: PlacementPlace<S, Mir3Tensor<HostT>>,
{
let x_mirrored = Mir3Tensor {
values: [x.clone(), x.clone(), x],
};
Ok(plc.place(sess, x_mirrored))
}
pub(crate) fn fixed_kernel<S: Session, HostRingT, MirRingT>(
sess: &S,
plc: &Mirrored3Placement,
x: HostFixedTensor<HostRingT>,
) -> Result<MirFixedTensor<MirRingT>>
where
Mirrored3Placement: PlacementMirror<S, HostRingT, MirRingT>,
{
Ok(MirFixedTensor {
tensor: plc.mirror(sess, &x.tensor),
fractional_precision: x.fractional_precision,
integral_precision: x.integral_precision,
})
}
}
impl DemirrorOp {
pub(crate) fn kernel<S: Session, R: Clone>(
sess: &S,
receiver: &HostPlacement,
x: Mir3Tensor<R>,
) -> Result<R>
where
HostPlacement: PlacementPlace<S, R>,
R: Placed<Placement = HostPlacement>,
{
let mir = x.placement()?;
let Mir3Tensor {
values: [x0, x1, x2],
} = x;
let (player0, player1, _player2) = &mir.host_placements();
let res = match () {
_ if receiver == player0 => x0,
_ if receiver == player1 => x1,
_ => receiver.place(sess, x2),
};
Ok(res)
}
pub(crate) fn fixed_kernel<S: Session, MirRingT, HostRingT>(
sess: &S,
receiver: &HostPlacement,
x: MirFixedTensor<MirRingT>,
) -> Result<HostFixedTensor<HostRingT>>
where
HostPlacement: PlacementDemirror<S, MirRingT, HostRingT>,
{
let dx = receiver.demirror(sess, &x.tensor);
Ok(HostFixedTensor {
tensor: dx,
fractional_precision: x.fractional_precision,
integral_precision: x.integral_precision,
})
}
}
impl RingFixedpointEncodeOp {
pub(crate) fn mir_kernel<S: Session, HostFloatT, HostRingT>(
sess: &S,
plc: &Mirrored3Placement,
scaling_base: u64,
scaling_exp: u32,
x: Mir3Tensor<HostFloatT>,
) -> Result<Mir3Tensor<HostRingT>>
where
HostPlacement: PlacementRingFixedpointEncode<S, HostFloatT, HostRingT>,
{
let (player0, player1, player2) = plc.host_placements();
let Mir3Tensor {
values: [x0, x1, x2],
} = &x;
let y0 = player0.fixedpoint_ring_encode(sess, scaling_base, scaling_exp, x0);
let y1 = player1.fixedpoint_ring_encode(sess, scaling_base, scaling_exp, x1);
let y2 = player2.fixedpoint_ring_encode(sess, scaling_base, scaling_exp, x2);
Ok(Mir3Tensor {
values: [y0, y1, y2],
})
}
}
impl RingFixedpointDecodeOp {
pub(crate) fn mir_kernel<S: Session, HostRingT, HostFloatT>(
sess: &S,
plc: &Mirrored3Placement,
scaling_base: u64,
scaling_exp: u32,
x: Mir3Tensor<HostRingT>,
) -> Result<Mir3Tensor<HostFloatT>>
where
HostPlacement: PlacementRingFixedpointDecode<S, HostRingT, HostFloatT>,
{
let (player0, player1, player2) = plc.host_placements();
let Mir3Tensor {
values: [x0, x1, x2],
} = &x;
let y0 = player0.fixedpoint_ring_decode(sess, scaling_base, scaling_exp, x0);
let y1 = player1.fixedpoint_ring_decode(sess, scaling_base, scaling_exp, x1);
let y2 = player2.fixedpoint_ring_decode(sess, scaling_base, scaling_exp, x2);
Ok(Mir3Tensor {
values: [y0, y1, y2],
})
}
}
impl ShareOp {
pub(crate) fn fixed_mir_kernel<S: Session, MirRingT, RepRingT>(
sess: &S,
plc: &ReplicatedPlacement,
x: MirFixedTensor<MirRingT>,
) -> Result<RepFixedTensor<RepRingT>>
where
ReplicatedPlacement: PlacementShare<S, MirRingT, RepRingT>,
{
Ok(RepFixedTensor {
tensor: plc.share(sess, &x.tensor),
fractional_precision: x.fractional_precision,
integral_precision: x.integral_precision,
})
}
pub(crate) fn ring_mir_kernel<S: Session, HostRingT, RepRingT>(
sess: &S,
plc: &ReplicatedPlacement,
x: Mir3Tensor<HostRingT>,
) -> Result<RepRingT>
where
HostRingT: Clone,
ReplicatedPlacement: PlacementShare<S, HostRingT, RepRingT>,
HostPlacement: PlacementPlace<S, HostRingT>,
HostRingT: Placed<Placement = HostPlacement>,
{
let Mir3Tensor {
values: [x0, _x1, _x2],
} = x;
Ok(plc.share(sess, &x0))
}
}
impl RevealOp {
pub(crate) fn mir_ring_kernel<S: Session, HostRingT: Clone>(
sess: &S,
mir: &Mirrored3Placement,
x: RepTensor<HostRingT>,
) -> Result<Mir3Tensor<HostRingT>>
where
RepTensor<HostRingT>: CanonicalType,
<RepTensor<HostRingT> as CanonicalType>::Type: KnownType<S>,
RepTensor<HostRingT>: Into<m!(c!(RepTensor<HostRingT>))>,
HostPlacement: PlacementReveal<S, m!(c!(RepTensor<HostRingT>)), HostRingT>,
{
let (player0, player1, player2) = mir.host_placements();
let x0 = player0.reveal(sess, &x.clone().into());
let x1 = player1.reveal(sess, &x.clone().into());
let x2 = player2.reveal(sess, &x.into());
Ok(Mir3Tensor {
values: [x0, x1, x2],
})
}
pub(crate) fn mir_fixed_kernel<S: Session, RepRingT, MirRingT>(
sess: &S,
receiver: &Mirrored3Placement,
xe: RepFixedTensor<RepRingT>,
) -> Result<MirFixedTensor<MirRingT>>
where
Mirrored3Placement: PlacementReveal<S, RepRingT, MirRingT>,
{
let x = receiver.reveal(sess, &xe.tensor);
Ok(MirFixedTensor {
tensor: x,
fractional_precision: xe.fractional_precision,
integral_precision: xe.integral_precision,
})
}
}