use std::num::NonZeroUsize;
#[cfg(feature = "flatbuffers")]
use flatbuffers::{FlatBufferBuilder, WIPOffset};
use rand::RngCore;
use thiserror::Error;
use crate::alloc::{Allocator, AllocatorError, ScopedAllocator, TryClone};
#[cfg(feature = "flatbuffers")]
use crate::flatbuffers as fb;
mod double_hadamard;
mod null;
mod padding_hadamard;
crate::utils::features! {
#![feature = "linalg"]
mod random_rotation;
}
mod utils;
#[cfg(test)]
#[cfg(not(miri))]
mod test_utils;
pub use double_hadamard::{DoubleHadamard, DoubleHadamardError};
pub use null::NullTransform;
pub use padding_hadamard::{PaddingHadamard, PaddingHadamardError};
pub use utils::TransformFailed;
crate::utils::features! {
#![feature = "linalg"]
pub use random_rotation::RandomRotation;
}
crate::utils::features! {
#![all(feature = "linalg", feature = "flatbuffers")]
pub use random_rotation::RandomRotationError;
}
crate::utils::features! {
#![feature = "flatbuffers"]
pub use null::NullTransformError;
}
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub enum TransformKind {
PaddingHadamard { target_dim: TargetDim },
DoubleHadamard { target_dim: TargetDim },
Null,
#[cfg(feature = "linalg")]
#[cfg_attr(docsrs, doc(cfg(feature = "linalg")))]
RandomRotation { target_dim: TargetDim },
}
#[derive(Debug, Clone, Error)]
pub enum NewTransformError {
#[error("random number generator is required for {0:?}")]
RngMissing(TransformKind),
#[error(transparent)]
AllocatorError(#[from] AllocatorError),
}
#[derive(Debug)]
#[cfg_attr(test, derive(PartialEq))]
pub enum Transform<A>
where
A: Allocator,
{
PaddingHadamard(PaddingHadamard<A>),
DoubleHadamard(DoubleHadamard<A>),
Null(NullTransform),
#[cfg(feature = "linalg")]
#[cfg_attr(docsrs, doc(cfg(feature = "linalg")))]
RandomRotation(RandomRotation),
}
impl<A> Transform<A>
where
A: Allocator,
{
pub fn new(
transform_kind: TransformKind,
dim: NonZeroUsize,
rng: Option<&mut dyn RngCore>,
allocator: A,
) -> Result<Self, NewTransformError> {
match transform_kind {
TransformKind::PaddingHadamard { target_dim } => {
let rng = rng.ok_or(NewTransformError::RngMissing(transform_kind))?;
Ok(Transform::PaddingHadamard(PaddingHadamard::new(
dim, target_dim, rng, allocator,
)?))
}
TransformKind::DoubleHadamard { target_dim } => {
let rng = rng.ok_or(NewTransformError::RngMissing(transform_kind))?;
Ok(Transform::DoubleHadamard(DoubleHadamard::new(
dim, target_dim, rng, allocator,
)?))
}
TransformKind::Null => Ok(Transform::Null(NullTransform::new(dim))),
#[cfg(feature = "linalg")]
TransformKind::RandomRotation { target_dim } => {
let rng = rng.ok_or(NewTransformError::RngMissing(transform_kind))?;
Ok(Transform::RandomRotation(RandomRotation::new(
dim, target_dim, rng,
)))
}
}
}
pub(crate) fn input_dim(&self) -> usize {
match self {
Self::PaddingHadamard(t) => t.input_dim(),
Self::DoubleHadamard(t) => t.input_dim(),
Self::Null(t) => t.dim(),
#[cfg(feature = "linalg")]
Self::RandomRotation(t) => t.input_dim(),
}
}
pub(crate) fn output_dim(&self) -> usize {
match self {
Self::PaddingHadamard(t) => t.output_dim(),
Self::DoubleHadamard(t) => t.output_dim(),
Self::Null(t) => t.dim(),
#[cfg(feature = "linalg")]
Self::RandomRotation(t) => t.output_dim(),
}
}
pub(crate) fn preserves_norms(&self) -> bool {
match self {
Self::PaddingHadamard(t) => t.preserves_norms(),
Self::DoubleHadamard(t) => t.preserves_norms(),
Self::Null(t) => t.preserves_norms(),
#[cfg(feature = "linalg")]
Self::RandomRotation(t) => t.preserves_norms(),
}
}
pub(crate) fn transform_into(
&self,
dst: &mut [f32],
src: &[f32],
allocator: ScopedAllocator<'_>,
) -> Result<(), TransformFailed> {
match self {
Self::PaddingHadamard(t) => t.transform_into(dst, src, allocator),
Self::DoubleHadamard(t) => t.transform_into(dst, src, allocator),
Self::Null(t) => t.transform_into(dst, src),
#[cfg(feature = "linalg")]
Self::RandomRotation(t) => t.transform_into(dst, src),
}
}
}
impl<A> TryClone for Transform<A>
where
A: Allocator,
{
fn try_clone(&self) -> Result<Self, AllocatorError> {
match self {
Self::PaddingHadamard(t) => Ok(Self::PaddingHadamard(t.try_clone()?)),
Self::DoubleHadamard(t) => Ok(Self::DoubleHadamard(t.try_clone()?)),
Self::Null(t) => Ok(Self::Null(t.clone())),
#[cfg(feature = "linalg")]
Self::RandomRotation(t) => Ok(Self::RandomRotation(t.clone())),
}
}
}
#[cfg(feature = "flatbuffers")]
#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
#[derive(Debug, Clone, Copy, Error, PartialEq)]
#[non_exhaustive]
pub enum TransformError {
#[error(transparent)]
PaddingHadamardError(#[from] PaddingHadamardError),
#[error(transparent)]
DoubleHadamardError(#[from] DoubleHadamardError),
#[error(transparent)]
NullTransformError(#[from] NullTransformError),
#[cfg(feature = "linalg")]
#[cfg_attr(docsrs, doc(cfg(feature = "linalg")))]
#[error(transparent)]
RandomRotationError(#[from] RandomRotationError),
#[error("invalid transform kind")]
InvalidTransformKind,
}
#[cfg(feature = "flatbuffers")]
impl<A> Transform<A>
where
A: Allocator,
{
pub(crate) fn pack<'a, FA>(
&self,
buf: &mut FlatBufferBuilder<'a, FA>,
) -> WIPOffset<fb::transforms::Transform<'a>>
where
FA: flatbuffers::Allocator + 'a,
{
let (kind, offset) = match self {
Self::PaddingHadamard(t) => (
fb::transforms::TransformKind::PaddingHadamard,
t.pack(buf).as_union_value(),
),
Self::DoubleHadamard(t) => (
fb::transforms::TransformKind::DoubleHadamard,
t.pack(buf).as_union_value(),
),
Self::Null(t) => (
fb::transforms::TransformKind::NullTransform,
t.pack(buf).as_union_value(),
),
#[cfg(feature = "linalg")]
Self::RandomRotation(t) => (
fb::transforms::TransformKind::RandomRotation,
t.pack(buf).as_union_value(),
),
};
fb::transforms::Transform::create(
buf,
&fb::transforms::TransformArgs {
transform_type: kind,
transform: Some(offset),
},
)
}
pub(crate) fn try_unpack(
alloc: A,
proto: fb::transforms::Transform<'_>,
) -> Result<Self, TransformError> {
if let Some(transform) = proto.transform_as_padding_hadamard() {
return Ok(Self::PaddingHadamard(PaddingHadamard::try_unpack(
alloc, transform,
)?));
}
#[cfg(feature = "linalg")]
if let Some(transform) = proto.transform_as_random_rotation() {
return Ok(Self::RandomRotation(RandomRotation::try_unpack(transform)?));
}
if let Some(transform) = proto.transform_as_double_hadamard() {
return Ok(Self::DoubleHadamard(DoubleHadamard::try_unpack(
alloc, transform,
)?));
}
if let Some(transform) = proto.transform_as_null_transform() {
return Ok(Self::Null(NullTransform::try_unpack(transform)?));
}
Err(TransformError::InvalidTransformKind)
}
}
#[derive(Debug, Clone, Copy)]
pub enum TargetDim {
Same,
Natural,
Override(NonZeroUsize),
}
#[cfg(test)]
#[cfg(not(miri))]
test_utils::delegate_transformer!(Transform<crate::alloc::GlobalAllocator>);