use super::*;
pub trait PlacementAtLeast2D<S: Session, T, O> {
fn at_least_2d(&self, sess: &S, to_column_vector: bool, x: &T) -> O;
}
modelled_kernel! {
PlacementAtLeast2D::at_least_2d, AtLeast2DOp{to_column_vector: bool},
[
(HostPlacement, (Tensor) -> Tensor => [concrete] Self::logical_host_kernel),
(HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt8Tensor) -> HostInt8Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt16Tensor) -> HostInt16Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt32Tensor) -> HostInt32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt64Tensor) -> HostInt64Tensor => [runtime] Self::host_kernel),
]
}
pub trait PlacementDiag<S: Session, T, O> {
fn diag(&self, sess: &S, x: &T) -> O;
}
modelled_kernel! {
PlacementDiag::diag, DiagOp,
[
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt8Tensor) -> HostInt8Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt16Tensor) -> HostInt16Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt32Tensor) -> HostInt32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt64Tensor) -> HostInt64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostUint16Tensor) -> HostUint16Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostUint32Tensor) -> HostUint32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostUint64Tensor) -> HostUint64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostBitTensor) -> HostBitTensor => [runtime] Self::bit_kernel),
(HostPlacement, (HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::ring_kernel),
(HostPlacement, (HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::ring_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedBitTensor) -> ReplicatedBitTensor => [concrete] Self::rep_kernel),
]
}
pub trait PlacementExpandDims<S: Session, T, O> {
fn expand_dims(&self, sess: &S, axis: Vec<usize>, x: &T) -> O;
}
modelled_kernel! {
PlacementExpandDims::expand_dims, ExpandDimsOp{axis: Vec<usize>},
[
(HostPlacement, (BooleanTensor) -> BooleanTensor => [concrete] Self::bool_host_kernel),
(HostPlacement, (Tensor) -> Tensor => [concrete] Self::logical_host_kernel),
(HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (HostBitTensor) -> HostBitTensor => [runtime] Self::host_bit_kernel),
(HostPlacement, (HostFixed64Tensor) -> HostFixed64Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (HostFixed128Tensor) -> HostFixed128Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_int_float_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_int_float_kernel),
(HostPlacement, (HostInt64Tensor) -> HostInt64Tensor => [runtime] Self::host_int_float_kernel),
(HostPlacement, (HostInt32Tensor) -> HostInt32Tensor => [runtime] Self::host_int_float_kernel),
(HostPlacement, (HostInt16Tensor) -> HostInt16Tensor => [runtime] Self::host_int_float_kernel),
(HostPlacement, (HostInt8Tensor) -> HostInt8Tensor => [runtime] Self::host_int_float_kernel),
(HostPlacement, (HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::host_ring_kernel),
(HostPlacement, (HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::host_ring_kernel),
(ReplicatedPlacement, (BooleanTensor) -> BooleanTensor => [concrete] Self::bool_rep_kernel),
(ReplicatedPlacement, (Tensor) -> Tensor => [concrete] Self::logical_rep_kernel),
(ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (ReplicatedBitTensor) -> ReplicatedBitTensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_kernel),
]
}
pub trait PlacementSqueeze<S: Session, T, O> {
fn squeeze(&self, sess: &S, axis: Option<usize>, x: &T) -> O;
}
modelled_kernel! {
PlacementSqueeze::squeeze, SqueezeOp{axis: Option<usize>}, [
(HostPlacement, (HostBitTensor) -> HostBitTensor => [runtime] Self::host_bit_kernel),
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt32Tensor) -> HostInt32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt64Tensor) -> HostInt64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::host_ring_kernel),
(HostPlacement, (HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::host_ring_kernel),
(HostPlacement, (HostUint32Tensor) -> HostUint32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostUint64Tensor) -> HostUint64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (BooleanTensor) -> BooleanTensor => [concrete] Self::bool_host_kernel),
(HostPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_kernel),
(HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_kernel),
(HostPlacement, (HostFixed64Tensor) -> HostFixed64Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (HostFixed128Tensor) -> HostFixed128Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (Tensor) -> Tensor => [concrete] Self::logical_host_kernel),
(HostPlacement, (Uint64Tensor) -> Uint64Tensor => [concrete] Self::u64_host_kernel),
(ReplicatedPlacement, (ReplicatedBitTensor) -> ReplicatedBitTensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedUint64Tensor) -> ReplicatedUint64Tensor => [concrete] Self::rep_uint_kernel),
(ReplicatedPlacement, (BooleanTensor) -> BooleanTensor => [concrete] Self::bool_rep_kernel),
(ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Tensor) -> Tensor => [concrete] Self::logical_rep_kernel),
(ReplicatedPlacement, (Uint64Tensor) -> Uint64Tensor => [concrete] Self::u64_rep_kernel),
]
}
pub trait PlacementConcatenate<S: Session, TS, O> {
fn concatenate(&self, sess: &S, axis: u32, xs: &[TS]) -> O;
}
modelled_kernel! {
PlacementConcatenate::concatenate, ConcatOp{axis: u32},
[
(HostPlacement, vec[Tensor] -> Tensor => [concrete] Self::logical_host_kernel),
(HostPlacement, vec[Float32Tensor] -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, vec[Float64Tensor] -> Float64Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, vec[HostBitTensor] -> HostBitTensor => [runtime] Self::bit_kernel),
(HostPlacement, vec[HostFloat32Tensor] -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, vec[HostFloat64Tensor] -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, vec[HostInt8Tensor] -> HostInt8Tensor => [runtime] Self::host_kernel),
(HostPlacement, vec[HostInt16Tensor] -> HostInt16Tensor => [runtime] Self::host_kernel),
(HostPlacement, vec[HostInt32Tensor] -> HostInt32Tensor => [runtime] Self::host_kernel),
(HostPlacement, vec[HostInt64Tensor] -> HostInt64Tensor => [runtime] Self::host_kernel),
(HostPlacement, vec[HostRing64Tensor] -> HostRing64Tensor => [runtime] Self::ring_kernel),
(HostPlacement, vec[HostRing128Tensor] -> HostRing128Tensor => [runtime] Self::ring_kernel),
(ReplicatedPlacement, vec[Tensor] -> Tensor => [concrete] Self::logical_rep_kernel),
(ReplicatedPlacement, vec[BooleanTensor] -> BooleanTensor => [concrete] Self::bool_rep_kernel),
(ReplicatedPlacement, vec[Fixed64Tensor] -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, vec[Fixed128Tensor] -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, vec[ReplicatedBitTensor] -> ReplicatedBitTensor => [concrete] Self::rep_rep_kernel),
(ReplicatedPlacement, vec[ReplicatedFixed64Tensor] -> ReplicatedFixed64Tensor => [concrete] Self::rep_fixed_kernel),
(ReplicatedPlacement, vec[ReplicatedFixed128Tensor] -> ReplicatedFixed128Tensor => [concrete] Self::rep_fixed_kernel),
(ReplicatedPlacement, vec[ReplicatedRing64Tensor] -> ReplicatedRing64Tensor => [concrete] Self::rep_rep_kernel),
(ReplicatedPlacement, vec[ReplicatedRing128Tensor] -> ReplicatedRing128Tensor => [concrete] Self::rep_rep_kernel),
]
}
pub trait PlacementTranspose<S: Session, T, O> {
fn transpose(&self, sess: &S, x: &T) -> O;
}
modelled_kernel! {
PlacementTranspose::transpose, TransposeOp,
[
(HostPlacement, (HostBitTensor) -> HostBitTensor => [runtime] Self::host_bit_kernel),
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt8Tensor) -> HostInt8Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt16Tensor) -> HostInt16Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt32Tensor) -> HostInt32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt64Tensor) -> HostInt64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::host_ring_kernel),
(HostPlacement, (HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::host_ring_kernel),
(HostPlacement, (HostUint32Tensor) -> HostUint32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostUint64Tensor) -> HostUint64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (BooleanTensor) -> BooleanTensor => [concrete] Self::bool_host_kernel),
(HostPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_kernel),
(HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_kernel),
(HostPlacement, (HostFixed64Tensor) -> HostFixed64Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (HostFixed128Tensor) -> HostFixed128Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (Tensor) -> Tensor => [concrete] Self::logical_host_kernel),
(HostPlacement, (Uint64Tensor) -> Uint64Tensor => [concrete] Self::u64_host_kernel),
(ReplicatedPlacement, (ReplicatedBitTensor) -> ReplicatedBitTensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedUint64Tensor) -> ReplicatedUint64Tensor => [concrete] Self::rep_uint_kernel),
(ReplicatedPlacement, (BooleanTensor) -> BooleanTensor => [concrete] Self::bool_rep_kernel),
(ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Tensor) -> Tensor => [concrete] Self::logical_rep_kernel),
(ReplicatedPlacement, (Uint64Tensor) -> Uint64Tensor => [concrete] Self::u64_rep_kernel),
]
}
pub trait PlacementShape<S: Session, T, ShapeT> {
fn shape(&self, sess: &S, x: &T) -> ShapeT;
}
modelled_kernel! {
PlacementShape::shape, ShapeOp,
[
(HostPlacement, (Tensor) -> Shape => [concrete] Self::host_logical_kernel),
(HostPlacement, (Float32Tensor) -> HostShape => [hybrid] Self::float_kernel),
(HostPlacement, (Float64Tensor) -> HostShape => [hybrid] Self::float_kernel),
(HostPlacement, (Fixed64Tensor) -> HostShape => [hybrid] Self::host_fixed_kernel),
(HostPlacement, (Fixed128Tensor) -> HostShape => [hybrid] Self::host_fixed_kernel),
(HostPlacement, (HostFixed64Tensor) -> HostShape => [hybrid] Self::host_hostfixed_kernel),
(HostPlacement, (HostFixed128Tensor) -> HostShape => [hybrid] Self::host_hostfixed_kernel),
(HostPlacement, (HostRing64Tensor) -> HostShape => [runtime] Self::ring_kernel),
(HostPlacement, (HostRing128Tensor) -> HostShape => [runtime] Self::ring_kernel),
(HostPlacement, (HostBitTensor) -> HostShape => [runtime] Self::bit_kernel),
(HostPlacement, (HostFloat32Tensor) -> HostShape => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostShape => [runtime] Self::host_kernel),
(ReplicatedPlacement, (Tensor) -> Shape => [concrete] Self::rep_logical_kernel),
(ReplicatedPlacement, (Fixed64Tensor) -> ReplicatedShape => [hybrid] Self::rep_fixed_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> ReplicatedShape => [hybrid] Self::rep_fixed_kernel),
(ReplicatedPlacement, (ReplicatedBitTensor) -> ReplicatedShape => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor) -> ReplicatedShape => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor) -> ReplicatedShape => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedShape => [hybrid] Self::rep_repfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedShape => [hybrid] Self::rep_repfixed_kernel),
(AdditivePlacement, (AdditiveRing64Tensor) -> AdditiveShape => [concrete] Self::adt_kernel),
(AdditivePlacement, (AdditiveRing128Tensor) -> AdditiveShape => [concrete] Self::adt_kernel),
]
}
pub trait PlacementReshape<S: Session, T, ShapeT, O> {
fn reshape(&self, sess: &S, x: &T, shape: &ShapeT) -> O;
}
modelled_kernel! {
PlacementReshape::reshape, ReshapeOp,
[
(HostPlacement, (Tensor, Shape) -> Tensor => [concrete] Self::host_logical_kernel),
(HostPlacement, (Float32Tensor, HostShape) -> Float32Tensor => [hybrid] Self::float_host_kernel),
(HostPlacement, (Float64Tensor, HostShape) -> Float64Tensor => [hybrid] Self::float_host_kernel),
(HostPlacement, (Fixed64Tensor, HostShape) -> Fixed64Tensor => [hybrid] Self::fixed_host_kernel),
(HostPlacement, (Fixed128Tensor, HostShape) -> Fixed128Tensor => [hybrid] Self::fixed_host_kernel),
(HostPlacement, (HostFixed64Tensor, HostShape) -> HostFixed64Tensor => [hybrid] Self::hostfixed_kernel),
(HostPlacement, (HostFixed128Tensor, HostShape) -> HostFixed128Tensor => [hybrid] Self::hostfixed_kernel),
(HostPlacement, (HostRing64Tensor, HostShape) -> HostRing64Tensor => [runtime] Self::host_ring_kernel),
(HostPlacement, (HostRing128Tensor, HostShape) -> HostRing128Tensor => [runtime] Self::host_ring_kernel),
(HostPlacement, (HostBitTensor, HostShape) -> HostBitTensor => [runtime] Self::host_bit_kernel),
(HostPlacement, (HostFloat32Tensor, HostShape) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor, HostShape) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt8Tensor, HostShape) -> HostInt8Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt16Tensor, HostShape) -> HostInt16Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt32Tensor, HostShape) -> HostInt32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostInt64Tensor, HostShape) -> HostInt64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostUint8Tensor, HostShape) -> HostUint8Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostUint16Tensor, HostShape) -> HostUint16Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostUint32Tensor, HostShape) -> HostUint32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostUint64Tensor, HostShape) -> HostUint64Tensor => [runtime] Self::host_kernel),
(ReplicatedPlacement, (Tensor, Shape) -> Tensor => [concrete] Self::rep_logical_kernel),
(ReplicatedPlacement, (Fixed64Tensor, ReplicatedShape) -> Fixed64Tensor => [hybrid] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor, ReplicatedShape) -> Fixed128Tensor => [hybrid] Self::fixed_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor, ReplicatedShape) -> ReplicatedFixed64Tensor => [hybrid] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor, ReplicatedShape) -> ReplicatedFixed128Tensor => [hybrid] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor, ReplicatedShape) -> ReplicatedRing64Tensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor, ReplicatedShape) -> ReplicatedRing128Tensor => [concrete] Self::rep_kernel),
]
}