use super::*;
pub trait PlacementAdd<S: Session, T, U, O> {
fn add(&self, sess: &S, x: &T, y: &U) -> O;
}
modelled_kernel! {
PlacementAdd::add, AddOp,
[
(HostPlacement, (Tensor, Tensor) -> Tensor => [concrete] Self::logical_host_kernel),
(ReplicatedPlacement, (Tensor, Tensor) -> Tensor => [concrete] Self::logical_rep_kernel),
(HostPlacement, (Float32Tensor, Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor, Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Fixed64Tensor, Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Fixed128Tensor, Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_host_kernel),
(ReplicatedPlacement, (Fixed64Tensor, Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor, Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(HostPlacement, (HostFixed64Tensor, HostFixed64Tensor) -> HostFixed64Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (HostFixed128Tensor, HostFixed128Tensor) -> HostFixed128Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (HostFloat32Tensor, HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor, HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt8Tensor, HostInt8Tensor) -> HostInt8Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt16Tensor, HostInt16Tensor) -> HostInt16Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt32Tensor, HostInt32Tensor) -> HostInt32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt64Tensor, HostInt64Tensor) -> HostInt64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostRing64Tensor, HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::ring_kernel),
(HostPlacement, (HostRing128Tensor, HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::ring_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor, ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor, ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor, Mirrored3Fixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::repfixed_mirfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor, Mirrored3Fixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::repfixed_mirfixed_kernel),
(ReplicatedPlacement, (Mirrored3Fixed64Tensor, ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::mirfixed_repfixed_kernel),
(ReplicatedPlacement, (Mirrored3Fixed128Tensor, ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::mirfixed_repfixed_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor, ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_rep_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor, ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_rep_kernel),
(ReplicatedPlacement, (ReplicatedBitTensor, ReplicatedBitTensor) -> ReplicatedBitTensor => [concrete] Self::rep_rep_kernel),
(ReplicatedPlacement, (Mirrored3Ring64Tensor, ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::mir_rep_kernel),
(ReplicatedPlacement, (Mirrored3Ring128Tensor, ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::mir_rep_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor, Mirrored3Ring64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_mir_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor, Mirrored3Ring128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_mir_kernel),
(ReplicatedPlacement, (Mirrored3BitTensor, ReplicatedBitTensor) -> ReplicatedBitTensor => [concrete] Self::mir_rep_kernel),
(ReplicatedPlacement, (ReplicatedBitTensor, Mirrored3BitTensor) -> ReplicatedBitTensor => [concrete] Self::rep_mir_kernel),
(AdditivePlacement, (AdditiveRing64Tensor, AdditiveRing64Tensor) -> AdditiveRing64Tensor => [concrete] Self::adt_adt_kernel),
(AdditivePlacement, (AdditiveRing128Tensor, AdditiveRing128Tensor) -> AdditiveRing128Tensor => [concrete] Self::adt_adt_kernel),
(AdditivePlacement, (AdditiveBitTensor, AdditiveBitTensor) -> AdditiveBitTensor => [concrete] Self::adt_adt_kernel),
(AdditivePlacement, (AdditiveRing64Tensor, HostRing64Tensor) -> AdditiveRing64Tensor => [hybrid] Self::adt_host_kernel),
(AdditivePlacement, (AdditiveRing128Tensor, HostRing128Tensor) -> AdditiveRing128Tensor => [hybrid] Self::adt_host_kernel),
(AdditivePlacement, (AdditiveBitTensor, HostBitTensor) -> AdditiveBitTensor => [hybrid] Self::adt_host_kernel),
(AdditivePlacement, (HostRing64Tensor, AdditiveRing64Tensor) -> AdditiveRing64Tensor => [hybrid] Self::host_adt_kernel),
(AdditivePlacement, (HostRing128Tensor, AdditiveRing128Tensor) -> AdditiveRing128Tensor => [hybrid] Self::host_adt_kernel),
(AdditivePlacement, (HostBitTensor, AdditiveBitTensor) -> AdditiveBitTensor => [hybrid] Self::host_adt_kernel),
]
}
pub trait PlacementSub<S: Session, T, U, O> {
fn sub(&self, sess: &S, x: &T, y: &U) -> O;
}
modelled_kernel! {
PlacementSub::sub, SubOp,
[
(HostPlacement, (Tensor, Tensor) -> Tensor => [concrete] Self::logical_host_kernel),
(HostPlacement, (Fixed64Tensor, Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Fixed128Tensor, Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Float32Tensor, Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor, Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (HostFixed64Tensor, HostFixed64Tensor) -> HostFixed64Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (HostFixed128Tensor, HostFixed128Tensor) -> HostFixed128Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (HostFloat32Tensor, HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor, HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostRing64Tensor, HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::ring_kernel),
(HostPlacement, (HostRing128Tensor, HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::ring_kernel),
(ReplicatedPlacement, (Tensor, Tensor) -> Tensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (Fixed64Tensor, Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor, Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor, ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor, ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor, ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_rep_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor, ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_rep_kernel),
(ReplicatedPlacement, (ReplicatedBitTensor, ReplicatedBitTensor) -> ReplicatedBitTensor => [concrete] Self::rep_rep_kernel),
(ReplicatedPlacement, (Mirrored3Ring64Tensor, ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::mir_rep_kernel),
(ReplicatedPlacement, (Mirrored3Ring128Tensor, ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::mir_rep_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor, Mirrored3Ring64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_mir_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor, Mirrored3Ring128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_mir_kernel),
(AdditivePlacement, (AdditiveRing64Tensor, AdditiveRing64Tensor) -> AdditiveRing64Tensor => [concrete] Self::adt_adt_kernel),
(AdditivePlacement, (AdditiveRing128Tensor, AdditiveRing128Tensor) -> AdditiveRing128Tensor => [concrete] Self::adt_adt_kernel),
(AdditivePlacement, (AdditiveBitTensor, AdditiveBitTensor) -> AdditiveBitTensor => [concrete] Self::adt_adt_kernel),
(AdditivePlacement, (AdditiveRing64Tensor, HostRing64Tensor) -> AdditiveRing64Tensor => [hybrid] Self::adt_host_kernel),
(AdditivePlacement, (AdditiveRing128Tensor, HostRing128Tensor) -> AdditiveRing128Tensor => [hybrid] Self::adt_host_kernel),
(AdditivePlacement, (HostRing64Tensor, AdditiveRing64Tensor) -> AdditiveRing64Tensor => [hybrid] Self::host_adt_kernel),
(AdditivePlacement, (HostRing128Tensor, AdditiveRing128Tensor) -> AdditiveRing128Tensor => [hybrid] Self::host_adt_kernel),
(HostPlacement, (HostInt8Tensor, HostInt8Tensor) -> HostInt8Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt16Tensor, HostInt16Tensor) -> HostInt16Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt32Tensor, HostInt32Tensor) -> HostInt32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt64Tensor, HostInt64Tensor) -> HostInt64Tensor => [runtime] Self::host_kernel),
]
}
pub trait PlacementNeg<S: Session, T, O> {
fn neg(&self, sess: &S, x: &T) -> O;
}
modelled_kernel! {
PlacementNeg::neg, NegOp,
[
(HostPlacement, (HostBitTensor) -> HostBitTensor => [runtime] Self::bit_kernel),
(HostPlacement, (HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::ring_kernel),
(HostPlacement, (HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::ring_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedBitTensor) -> ReplicatedBitTensor => [concrete] Self::rep_bit_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_rep_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_rep_kernel),
]
}
pub trait PlacementMul<S: Session, T, U, O> {
fn mul(&self, sess: &S, x: &T, y: &U) -> O;
}
modelled_kernel! {
PlacementMul::mul, MulOp,
[
(HostPlacement, (Tensor, Tensor) -> Tensor => [concrete] attributes[sig] Self::logical_host_kernel),
(ReplicatedPlacement, (Tensor, Tensor) -> Tensor => [concrete] attributes[sig] Self::logical_rep_kernel),
(HostPlacement, (HostFloat32Tensor, HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor, HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt8Tensor, HostInt8Tensor) -> HostInt8Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt16Tensor, HostInt16Tensor) -> HostInt16Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt32Tensor, HostInt32Tensor) -> HostInt32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt64Tensor, HostInt64Tensor) -> HostInt64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostRing64Tensor, HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::ring_kernel),
(HostPlacement, (HostRing128Tensor, HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::ring_kernel),
(HostPlacement, (Float32Tensor, Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor, Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Fixed64Tensor, Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Fixed128Tensor, Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (HostFixed64Tensor, HostFixed64Tensor) -> HostFixed64Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (HostFixed128Tensor, HostFixed128Tensor) -> HostFixed128Tensor => [concrete] Self::hostfixed_kernel),
(ReplicatedPlacement, (Fixed64Tensor, Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor, Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor, ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_rep_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor, ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_rep_kernel),
(ReplicatedPlacement, (ReplicatedBitTensor, ReplicatedBitTensor) -> ReplicatedBitTensor => [concrete] Self::rep_rep_kernel),
(ReplicatedPlacement, (Mirrored3Ring128Tensor, ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::mir_rep_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor, Mirrored3Ring128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_mir_kernel),
(ReplicatedPlacement, (Mirrored3Ring64Tensor, ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::mir_rep_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor, Mirrored3Ring64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_mir_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor, ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor, ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor, Mirrored3Fixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::repfixed_mirfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor, Mirrored3Fixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::repfixed_mirfixed_kernel),
(ReplicatedPlacement, (Mirrored3Fixed64Tensor, ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::mirfixed_repfixed_kernel),
(ReplicatedPlacement, (Mirrored3Fixed128Tensor, ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::mirfixed_repfixed_kernel),
(AdditivePlacement, (HostRing64Tensor, AdditiveRing64Tensor) -> AdditiveRing64Tensor => [hybrid] Self::host_adt_kernel),
(AdditivePlacement, (AdditiveRing64Tensor, HostRing64Tensor) -> AdditiveRing64Tensor => [hybrid] Self::adt_host_kernel),
(AdditivePlacement, (AdditiveRing128Tensor, HostRing128Tensor) -> AdditiveRing128Tensor => [hybrid] Self::adt_host_kernel),
(AdditivePlacement, (HostRing128Tensor, AdditiveRing128Tensor) -> AdditiveRing128Tensor => [hybrid] Self::host_adt_kernel),
(AdditivePlacement, (AdditiveBitTensor, HostBitTensor) -> AdditiveBitTensor => [hybrid] Self::adt_host_kernel),
(AdditivePlacement, (HostBitTensor, AdditiveBitTensor) -> AdditiveBitTensor => [hybrid] Self::host_adt_kernel),
]
}
pub trait PlacementDiv<S: Session, T, U, O> {
fn div(&self, sess: &S, x: &T, y: &U) -> O;
}
modelled_kernel! {
PlacementDiv::div, DivOp,
[
(HostPlacement, (Tensor, Tensor) -> Tensor => [concrete] Self::logical_host_kernel),
(HostPlacement, (Fixed64Tensor, Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Fixed128Tensor, Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Float32Tensor, Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor, Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (HostFixed64Tensor, HostFixed64Tensor) -> HostFixed64Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (HostFixed128Tensor, HostFixed128Tensor) -> HostFixed128Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (HostRing64Tensor, HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::ring_kernel),
(HostPlacement, (HostRing128Tensor, HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::ring_kernel),
(HostPlacement, (HostFloat32Tensor, HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor, HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(ReplicatedPlacement, (Tensor, Tensor) -> Tensor => [concrete] Self::logical_rep_kernel),
(ReplicatedPlacement, (Fixed64Tensor, Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor, Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor, ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::rep_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor, ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::rep_rep_kernel),
(HostPlacement, (HostInt8Tensor, HostInt8Tensor) -> HostInt8Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt16Tensor, HostInt16Tensor) -> HostInt16Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt32Tensor, HostInt32Tensor) -> HostInt32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt64Tensor, HostInt64Tensor) -> HostInt64Tensor => [runtime] Self::host_kernel),
]
}
pub trait PlacementDot<S: Session, T, U, O> {
fn dot(&self, sess: &S, x: &T, y: &U) -> O;
}
modelled_kernel! {
PlacementDot::dot, DotOp,
[
(HostPlacement, (Tensor, Tensor) -> Tensor => [concrete] attributes[sig] Self::logical_host_kernel),
(HostPlacement, (Fixed64Tensor, Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_on_host_kernel),
(HostPlacement, (Fixed128Tensor, Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_on_host_kernel),
(HostPlacement, (Float32Tensor, Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor, Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (HostFloat32Tensor, HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor, HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFixed64Tensor, HostFixed64Tensor) -> HostFixed64Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (HostFixed128Tensor, HostFixed128Tensor) -> HostFixed128Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (HostRing64Tensor, HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::ring_kernel),
(HostPlacement, (HostRing128Tensor, HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::ring_kernel),
(ReplicatedPlacement, (Tensor, Tensor) -> Tensor => [concrete] attributes[sig] Self::logical_rep_kernel),
(ReplicatedPlacement, (Fixed64Tensor, Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_on_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor, Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_on_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor, ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor, ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor, ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_rep_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor, ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_rep_kernel),
]
}
pub trait PlacementShl<S: Session, T, O> {
fn shl(&self, sess: &S, amount: usize, x: &T) -> O;
}
modelled_kernel! {
PlacementShl::shl, ShlOp{amount: usize},
[
(ReplicatedPlacement, (ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_kernel),
(AdditivePlacement, (AdditiveRing64Tensor) -> AdditiveRing64Tensor => [concrete] Self::adt_kernel),
(AdditivePlacement, (AdditiveRing128Tensor) -> AdditiveRing128Tensor => [concrete] Self::adt_kernel),
(HostPlacement, (HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::ring_kernel),
(HostPlacement, (HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::ring_kernel),
]
}
pub trait PlacementShr<S: Session, T, O> {
fn shr(&self, sess: &S, amount: usize, x: &T) -> O;
}
modelled_kernel! {
PlacementShr::shr, ShrOp{amount: usize},
[
(HostPlacement, (HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::ring_kernel),
(HostPlacement, (HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::ring_kernel),
]
}
pub trait PlacementSqrt<S: Session, T, O> {
fn sqrt(&self, sess: &S, x: &T) -> O;
}
modelled_kernel! {
PlacementSqrt::sqrt, SqrtOp,
[
(HostPlacement, (Tensor) -> Tensor => [concrete] Self::logical_host_kernel),
(HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(ReplicatedPlacement, (Tensor) -> Tensor => [concrete] Self::logical_rep_kernel),
(ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [transparent] Self::rep_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [transparent] Self::rep_rep_kernel),
]
}
pub trait PlacementAddN<S: Session, T, O> {
fn add_n(&self, sess: &S, x: &[T]) -> O;
}
modelled_kernel! {
PlacementAddN::add_n, AddNOp,
[
(HostPlacement, vec[Tensor] -> Tensor => [concrete] Self::host_logical_kernel),
(HostPlacement, vec[Float32Tensor] -> Float32Tensor => [concrete] Self::float_kernel),
(HostPlacement, vec[Float64Tensor] -> Float64Tensor => [concrete] Self::float_kernel),
(HostPlacement, vec[HostFloat32Tensor] -> HostFloat32Tensor => [runtime] Self::host_float_kernel),
(HostPlacement, vec[HostFloat64Tensor] -> HostFloat64Tensor => [runtime] Self::host_float_kernel),
(HostPlacement, vec[Fixed64Tensor] -> Fixed64Tensor => [concrete] Self::fixed_kernel),
(HostPlacement, vec[Fixed128Tensor] -> Fixed128Tensor => [concrete] Self::fixed_kernel),
(HostPlacement, vec[HostFixed64Tensor] -> HostFixed64Tensor => [concrete] Self::host_fixed_kernel),
(HostPlacement, vec[HostFixed128Tensor] -> HostFixed128Tensor => [concrete] Self::host_fixed_kernel),
(HostPlacement, vec[HostRing64Tensor] -> HostRing64Tensor => [runtime] Self::host_kernel),
(HostPlacement, vec[HostRing128Tensor] -> HostRing128Tensor => [runtime] Self::host_kernel),
(ReplicatedPlacement, vec[ReplicatedRing64Tensor] -> ReplicatedRing64Tensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, vec[ReplicatedRing128Tensor] -> ReplicatedRing128Tensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, vec[ReplicatedFixed64Tensor] -> ReplicatedFixed64Tensor => [concrete] Self::rep_fixed_kernel),
(ReplicatedPlacement, vec[ReplicatedFixed128Tensor] -> ReplicatedFixed128Tensor => [concrete] Self::rep_fixed_kernel),
(ReplicatedPlacement, vec[Fixed64Tensor] -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, vec[Fixed128Tensor] -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, vec[Tensor] -> Tensor => [concrete] Self::logical_rep_kernel),
]
}
pub trait PlacementSum<S: Session, T, O> {
fn sum(&self, sess: &S, axis: Option<usize>, x: &T) -> O;
}
modelled_kernel! {
PlacementSum::sum, SumOp{axis: Option<usize>},
[
(HostPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (HostFixed64Tensor) -> HostFixed64Tensor => [concrete] Self::fixed_hostfixed_kernel),
(HostPlacement, (HostFixed128Tensor) -> HostFixed128Tensor => [concrete] Self::fixed_hostfixed_kernel),
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_float_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_float_kernel),
(HostPlacement, (HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::host_ring_kernel),
(HostPlacement, (HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::host_ring_kernel),
(HostPlacement, (Tensor) -> Tensor => [concrete] Self::logical_host_kernel),
(ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::fixed_repfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::fixed_repfixed_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_ring_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_ring_kernel),
(ReplicatedPlacement, (Tensor) -> Tensor => [concrete] Self::logical_rep_kernel),
]
}
pub trait PlacementPow2<S: Session, T, O> {
fn pow2(&self, sess: &S, x: &T) -> O;
}
modelled_kernel! {
PlacementPow2::pow2, Pow2Op,
[
(ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::rep_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::rep_rep_kernel),
]
}
pub trait PlacementExp<S: Session, T, O> {
fn exp(&self, sess: &S, x: &T) -> O;
}
modelled_kernel! {
PlacementExp::exp, ExpOp,
[
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_kernel),
(HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_kernel),
(HostPlacement, (Tensor) -> Tensor => [concrete] Self::logical_host_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [transparent] Self::rep_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [transparent] Self::rep_rep_kernel),
(ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Tensor) -> Tensor => [concrete] Self::logical_rep_kernel),
]
}
pub trait PlacementSigmoid<S: Session, T, O> {
fn sigmoid(&self, sess: &S, x: &T) -> O;
}
modelled_kernel! {
PlacementSigmoid::sigmoid, SigmoidOp,
[
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Tensor) -> Tensor => [concrete] Self::logical_host_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [transparent] Self::rep_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [transparent] Self::rep_rep_kernel),
(ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Tensor) -> Tensor => [concrete] Self::logical_rep_kernel),
]
}
pub trait PlacementMean<S: Session, T, O> {
fn mean(&self, sess: &S, axis: Option<u32>, x: &T) -> O;
}
modelled_kernel! {
PlacementMean::mean, MeanOp{axis: Option<u32>},
[
(HostPlacement, (Tensor) -> Tensor => [concrete] custom |op| {
let sig = op.sig;
let axis = op.axis;
Ok(Box::new(move |sess, plc, x| {
Self::logical_host_kernel(sess, plc, sig, axis, x)
}))
}),
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (HostFixed64Tensor) -> HostFixed64Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (HostFixed128Tensor) -> HostFixed128Tensor => [concrete] Self::hostfixed_kernel),
(ReplicatedPlacement, (Tensor) -> Tensor => [concrete] custom |op| {
let sig = op.sig;
let axis = op.axis;
Ok(Box::new(move |sess, plc, x| {
Self::logical_rep_kernel(sess, plc, sig, axis, x)
}))
}),
(ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::repfixed_kernel),
]
}
pub trait PlacementMeanAsFixedpoint<S: Session, T, O> {
fn mean_as_fixedpoint(
&self,
sess: &S,
axis: Option<u32>,
scaling_base: u64,
scaling_exp: u32,
x: &T,
) -> O;
}
modelled_kernel! {
PlacementMeanAsFixedpoint::mean_as_fixedpoint, RingFixedpointMeanOp{axis: Option<u32>, scaling_base: u64, scaling_exp: u32},
[
(HostPlacement, (HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::ring64_kernel),
(HostPlacement, (HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::ring128_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_kernel),
]
}
pub trait PlacementMaximum<S: Session, TS, O> {
fn maximum(&self, sess: &S, x: &[TS]) -> O;
}
modelled_kernel! {
PlacementMaximum::maximum, MaximumOp,
[
(HostPlacement, vec[HostFloat32Tensor] -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, vec[HostFloat64Tensor] -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, vec[HostRing64Tensor] -> HostRing64Tensor => [runtime] Self::host_ring_kernel),
(HostPlacement, vec[HostRing128Tensor] -> HostRing128Tensor => [runtime] Self::host_ring_kernel),
(HostPlacement, vec[Fixed64Tensor] -> Fixed64Tensor => [concrete] Self::fixed_lowering_kernel),
(HostPlacement, vec[Fixed128Tensor] -> Fixed128Tensor => [concrete] Self::fixed_lowering_kernel),
(HostPlacement, vec[Float32Tensor] -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, vec[Float64Tensor] -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, vec[HostFixed64Tensor] -> HostFixed64Tensor => [concrete] Self::host_fixed_kernel),
(HostPlacement, vec[HostFixed128Tensor] -> HostFixed128Tensor => [concrete] Self::host_fixed_kernel),
(HostPlacement, vec[Tensor] -> Tensor => [concrete] Self::logical_host_kernel),
(ReplicatedPlacement, vec[ReplicatedRing64Tensor] -> ReplicatedRing64Tensor => [transparent] Self::kernel),
(ReplicatedPlacement, vec[ReplicatedRing128Tensor] -> ReplicatedRing128Tensor => [transparent] Self::kernel),
(ReplicatedPlacement, vec[Fixed64Tensor] -> Fixed64Tensor => [concrete] Self::fixed_kernel),
(ReplicatedPlacement, vec[Fixed128Tensor] -> Fixed128Tensor => [concrete] Self::fixed_kernel),
(ReplicatedPlacement, vec[ReplicatedFixed64Tensor] -> ReplicatedFixed64Tensor => [concrete] Self::rep_fixed_kernel),
(ReplicatedPlacement, vec[ReplicatedFixed128Tensor] -> ReplicatedFixed128Tensor => [concrete] Self::rep_fixed_kernel),
(ReplicatedPlacement, vec[Tensor] -> Tensor => [concrete] Self::rep_logical_kernel),
]
}
pub trait PlacementAbs<S: Session, T, O> {
fn abs(&self, sess: &S, x: &T) -> O;
}
modelled_kernel! {
PlacementAbs::abs, AbsOp,
[
(HostPlacement, (Tensor) -> Tensor => [concrete] Self::logical_host_kernel),
(HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt8Tensor) -> HostInt8Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt16Tensor) -> HostInt16Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt32Tensor) -> HostInt32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt64Tensor) -> HostInt64Tensor => [runtime] Self::host_kernel),
(ReplicatedPlacement, (Tensor) -> Tensor => [concrete] Self::rep_logical_kernel),
(ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::rep_fixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::rep_fixed_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [transparent] Self::rep_ring_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [transparent] Self::rep_ring_kernel),
]
}
pub trait PlacementRelu<S: Session, T, O> {
fn relu(&self, sess: &S, x: &T) -> O;
}
modelled_kernel! {
PlacementRelu::relu, ReluOp,
[
(HostPlacement, (Tensor) -> Tensor => [concrete] Self::logical_host_kernel),
(HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt8Tensor) -> HostInt8Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt16Tensor) -> HostInt16Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt32Tensor) -> HostInt32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt64Tensor) -> HostInt64Tensor => [runtime] Self::host_kernel),
(ReplicatedPlacement, (Tensor) -> Tensor => [concrete] Self::rep_logical_kernel),
(ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::rep_fixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::rep_fixed_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [transparent] Self::rep_ring_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [transparent] Self::rep_ring_kernel),
]
}
pub trait PlacementSign<S: Session, T, O> {
fn sign(&self, sess: &S, x: &T) -> O;
}
modelled_kernel! {
PlacementSign::sign, SignOp,
[
(HostPlacement, (HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::ring64_kernel),
(HostPlacement, (HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::ring128_kernel),
]
}
pub trait PlacementInverse<S: Session, T, O> {
fn inverse(&self, sess: &S, x: &T) -> O;
}
modelled_kernel! {
PlacementInverse::inverse, InverseOp,
[
(HostPlacement, (Tensor) -> Tensor => [concrete] Self::kernel),
(HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
]
}
pub trait PlacementLog<S: Session, T, O> {
fn log(&self, sess: &S, x: &T) -> O;
}
modelled_kernel! {
PlacementLog::log, LogOp,
[
(HostPlacement, (Tensor) -> Tensor => [concrete] Self::logical_host_kernel),
(ReplicatedPlacement, (Tensor) -> Tensor => [concrete] Self::logical_rep_kernel),
(HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [transparent] Self::rep_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [transparent] Self::rep_rep_kernel),
]
}
pub trait PlacementLog2<S: Session, T, O> {
fn log2(&self, sess: &S, x: &T) -> O;
}
modelled_kernel! {
PlacementLog2::log2, Log2Op,
[
(HostPlacement, (Tensor) -> Tensor => [concrete] Self::logical_host_kernel),
(ReplicatedPlacement, (Tensor) -> Tensor => [concrete] Self::logical_rep_kernel),
(HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::rep_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::rep_rep_kernel),
]
}
pub trait PlacementArgmax<S: Session, T, O> {
fn argmax(&self, sess: &S, axis: usize, upmost_index: usize, x: &T) -> O;
}
modelled_kernel! {
PlacementArgmax::argmax, ArgmaxOp{axis: usize, upmost_index: usize},
[
(HostPlacement, (HostFixed64Tensor) -> HostUint64Tensor => [hybrid] Self::host_fixed_uint_kernel),
(HostPlacement, (HostFixed128Tensor) -> HostUint64Tensor => [hybrid] Self::host_fixed_uint_kernel),
(HostPlacement, (Fixed64Tensor) -> Uint64Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Fixed128Tensor) -> Uint64Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Tensor) -> Tensor => [concrete] Self::logical_host_kernel),
(ReplicatedPlacement, (Tensor) -> Tensor => [concrete] Self::logical_rep_kernel),
(ReplicatedPlacement, (Fixed64Tensor) -> Uint64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> Uint64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedUint64Tensor => [concrete] Self::rep_fixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedUint64Tensor => [concrete] Self::rep_fixed_kernel),
]
}
modelled_kernel! {
PlacementArgmax::argmax, RingFixedpointArgmaxOp{axis: usize, upmost_index: usize},
[
(HostPlacement, (HostRing128Tensor) -> HostRing64Tensor => [runtime] Self::host_ring128_kernel),
(HostPlacement, (HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::host_ring64_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [transparent] Self::rep_ring_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor) -> ReplicatedRing64Tensor => [transparent] Self::rep_ring_kernel),
]
}