extern crate rand;
extern crate tensorism_gen;
use rand::Rng;
use std::{alloc::Layout, fmt::Display, ops::Index, slice};
pub use tensorism_gen::new_ndarray;
pub struct Reindexing1 {
input0_bound: usize,
output_bound: usize,
ptr: *mut usize,
}
impl Reindexing1 {
#[doc(hidden)]
pub fn new(
input0_bound: usize,
output_bound: usize,
mut f: impl FnMut(usize) -> usize,
) -> Self {
if input0_bound == 0 {
return Self {
input0_bound: 0,
output_bound,
ptr: std::ptr::null_mut(),
};
}
if output_bound == 0 {
panic!("Should not be 0")
}
let layout = Layout::array::<usize>(input0_bound).unwrap();
let ptr: *mut usize;
unsafe {
ptr = std::alloc::alloc(layout) as *mut usize;
if ptr.is_null() {
std::alloc::handle_alloc_error(layout);
}
for i in 0..input0_bound {
ptr.add(i).write(f(i) % output_bound);
}
}
Self {
input0_bound,
output_bound,
ptr,
}
}
pub fn shuffle(bound: usize, rand: &mut impl Rng) -> Self {
if bound == 0 {
return Self {
input0_bound: 0,
output_bound: 0,
ptr: std::ptr::null_mut(),
};
}
let layout = Layout::array::<usize>(bound).unwrap();
let ptr: *mut usize;
unsafe {
ptr = std::alloc::alloc(layout) as *mut usize;
if ptr.is_null() {
std::alloc::handle_alloc_error(layout);
}
for i in 0..bound {
ptr.add(i).write(i);
}
for upper_index in (1..bound).rev() {
let random_index = rand.random_range(0..=upper_index);
if random_index != upper_index {
ptr.add(random_index).swap(ptr.add(upper_index));
}
}
}
Self {
input0_bound: bound,
output_bound: bound,
ptr,
}
}
pub fn shuffle_partially(
input0_bound: usize,
output_bound: usize,
rand: &mut impl Rng,
) -> Self {
if input0_bound > output_bound {
panic!("input0_bound must be less than or equal to output_bound");
}
if input0_bound == 0 {
return Self {
input0_bound: 0,
output_bound,
ptr: std::ptr::null_mut(),
};
}
let layout = Layout::array::<usize>(input0_bound).unwrap();
let ptr: *mut usize;
unsafe {
ptr = std::alloc::alloc(layout) as *mut usize;
if ptr.is_null() {
std::alloc::handle_alloc_error(layout);
}
for i in 0..input0_bound {
ptr.add(i).write(i);
}
for i in input0_bound..output_bound {
let j = rand.random_range(0..=i);
if j < input0_bound {
ptr.add(j).write(i);
}
}
for i in (1..input0_bound).rev() {
let j = rand.random_range(0..=i);
ptr.add(j).swap(ptr.add(i));
}
}
Self {
input0_bound,
output_bound,
ptr,
}
}
#[doc(hidden)]
pub fn get_input0_bound(r: &Self) -> usize {
r.input0_bound
}
#[doc(hidden)]
pub fn get_output_bound(r: &Self) -> usize {
r.output_bound
}
pub fn input_shape(&self) -> (usize,) {
(self.input0_bound,)
}
pub fn output_bound(&self) -> usize {
self.output_bound
}
#[doc(hidden)]
pub unsafe fn get_unchecked(r: &Self, i: usize) -> usize {
unsafe { r.ptr.add(i).read() }
}
}
impl PartialEq for Reindexing1 {
fn eq(&self, other: &Self) -> bool {
if self.input0_bound != other.input0_bound || self.output_bound != other.output_bound {
return false;
}
for i in 0..self.input0_bound {
if unsafe { self.ptr.add(i).read() } != unsafe { other.ptr.add(i).read() } {
return false;
}
}
true
}
}
impl Eq for Reindexing1 {}
impl Clone for Reindexing1 {
fn clone(&self) -> Self {
if self.ptr.is_null() {
return Self {
input0_bound: 0,
output_bound: self.output_bound,
ptr: std::ptr::null_mut(),
};
}
let layout = Layout::array::<usize>(self.input0_bound).unwrap();
let ptr: *mut usize;
unsafe {
ptr = std::alloc::alloc(layout) as *mut usize;
if ptr.is_null() {
std::alloc::handle_alloc_error(layout);
}
for i in 0..self.input0_bound {
ptr.add(i).write(self.ptr.add(i).read());
}
}
Self {
input0_bound: self.input0_bound,
output_bound: self.output_bound,
ptr,
}
}
}
impl std::convert::AsRef<[usize]> for Reindexing1 {
fn as_ref(&self) -> &[usize] {
if self.ptr.is_null() {
return &[];
}
unsafe { slice::from_raw_parts(self.ptr, self.input0_bound) }
}
}
impl Index<usize> for Reindexing1 {
type Output = usize;
fn index(&self, index: usize) -> &Self::Output {
if index >= self.input0_bound {
panic!("Index out of bound.")
}
unsafe { self.ptr.add(index).as_ref().unwrap() }
}
}
impl Drop for Reindexing1 {
fn drop(&mut self) {
if self.ptr.is_null() {
return;
}
let layout = Layout::array::<usize>(self.input0_bound).unwrap();
unsafe { std::alloc::dealloc(self.ptr as *mut u8, layout) }
}
}
impl Display for Reindexing1 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[")?;
for i in 0..self.input0_bound {
if i != 0 {
write!(f, ", ")?;
}
write!(f, "{}↦{}", i, unsafe { self.ptr.add(i).read() })?;
}
write!(f, "]")
}
}
unsafe impl Send for Reindexing1 {}
unsafe impl Sync for Reindexing1 {}
pub struct Reindexing2 {
input0_bound: usize,
input1_bound: usize,
output_bound: usize,
ptr: *mut usize,
}
impl Reindexing2 {
#[doc(hidden)]
pub fn new(
input0_bound: usize,
input1_bound: usize,
output_bound: usize,
mut f: impl FnMut(usize, usize) -> usize,
) -> Self {
if output_bound == 0 {
panic!("Should not be 0")
}
let bound = input0_bound
.checked_mul(input1_bound)
.expect("Saturated value");
let layout = Layout::array::<usize>(bound).unwrap();
let ptr: *mut usize;
unsafe {
ptr = std::alloc::alloc(layout) as *mut usize;
if ptr.is_null() {
std::alloc::handle_alloc_error(layout);
}
let mut i = 0;
for i0 in 0..input0_bound {
for i1 in 0..input1_bound {
ptr.add(i).write(f(i0, i1) % output_bound);
i += 1;
}
}
}
Self {
input0_bound,
input1_bound,
output_bound,
ptr,
}
}
pub fn shuffle(input0_bound: usize, input1_bound: usize, rand: &mut impl Rng) -> Self {
let bound = input0_bound
.checked_mul(input1_bound)
.expect("Saturated value");
let layout = Layout::array::<usize>(bound).unwrap();
let ptr: *mut usize;
unsafe {
ptr = std::alloc::alloc(layout) as *mut usize;
if ptr.is_null() {
std::alloc::handle_alloc_error(layout);
}
for i in 0..bound {
ptr.add(i).write(i);
}
for i in 0..(bound - 2) {
let shrinking_bound = bound - i;
let last_valid_index = shrinking_bound - 1;
let random_index = rand.random_range(0..shrinking_bound);
if random_index < last_valid_index {
ptr.add(random_index).swap(ptr.add(last_valid_index));
}
}
}
Self {
input0_bound,
input1_bound,
output_bound: bound,
ptr,
}
}
#[doc(hidden)]
pub fn get_input0_bound(r: &Self) -> usize {
r.input0_bound
}
#[doc(hidden)]
pub fn get_input1_bound(r: &Self) -> usize {
r.input1_bound
}
#[doc(hidden)]
pub fn get_output_bound(r: &Self) -> usize {
r.output_bound
}
pub fn input_shape(&self) -> (usize, usize) {
(self.input0_bound, self.input1_bound)
}
pub fn output_bound(&self) -> usize {
self.output_bound
}
#[doc(hidden)]
pub unsafe fn get_unchecked(r: &Self, i0: usize, i1: usize) -> usize {
unsafe { r.ptr.add(i0 * r.input1_bound + i1).read() }
}
}
impl AsRef<[usize]> for Reindexing2 {
fn as_ref(&self) -> &[usize] {
unsafe { slice::from_raw_parts(self.ptr, self.input0_bound * self.input1_bound) }
}
}
impl Index<(usize, usize)> for Reindexing2 {
type Output = usize;
fn index(&self, index: (usize, usize)) -> &Self::Output {
if index.0 >= self.input0_bound || index.1 >= self.input1_bound {
panic!("Index out of bound.")
}
unsafe {
self.ptr
.add(index.0 * self.input1_bound + index.1)
.as_ref()
.unwrap()
}
}
}
impl Drop for Reindexing2 {
fn drop(&mut self) {
if self.ptr.is_null() {
return;
}
let layout = Layout::array::<usize>(self.input0_bound * self.input1_bound).unwrap();
unsafe { std::alloc::dealloc(self.ptr as *mut u8, layout) }
}
}
pub struct Reindexing3 {
input0_bound: usize,
input1_bound: usize,
input2_bound: usize,
output_bound: usize,
ptr: *mut usize,
}
impl Reindexing3 {
#[doc(hidden)]
pub fn new(
input0_bound: usize,
input1_bound: usize,
input2_bound: usize,
output_bound: usize,
mut f: impl FnMut(usize, usize, usize) -> usize,
) -> Self {
if output_bound == 0 {
panic!("Should not be 0")
}
let count = input0_bound
.checked_mul(input1_bound)
.expect("Saturated value")
.checked_mul(input2_bound)
.expect("Saturated value");
let layout = Layout::array::<usize>(count).unwrap();
let ptr: *mut usize;
unsafe {
ptr = std::alloc::alloc(layout) as *mut usize;
if ptr.is_null() {
std::alloc::handle_alloc_error(layout);
}
let mut i = 0;
for i0 in 0..input0_bound {
for i1 in 0..input1_bound {
for i2 in 0..input2_bound {
ptr.add(i).write(f(i0, i1, i2) % output_bound);
i += 1;
}
}
}
}
Self {
input0_bound,
input1_bound,
input2_bound,
output_bound,
ptr,
}
}
pub fn shuffle(
input0_bound: usize,
input1_bound: usize,
input2_bound: usize,
rand: &mut impl Rng,
) -> Self {
let bound = input0_bound
.checked_mul(input1_bound)
.expect("Saturated value")
.checked_mul(input2_bound)
.expect("Saturated value");
let layout = Layout::array::<usize>(bound).unwrap();
let ptr: *mut usize;
unsafe {
ptr = std::alloc::alloc(layout) as *mut usize;
if ptr.is_null() {
std::alloc::handle_alloc_error(layout);
}
for i in 0..bound {
ptr.add(i).write(i);
}
for upper_index in (1..bound).rev() {
let random_index = rand.random_range(0..=upper_index);
if random_index != upper_index {
ptr.add(random_index).swap(ptr.add(upper_index));
}
}
}
Self {
input0_bound,
input1_bound,
input2_bound,
output_bound: bound,
ptr,
}
}
#[doc(hidden)]
pub fn get_input0_bound(r: &Self) -> usize {
r.input0_bound
}
#[doc(hidden)]
pub fn get_input1_bound(r: &Self) -> usize {
r.input1_bound
}
#[doc(hidden)]
pub fn get_input2_bound(r: &Self) -> usize {
r.input2_bound
}
#[doc(hidden)]
pub fn get_output_bound(r: &Self) -> usize {
r.output_bound
}
pub fn input_shape(&self) -> (usize, usize, usize) {
(self.input0_bound, self.input1_bound, self.input2_bound)
}
pub fn output_bound(&self) -> usize {
self.output_bound
}
#[doc(hidden)]
pub unsafe fn get_unchecked(r: &Self, i0: usize, i1: usize, i2: usize) -> usize {
unsafe {
r.ptr
.add(i0 * r.input1_bound * r.input2_bound + i1 * r.input2_bound + i2)
.read()
}
}
}
impl AsRef<[usize]> for Reindexing3 {
fn as_ref(&self) -> &[usize] {
unsafe {
slice::from_raw_parts(
self.ptr,
self.input0_bound * self.input1_bound * self.input2_bound,
)
}
}
}
impl Index<(usize, usize, usize)> for Reindexing3 {
type Output = usize;
fn index(&self, index: (usize, usize, usize)) -> &Self::Output {
if index.0 >= self.input0_bound
|| index.1 >= self.input1_bound
|| index.2 >= self.input2_bound
{
panic!("Index out of bound.")
}
let index = (index.0 * self.input1_bound + index.1) * self.input2_bound + index.2;
unsafe { self.ptr.add(index).as_ref().unwrap() }
}
}
impl Drop for Reindexing3 {
fn drop(&mut self) {
if self.ptr.is_null() {
return;
}
let layout =
Layout::array::<usize>(self.input0_bound * self.input1_bound * self.input2_bound)
.unwrap();
unsafe { std::alloc::dealloc(self.ptr as *mut u8, layout) }
}
}
pub struct Reindexing4 {
input0_bound: usize,
input1_bound: usize,
input2_bound: usize,
input3_bound: usize,
output_bound: usize,
ptr: *mut usize,
}
impl Reindexing4 {
#[doc(hidden)]
pub fn new(
input0_bound: usize,
input1_bound: usize,
input2_bound: usize,
input3_bound: usize,
output_bound: usize,
mut f: impl FnMut(usize, usize, usize, usize) -> usize,
) -> Self {
if output_bound == 0 {
panic!("Should not be 0")
}
let count = input0_bound
.checked_mul(input1_bound)
.expect("Saturated value")
.checked_mul(input2_bound)
.expect("Saturated value")
.checked_mul(input3_bound)
.expect("Saturated value");
let layout = Layout::array::<usize>(count).unwrap();
let ptr: *mut usize;
unsafe {
ptr = std::alloc::alloc(layout) as *mut usize;
if ptr.is_null() {
std::alloc::handle_alloc_error(layout);
}
let mut i = 0;
for i0 in 0..input0_bound {
for i1 in 0..input1_bound {
for i2 in 0..input2_bound {
for i3 in 0..input3_bound {
ptr.add(i).write(f(i0, i1, i2, i3) % output_bound);
i += 1;
}
}
}
}
}
Self {
input0_bound,
input1_bound,
input2_bound,
input3_bound,
output_bound,
ptr,
}
}
pub fn shuffle(
input0_bound: usize,
input1_bound: usize,
input2_bound: usize,
input3_bound: usize,
rand: &mut impl Rng,
) -> Self {
let bound = input0_bound
.checked_mul(input1_bound)
.expect("Saturated value")
.checked_mul(input2_bound)
.expect("Saturated value")
.checked_mul(input3_bound)
.expect("Saturated value");
let layout = Layout::array::<usize>(bound).unwrap();
let ptr: *mut usize;
unsafe {
ptr = std::alloc::alloc(layout) as *mut usize;
if ptr.is_null() {
std::alloc::handle_alloc_error(layout);
}
for i in 0..bound {
ptr.add(i).write(i);
}
for upper_index in (1..bound).rev() {
let random_index = rand.random_range(0..=upper_index);
if random_index != upper_index {
ptr.add(random_index).swap(ptr.add(upper_index));
}
}
}
Self {
input0_bound,
input1_bound,
input2_bound,
input3_bound,
output_bound: bound,
ptr,
}
}
#[doc(hidden)]
pub fn get_input0_bound(r: &Self) -> usize {
r.input0_bound
}
#[doc(hidden)]
pub fn get_input1_bound(r: &Self) -> usize {
r.input1_bound
}
#[doc(hidden)]
pub fn get_input2_bound(r: &Self) -> usize {
r.input2_bound
}
#[doc(hidden)]
pub fn get_input3_bound(r: &Self) -> usize {
r.input3_bound
}
#[doc(hidden)]
pub fn get_output_bound(r: &Self) -> usize {
r.output_bound
}
pub fn input_shape(&self) -> (usize, usize, usize, usize) {
(
self.input0_bound,
self.input1_bound,
self.input2_bound,
self.input3_bound,
)
}
pub fn output_bound(&self) -> usize {
self.output_bound
}
#[doc(hidden)]
pub unsafe fn get_unchecked(r: &Self, i0: usize, i1: usize, i2: usize, i3: usize) -> usize {
unsafe {
r.ptr
.add(((i0 * r.input1_bound + i1) * r.input2_bound + i2) * r.input3_bound + i3)
.read()
}
}
}
impl AsRef<[usize]> for Reindexing4 {
fn as_ref(&self) -> &[usize] {
unsafe {
slice::from_raw_parts(
self.ptr,
self.input0_bound * self.input1_bound * self.input2_bound * self.input3_bound,
)
}
}
}
impl Index<(usize, usize, usize, usize)> for Reindexing4 {
type Output = usize;
fn index(&self, index: (usize, usize, usize, usize)) -> &Self::Output {
if index.0 >= self.input0_bound
|| index.1 >= self.input1_bound
|| index.2 >= self.input2_bound
|| index.3 >= self.input3_bound
{
panic!("Index out of bound.")
}
let index = ((index.0 * self.input1_bound + index.1) * self.input2_bound + index.2)
* self.input3_bound
+ index.3;
unsafe { self.ptr.add(index).as_ref().unwrap() }
}
}
impl Drop for Reindexing4 {
fn drop(&mut self) {
if self.ptr.is_null() {
return;
}
let layout = Layout::array::<usize>(
self.input0_bound * self.input1_bound * self.input2_bound * self.input3_bound,
)
.unwrap();
unsafe { std::alloc::dealloc(self.ptr as *mut u8, layout) }
}
}
#[cfg(test)]
mod tests {
use rand::SeedableRng;
use super::*;
const SEED: [u8; 32] = [
7u8, 38u8, 19u8, 50u8, 11u8, 22u8, 33u8, 44u8, 133u8, 144u8, 155u8, 166u8, 177u8, 188u8,
199u8, 200u8, 55u8, 66u8, 77u8, 88u8, 99u8, 100u8, 111u8, 122u8, 211u8, 222u8, 233u8,
244u8, 255u8, 0u8, 2u8, 3u8,
];
#[test]
fn test_reindexing1() {
let reindexing = Reindexing1::new(16, 7, |i| i * 5 + 4);
assert_eq!(16, Reindexing1::get_input0_bound(&reindexing));
assert_eq!(7, Reindexing1::get_output_bound(&reindexing));
assert_eq!(
&[4, 2, 0, 5, 3, 1, 6, 4, 2, 0, 5, 3, 1, 6, 4, 2],
reindexing.as_ref()
);
assert_eq!(2, reindexing[15])
}
#[should_panic]
#[test]
fn test_panicking_reindexing1() {
let reindexing = Reindexing1::new(16, 7, |i| i * 5 + 4);
let _ = reindexing[16];
}
#[test]
fn test_shuffle_reindexing1() {
let mut rand = rand::rngs::StdRng::from_seed(SEED);
let reindexing = Reindexing1::shuffle(22, &mut rand);
assert_eq!(
&[
14, 2, 11, 15, 12, 4, 1, 18, 9, 16, 21, 0, 5, 10, 13, 19, 3, 20, 17, 7, 6, 8
],
reindexing.as_ref()
);
let reindexing = Reindexing1::shuffle(22, &mut rand);
assert_eq!(
&[
5, 11, 0, 3, 10, 18, 15, 4, 7, 13, 9, 16, 20, 1, 12, 8, 17, 19, 21, 6, 2, 14
],
reindexing.as_ref()
);
}
#[test]
fn test_shuffle_partially_reindexing1() {
let mut rand = rand::rngs::StdRng::from_seed(SEED);
let reindexing = Reindexing1::shuffle_partially(7, 22, &mut rand);
assert_eq!(&[9, 12, 17, 13, 5, 11, 1], reindexing.as_ref());
let reindexing = Reindexing1::shuffle_partially(7, 22, &mut rand);
assert_eq!(&[7, 6, 0, 16, 8, 19, 12], reindexing.as_ref());
let reindexing = Reindexing1::shuffle_partially(7, 22, &mut rand);
assert_eq!(&[2, 17, 8, 0, 12, 11, 20], reindexing.as_ref());
let reindexing = Reindexing1::shuffle_partially(7, 22, &mut rand);
assert_eq!(&[16, 7, 0, 19, 20, 4, 3], reindexing.as_ref());
let reindexing = Reindexing1::shuffle_partially(7, 22, &mut rand);
assert_eq!(&[16, 13, 18, 5, 2, 17, 0], reindexing.as_ref());
}
#[test]
fn test_shuffle_partially_reindexing1_uniformity() {
let mut rand = rand::rngs::StdRng::from_seed(SEED);
let mut counts = [0; 22];
const NUMBER_OF_SAMPLES: usize = 50_000;
for _ in 0..NUMBER_OF_SAMPLES {
let reindexing = Reindexing1::shuffle_partially(7, 22, &mut rand);
for index in reindexing.as_ref() {
counts[*index] += 1;
}
}
let min = counts.iter().min().unwrap();
let max = counts.iter().max().unwrap();
assert!(
max - min <= NUMBER_OF_SAMPLES / 100,
"Counts are not uniform: min = {}, max = {}",
min,
max
);
}
#[test]
fn test_reindexing2() {
let reindexing = Reindexing2::new(6, 4, 24, |i0, i1| 17 * (i0 * 4 + i1) + 11);
assert_eq!(6, Reindexing2::get_input0_bound(&reindexing));
assert_eq!(4, Reindexing2::get_input1_bound(&reindexing));
assert_eq!(24, Reindexing2::get_output_bound(&reindexing));
assert_eq!(
&[
11, 4, 21, 14, 7, 0, 17, 10, 3, 20, 13, 6, 23, 16, 9, 2, 19, 12, 5, 22, 15, 8, 1,
18
],
reindexing.as_ref(),
);
assert_eq!(18, reindexing[(5, 3)])
}
#[should_panic]
#[test]
fn test_panicking_reindexing2() {
let reindexing = Reindexing2::new(6, 4, 24, |i0, i1| i0 * 4 + i1);
let _ = reindexing[(5, 4)];
}
#[test]
fn test_shuffle_reindexing2() {
let mut rand = rand::rngs::StdRng::from_seed(SEED);
let reindexing = Reindexing2::shuffle(3, 5, &mut rand);
assert_eq!(
&[8, 6, 9, 3, 0, 2, 12, 7, 11, 1, 13, 10, 14, 4, 5],
reindexing.as_ref()
);
}
}