use super::*;
use crate::host::SyncKey;
pub trait PlacementDeriveSeed<S: Session, KeyT, SeedT> {
fn derive_seed(&self, sess: &S, sync_key: SyncKey, key: &KeyT) -> SeedT;
}
modelled_kernel! {
PlacementDeriveSeed::derive_seed, DeriveSeedOp{sync_key: SyncKey},
[
(HostPlacement, (HostPrfKey) -> HostSeed => [runtime] Self::kernel),
]
}
pub trait PlacementSample<S: Session, ShapeT, O> {
fn sample(&self, sess: &S, max_value: Option<u64>, shape: &ShapeT) -> O;
}
pub trait PlacementSampleUniform<S: Session, ShapeT, O> {
fn sample_uniform(&self, sess: &S, shape: &ShapeT) -> O;
}
modelled_kernel! {
PlacementSample::sample, SampleOp{max_value: Option<u64>},
[
(HostPlacement, (HostShape) -> HostBitTensor => [runtime] Self::bit_kernel),
(HostPlacement, (HostShape) -> HostRing64Tensor => [runtime] Self::ring64_kernel),
(HostPlacement, (HostShape) -> HostRing128Tensor => [runtime] Self::ring128_kernel),
]
}
impl<S: Session, ShapeT, O, P> PlacementSampleUniform<S, ShapeT, O> for P
where
P: PlacementSample<S, ShapeT, O>,
{
fn sample_uniform(&self, sess: &S, shape: &ShapeT) -> O {
self.sample(sess, None, shape)
}
}
pub trait PlacementSampleBits<S: Session, ShapeT, O> {
fn sample_bits(&self, sess: &S, shape: &ShapeT) -> O;
}
impl<S: Session, ShapeT, O, P> PlacementSampleBits<S, ShapeT, O> for P
where
P: PlacementSample<S, ShapeT, O>,
{
fn sample_bits(&self, sess: &S, shape: &ShapeT) -> O {
self.sample(sess, Some(1), shape)
}
}
pub trait PlacementSampleSeeded<S: Session, ShapeT, SeedT, O> {
fn sample_seeded(&self, sess: &S, max_value: Option<u64>, shape: &ShapeT, seed: &SeedT) -> O;
}
modelled_kernel! {
PlacementSampleSeeded::sample_seeded, SampleSeededOp{max_value: Option<u64>},
[
(HostPlacement, (HostShape, HostSeed) -> HostBitTensor => [runtime] Self::bit_kernel),
(HostPlacement, (HostShape, HostSeed) -> HostRing64Tensor => [runtime] Self::ring64_kernel),
(HostPlacement, (HostShape, HostSeed) -> HostRing128Tensor => [runtime] Self::ring128_kernel),
]
}
pub trait PlacementSampleUniformSeeded<S: Session, ShapeT, SeedT, O> {
fn sample_uniform_seeded(&self, sess: &S, shape: &ShapeT, seed: &SeedT) -> O;
}
impl<S: Session, ShapeT, SeedT, O, P> PlacementSampleUniformSeeded<S, ShapeT, SeedT, O> for P
where
P: PlacementSampleSeeded<S, ShapeT, SeedT, O>,
{
fn sample_uniform_seeded(&self, sess: &S, shape: &ShapeT, seed: &SeedT) -> O {
self.sample_seeded(sess, None, shape, seed)
}
}
pub trait PlacementSampleBitsSeeded<S: Session, ShapeT, SeedT, O> {
fn sample_bits_seeded(&self, sess: &S, shape: &ShapeT, seed: &SeedT) -> O;
}
impl<S: Session, ShapeT, SeedT, O, P> PlacementSampleBitsSeeded<S, ShapeT, SeedT, O> for P
where
P: PlacementSampleSeeded<S, ShapeT, SeedT, O>,
{
fn sample_bits_seeded(&self, sess: &S, shape: &ShapeT, seed: &SeedT) -> O {
self.sample_seeded(sess, Some(1), shape, seed)
}
}