use rten_tensor::prelude::*;
use rten_tensor::storage::Alloc;
use rten_tensor::{CowTensor, TensorView};
pub trait ShiftCast<T> {
fn shift_cast(self) -> T;
fn shift_cast_in(self, _alloc: impl Alloc) -> T
where
Self: Sized,
{
self.shift_cast()
}
}
macro_rules! impl_noop_cast {
($type:ty) => {
impl ShiftCast<$type> for $type {
fn shift_cast(self) -> Self {
self
}
}
};
}
impl_noop_cast!(i8);
impl ShiftCast<u8> for i8 {
fn shift_cast(self) -> u8 {
(self as u8) ^ 0x80
}
}
impl_noop_cast!(u8);
impl ShiftCast<i8> for u8 {
fn shift_cast(self) -> i8 {
(self ^ 0x80) as i8
}
}
impl<'a, T> ShiftCast<CowTensor<'a, T>> for TensorView<'a, T> {
fn shift_cast(self) -> CowTensor<'a, T> {
self.as_cow()
}
}
impl<'a> ShiftCast<CowTensor<'a, u8>> for TensorView<'a, i8> {
fn shift_cast(self) -> CowTensor<'a, u8> {
self.map(|&x| x.shift_cast()).into_cow()
}
fn shift_cast_in(self, alloc: impl Alloc) -> CowTensor<'a, u8> {
self.map_in(alloc, |&x| x.shift_cast()).into_cow()
}
}
impl<'a> ShiftCast<CowTensor<'a, i8>> for TensorView<'a, u8> {
fn shift_cast(self) -> CowTensor<'a, i8> {
self.map(|&x| x.shift_cast()).into_cow()
}
fn shift_cast_in(self, alloc: impl Alloc) -> CowTensor<'a, i8> {
self.map_in(alloc, |&x| x.shift_cast()).into_cow()
}
}
impl<T, U> ShiftCast<Vec<U>> for Vec<T>
where
T: ShiftCast<U>,
{
fn shift_cast(self) -> Vec<U> {
self.into_iter().map(|x| x.shift_cast()).collect()
}
}
#[cfg(test)]
mod tests {
use rten_tensor::Tensor;
use rten_tensor::prelude::*;
use super::{CowTensor, ShiftCast};
#[test]
fn test_shift_cast_scalar() {
const LEN: usize = 5;
let input = [-128i8, -1, 0, 1, 127];
let expected = [0u8, 127, 128, 129, 255];
let actual: [u8; LEN] = input.map(|x| x.shift_cast());
assert_eq!(actual, expected);
let actual_noop: [u8; LEN] = actual.map(|x| x.shift_cast());
assert_eq!(actual_noop, expected);
let actual_inverse: [i8; LEN] = expected.map(|x| x.shift_cast());
assert_eq!(actual_inverse, input);
let actual_inverse_noop: [i8; LEN] = input.map(|x| x.shift_cast());
assert_eq!(actual_inverse_noop, input);
}
#[test]
fn test_shift_cast_tensor() {
let input = Tensor::from([-128i8, -1, 0, 1, 127]);
let expected = Tensor::from([0u8, 127, 128, 129, 255]);
let actual: CowTensor<u8> = input.view().shift_cast();
assert_eq!(actual, expected);
let noop_cast: CowTensor<u8> = actual.view().shift_cast();
assert_eq!(noop_cast, actual);
let actual_inverse: CowTensor<i8> = expected.view().shift_cast();
assert_eq!(actual_inverse, input);
}
#[test]
fn test_shift_cast_vec() {
let input: Vec<_> = [-128i8, -1, 0, 1, 127].into();
let expected = [0u8, 127, 128, 129, 255];
let actual: Vec<u8> = input.shift_cast();
assert_eq!(actual, expected);
}
}