use crate::crypto::UnsignedTorus;
use crate::math::tensor::{AsMutTensor, AsRefTensor};
use crate::math::torus::{FromTorus, IntoTorus};
use crate::numeric::{FloatingPoint, Numeric};
use super::{Cleartext, CleartextList, Plaintext, PlaintextList};
pub trait Encoder<Enc: Numeric> {
type Raw: Numeric;
fn encode(&self, raw: Cleartext<Self::Raw>) -> Plaintext<Enc>;
fn decode(&self, encoded: Plaintext<Enc>) -> Cleartext<Self::Raw>;
fn encode_list<RawCont, EncCont>(
&self,
encoded: &mut PlaintextList<EncCont>,
raw: &CleartextList<RawCont>,
) where
CleartextList<RawCont>: AsRefTensor<Element = Self::Raw>,
PlaintextList<EncCont>: AsMutTensor<Element = Enc>;
fn decode_list<RawCont, EncCont>(
&self,
raw: &mut CleartextList<RawCont>,
encoded: &PlaintextList<EncCont>,
) where
CleartextList<RawCont>: AsMutTensor<Element = Self::Raw>,
PlaintextList<EncCont>: AsRefTensor<Element = Enc>;
}
pub struct RealEncoder<T: FloatingPoint> {
pub offset: T,
pub delta: T,
}
impl<RawScalar, EncScalar> Encoder<EncScalar> for RealEncoder<RawScalar>
where
EncScalar: UnsignedTorus + FromTorus<RawScalar> + IntoTorus<RawScalar>,
RawScalar: FloatingPoint,
{
type Raw = RawScalar;
fn encode(&self, raw: Cleartext<RawScalar>) -> Plaintext<EncScalar> {
Plaintext(<EncScalar as FromTorus<RawScalar>>::from_torus(
(raw.0 - self.offset) / self.delta,
))
}
fn decode(&self, encoded: Plaintext<EncScalar>) -> Cleartext<RawScalar> {
let mut e: RawScalar = encoded.0.into_torus();
e *= self.delta;
e += self.offset;
Cleartext(e)
}
fn encode_list<RawCont, EncCont>(
&self,
encoded: &mut PlaintextList<EncCont>,
raw: &CleartextList<RawCont>,
) where
CleartextList<RawCont>: AsRefTensor<Element = RawScalar>,
PlaintextList<EncCont>: AsMutTensor<Element = EncScalar>,
{
encoded
.as_mut_tensor()
.fill_with_one(raw.as_tensor(), |r| self.encode(Cleartext(*r)).0);
}
fn decode_list<RawCont, EncCont>(
&self,
raw: &mut CleartextList<RawCont>,
encoded: &PlaintextList<EncCont>,
) where
CleartextList<RawCont>: AsMutTensor<Element = RawScalar>,
PlaintextList<EncCont>: AsRefTensor<Element = EncScalar>,
{
raw.as_mut_tensor()
.fill_with_one(encoded.as_tensor(), |e| self.decode(Plaintext(*e)).0);
}
}