moose 0.2.2

Encrypted learning and data processing framework
Documentation
//! Abstraction layer for high-level logical tensors

use crate::computation::{HasShortName, Placed, Placement};
#[cfg(feature = "compile")]
use crate::computation::{PartiallySymbolicType, SymbolicType};
use crate::error::Result;
#[cfg(feature = "compile")]
use crate::execution::symbolic::Symbolic;
use crate::types::*;
use derive_more::Display;
use serde::{Deserialize, Serialize};
#[cfg(feature = "compile")]
use std::convert::TryFrom;

mod ops;

#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Copy, Clone, Debug, Display)]
pub enum TensorDType {
    #[display(fmt = "Fixed64({}, {})", integral_precision, fractional_precision)]
    Fixed64 {
        integral_precision: u32,
        fractional_precision: u32,
    },
    #[display(fmt = "Fixed128({}, {})", integral_precision, fractional_precision)]
    Fixed128 {
        integral_precision: u32,
        fractional_precision: u32,
    },
    Float32,
    Float64,
    Bool,
    Uint64,
    Unknown,
}

#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Copy, Clone, Debug, Display)]
pub enum TensorShape {
    Host,
    Replicated,
    Additive,
    Mirrored,
    Unknown,
}

impl HasShortName for TensorDType {
    fn short_name(&self) -> &str {
        match self {
            TensorDType::Fixed64 { .. } => "Fixed64",
            TensorDType::Fixed128 { .. } => "Fixed128",
            TensorDType::Float32 => "Float32",
            TensorDType::Float64 => "Float64",
            TensorDType::Bool => "Bool",
            TensorDType::Uint64 => "Uint64",
            TensorDType::Unknown => "Unknown",
        }
    }
}

impl HasShortName for TensorShape {
    fn short_name(&self) -> &str {
        match self {
            TensorShape::Host => "Host",
            TensorShape::Replicated => "Replicated",
            TensorShape::Additive => "Additive",
            TensorShape::Mirrored => "Mirrored",
            TensorShape::Unknown => "Unknown",
        }
    }
}

#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub enum AbstractTensor<Fixed64T, Fixed128T, Float32T, Float64T, BoolT, Uint64T> {
    Fixed64(Fixed64T),
    Fixed128(Fixed128T),
    Float32(Float32T),
    Float64(Float64T),
    Bool(BoolT),
    Uint64(Uint64T),
}

#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub enum AbstractShape<HostS, RepS> {
    Host(HostS),
    Replicated(RepS),
}

impl<Fixed64T, Fixed128T, Float32T, Float64T, BoolT, Uint64T>
    AbstractTensor<Fixed64T, Fixed128T, Float32T, Float64T, BoolT, Uint64T>
{
    pub(crate) fn ty_desc(&self) -> String {
        match self {
            AbstractTensor::Fixed64(_) => "Tensor(Fixed64)",
            AbstractTensor::Fixed128(_) => "Tensor(Fixed128)",
            AbstractTensor::Float32(_) => "Tensor(Float32)",
            AbstractTensor::Float64(_) => "Tensor(Float64)",
            AbstractTensor::Bool(_) => "Tensor(Bool)",
            AbstractTensor::Uint64(_) => "Tensor(Uint64T)",
        }
        .to_string()
    }
}

impl<Fixed64T, Fixed128T, Float32T, Float64T, BoolT, Uint64T> Placed
    for AbstractTensor<Fixed64T, Fixed128T, Float32T, Float64T, BoolT, Uint64T>
where
    Fixed64T: Placed,
    Fixed64T::Placement: Into<Placement>,
    Fixed128T: Placed,
    Fixed128T::Placement: Into<Placement>,
    Float32T: Placed,
    Float32T::Placement: Into<Placement>,
    Float64T: Placed,
    Float64T::Placement: Into<Placement>,
    BoolT: Placed,
    BoolT::Placement: Into<Placement>,
    Uint64T: Placed,
    Uint64T::Placement: Into<Placement>,
{
    type Placement = Placement;

    fn placement(&self) -> Result<Self::Placement> {
        match self {
            AbstractTensor::Fixed64(x) => Ok(x.placement()?.into()),
            AbstractTensor::Fixed128(x) => Ok(x.placement()?.into()),
            AbstractTensor::Float32(x) => Ok(x.placement()?.into()),
            AbstractTensor::Float64(x) => Ok(x.placement()?.into()),
            AbstractTensor::Bool(x) => Ok(x.placement()?.into()),
            AbstractTensor::Uint64(x) => Ok(x.placement()?.into()),
        }
    }
}

#[cfg(feature = "compile")]
impl PartiallySymbolicType for Tensor {
    #[allow(clippy::type_complexity)]
    type Type = AbstractTensor<
        <Fixed64Tensor as SymbolicType>::Type,
        <Fixed128Tensor as SymbolicType>::Type,
        <Float32Tensor as SymbolicType>::Type,
        <Float64Tensor as SymbolicType>::Type,
        <BooleanTensor as SymbolicType>::Type,
        <Uint64Tensor as SymbolicType>::Type,
    >;
}

#[cfg(feature = "compile")]
impl<Fixed64T, Fixed128T, Float32T, Float64T, BoolT, Uint64T>
    From<AbstractTensor<Fixed64T, Fixed128T, Float32T, Float64T, BoolT, Uint64T>>
    for Symbolic<AbstractTensor<Fixed64T, Fixed128T, Float32T, Float64T, BoolT, Uint64T>>
where
    Fixed64T: Placed<Placement = Placement>,
    Fixed128T: Placed<Placement = Placement>,
    Float32T: Placed<Placement = Placement>,
    Float64T: Placed<Placement = Placement>,
    BoolT: Placed<Placement = Placement>,
    Uint64T: Placed<Placement = Placement>,
{
    fn from(x: AbstractTensor<Fixed64T, Fixed128T, Float32T, Float64T, BoolT, Uint64T>) -> Self {
        Symbolic::Concrete(x)
    }
}

#[cfg(feature = "compile")]
impl<Fixed64T, Fixed128T, Float32T, Float64T, BoolT, Uint64T>
    TryFrom<Symbolic<AbstractTensor<Fixed64T, Fixed128T, Float32T, Float64T, BoolT, Uint64T>>>
    for AbstractTensor<Fixed64T, Fixed128T, Float32T, Float64T, BoolT, Uint64T>
where
    Fixed64T: Placed<Placement = Placement>,
    Fixed128T: Placed<Placement = Placement>,
    Float32T: Placed<Placement = Placement>,
    Float64T: Placed<Placement = Placement>,
    BoolT: Placed<Placement = Placement>,
    Uint64T: Placed<Placement = Placement>,
{
    type Error = ();
    fn try_from(
        v: Symbolic<AbstractTensor<Fixed64T, Fixed128T, Float32T, Float64T, BoolT, Uint64T>>,
    ) -> std::result::Result<Self, ()> {
        match v {
            Symbolic::Concrete(x) => Ok(x),
            _ => Err(()),
        }
    }
}

#[cfg(feature = "compile")]
impl PartiallySymbolicType for Shape {
    type Type =
        AbstractShape<<HostShape as SymbolicType>::Type, <ReplicatedShape as SymbolicType>::Type>;
}

impl<HostS, RepS> Placed for AbstractShape<HostS, RepS>
where
    HostS: Placed,
    HostS::Placement: Into<Placement>,
    RepS: Placed,
    RepS::Placement: Into<Placement>,
{
    type Placement = Placement;

    fn placement(&self) -> Result<Self::Placement> {
        match self {
            AbstractShape::Host(sh) => Ok(sh.placement()?.into()),
            AbstractShape::Replicated(sh) => Ok(sh.placement()?.into()),
        }
    }
}

#[cfg(feature = "compile")]
impl<HostS, RepS> From<AbstractShape<HostS, RepS>> for Symbolic<AbstractShape<HostS, RepS>>
where
    HostS: Placed<Placement = Placement>,
    RepS: Placed<Placement = Placement>,
{
    fn from(x: AbstractShape<HostS, RepS>) -> Self {
        Symbolic::Concrete(x)
    }
}

#[cfg(feature = "compile")]
impl<HostS, RepS> TryFrom<Symbolic<AbstractShape<HostS, RepS>>> for AbstractShape<HostS, RepS>
where
    HostS: Placed<Placement = Placement>,
    RepS: Placed<Placement = Placement>,
{
    type Error = ();
    fn try_from(v: Symbolic<AbstractShape<HostS, RepS>>) -> std::result::Result<Self, ()> {
        match v {
            Symbolic::Concrete(x) => Ok(x),
            _ => Err(()),
        }
    }
}