use serde::{Deserialize, Serialize};
use std::{
error::Error,
fmt,
fmt::{Debug, Display},
fs,
marker::PhantomData,
ops::Div,
str::FromStr,
};
use anyhow::Result;
use itertools::iproduct;
use itertools::Itertools;
use num_traits::{
pow,
sign::{abs, Signed},
Float, Num, NumAssign, NumAssignOps, NumAssignRef, NumOps, One, Zero,
};
use rand::{distributions::uniform::SampleUniform, Rng};
use rayon::prelude::*;
use std::iter::{Product, Sum};
pub type Shape = Vec<usize>;
fn swap(lhs: &mut usize, rhs: &mut usize) {
let temp = *lhs;
*lhs = *rhs;
*rhs = temp;
}
macro_rules! at {
($i:expr, $j:expr, $ncols:expr) => {
$i * $ncols + $j
};
}
macro_rules! index {
($indexes:expr, $dimensions:expr) => {{
if $indexes.len() != $dimensions.len() {
32
} else {
let mut stride = 1;
let mut result = 0;
for (&index, &dimension) in $indexes.iter().rev().zip($dimensions.iter().rev()) {
result += index * stride;
stride *= dimension;
}
result
}
}};
}
macro_rules! index_list {
($single_index:expr, $dimensions:expr) => {{
let mut indexes = Vec::with_capacity($dimensions.len());
let mut remaining_index = $single_index;
for &dimension in $dimensions.iter().rev() {
indexes.push(remaining_index % dimension);
remaining_index /= dimension;
}
indexes.reverse();
indexes
}};
}
#[derive(Debug, PartialEq)]
pub enum TensorError {
TensorCreationError,
TensorIndexOutOfBoundsError,
MatrixMultiplicationDimensionMismatchError,
TensorDimensionMismatchError,
TensorParseError,
TensorDivideByZeroError,
TensorFileReadError(&'static str),
}
impl std::fmt::Display for TensorError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TensorError::TensorCreationError => {
write!(f, "There was an error creating the tensor.")
}
TensorError::TensorIndexOutOfBoundsError => {
write!(f, "The indexes are out of bounds for the matrix")
}
TensorError::MatrixMultiplicationDimensionMismatchError => {
write!(
f,
"The two matrices supplied are not on the form M x N @ N x P"
)
}
TensorError::TensorDimensionMismatchError => {
write!(f, "The tensors provided are both not on the form M x N")
}
TensorError::TensorParseError => write!(f, "Failed to parse tensor from file"),
TensorError::TensorDivideByZeroError => write!(f, "Tried to divide by zero"),
TensorError::TensorFileReadError(path) => {
write!(f, "Could not read file from path: {}", path)
}
}
}
}
impl<'a, T> std::error::Error for Tensor<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
}
#[derive(Clone, PartialEq, PartialOrd, Debug, Serialize, Deserialize)]
pub struct Tensor<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub data: Vec<T>,
pub shape: Shape,
pub ndims: usize,
_lifetime: PhantomData<&'a T>,
}
pub trait TensorElement:
Copy
+ Clone
+ PartialOrd
+ Signed
+ Float
+ Sum
+ Product
+ Display
+ Debug
+ FromStr
+ Default
+ One
+ PartialEq
+ Zero
+ Send
+ Sync
+ Sized
+ Num
+ NumOps
+ NumAssignOps
+ NumAssignRef
+ NumAssign
+ SampleUniform
{
}
unsafe impl<'a, T> Send for Tensor<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
}
unsafe impl<'a, T> Sync for Tensor<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
}
impl TensorElement for f32 {}
impl TensorElement for f64 {}
impl<'a, T> FromStr for Tensor<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
type Err = TensorError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let v: Vec<T> = s
.trim()
.lines()
.map(|l| {
l.split_whitespace()
.map(|num| num.parse::<T>().unwrap())
.collect::<Vec<T>>()
})
.collect::<Vec<Vec<T>>>()
.into_iter()
.flatten()
.collect();
let rows = s.trim().lines().count();
let cols = s.trim().lines().nth(0).unwrap().split_whitespace().count();
if let Ok(tensor) = Self::new(v, vec![rows, cols]) {
return Ok(tensor);
}
Err(TensorError::TensorCreationError.into())
}
}
impl<'a, T> Display for Tensor<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Ok(())
}
}
impl<'a, T> Default for Tensor<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
fn default() -> Self {
Self::eye(3)
}
}
impl<'a, T> Tensor<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn new(data: Vec<T>, shape: Shape) -> Result<Self, TensorError> {
if shape.iter().product::<usize>() != data.len() {
return Err(TensorError::TensorCreationError.into());
}
Ok(Self {
data,
ndims: shape.len(),
shape,
_lifetime: PhantomData::default(),
})
}
pub fn init(value: T, shape: Shape) -> Self {
Self::from_shape(value, shape.clone())
}
pub fn eye(size: usize) -> Self {
let mut data: Vec<T> = vec![T::zero(); size * size];
(0..size).for_each(|i| data[i * size + i] = T::one());
Self::new(data, vec![size, size]).unwrap()
}
pub fn identity(size: usize) -> Self {
Self::eye(size)
}
pub fn from_slice(arr: &[T], shape: Shape) -> Result<Self, TensorError> {
if shape.iter().product::<usize>() != arr.len() {
return Err(TensorError::TensorCreationError.into());
}
Ok(Self::new(arr.to_owned(), shape).unwrap())
}
pub fn zeros(shape: Shape) -> Self {
Self::from_shape(T::zero(), shape.clone())
}
pub fn ones(shape: Shape) -> Self {
Self::from_shape(T::one(), shape.clone())
}
pub fn zeros_like(other: &Self) -> Self {
Self::from_shape(T::zero(), other.shape.clone())
}
pub fn ones_like(other: &Self) -> Self {
Self::from_shape(T::one(), other.shape.clone())
}
pub fn random_like(tensor: &Self) -> Self {
Self::randomize_range(T::zero(), T::one(), tensor.shape.clone())
}
pub fn randomize_range(start: T, end: T, shape: Shape) -> Self {
let mut rng = rand::thread_rng();
let len: usize = shape.iter().product();
let data: Vec<T> = (0..len).map(|_| rng.gen_range(start..=end)).collect();
Self::new(data, shape.clone()).unwrap()
}
pub fn randomize(shape: Shape) -> Self {
Self::randomize_range(T::zero(), T::one(), shape.clone())
}
pub fn from_file(path: &'static str) -> Result<Self, TensorError> {
let data =
fs::read_to_string(path).map_err(|_| TensorError::TensorFileReadError(path).into())?;
data.parse::<Self>()
.map_err(|_| TensorError::TensorParseError.into())
}
fn from_shape(value: T, shape: Shape) -> Self {
let len: usize = shape.iter().product();
let data = vec![value; len];
Self::new(data, shape).unwrap()
}
}
pub enum Dimension {
Row = 0,
Col = 1,
}
impl<'a, T> Tensor<'a, T>
where
T: TensorElement + Div<Output = T> + Sum<T>,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn reshape(&mut self, new_shape: Shape) {
if new_shape.iter().product::<usize>() != self.size() {
println!("Can not reshape.. Keeping old dimensions for now");
return;
}
self.shape = new_shape;
}
pub fn squeeze(&mut self) {
self.shape.retain(|&num| num > 1);
self.ndims = self.shape.len()
}
pub fn size(&self) -> usize {
self.shape.iter().product()
}
pub fn get(&self, idx: Shape) -> Option<T> {
let i: usize = index!(idx, self.shape);
if i >= self.size() {
return None;
}
Some(self.data[i])
}
pub fn get_vec_slice(&self, start_idx: Shape, dy: usize, dx: usize) -> Vec<T> {
let start_row = start_idx.iter().nth_back(1).unwrap().clone();
let start_col = start_idx.iter().nth_back(0).unwrap().clone();
let y_range = start_row..start_row + dy;
let x_range = start_col..start_col + dx;
iproduct!(y_range, x_range)
.filter_map(|(i, j)| self.get(vec![i, j]))
.collect()
}
fn get_sub_tensor(&self, start_idx: Shape, size: Shape) -> Vec<T> {
unimplemented!()
}
pub fn set(&mut self, idx: Shape, value: T) {
if idx.iter().product::<usize>() >= self.size() {
eprintln!("Error: Index out of bounds. Keeping old value");
return;
}
let i: usize = index!(idx, self.shape);
self.data[i] = value;
}
pub fn set_many(&mut self, indexes: Vec<Shape>, value: T) {
indexes.iter().for_each(|idx| self.set(idx.clone(), value));
}
pub fn max(&self) -> T {
*self
.data
.par_iter()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
}
pub fn min(&self) -> T {
*self
.data
.par_iter()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
}
pub fn cumsum(&self) -> T {
self.data.par_iter().copied().sum()
}
pub fn cumprod(&self) -> T {
self.data.par_iter().copied().product()
}
pub fn avg(&self) -> T {
let mut size: T = T::zero();
self.data.iter().for_each(|_| size += T::one());
let tot: T = self.data.par_iter().copied().sum::<T>();
tot / size
}
pub fn mean(&self) -> T {
self.avg()
}
pub fn median(&self) -> T {
match self.data.len() % 2 {
0 => {
let half: usize = self.data.len() / 2;
self.data
.iter()
.sorted_by(|a, b| a.partial_cmp(&b).unwrap())
.skip(half - 1)
.take(2)
.copied()
.sum::<T>()
/ (T::from(2).unwrap())
}
1 => {
let half: usize = self.data.len() / 2;
self.data
.iter()
.sorted_by(|a, b| a.partial_cmp(&b).unwrap())
.nth(half)
.copied()
.unwrap()
}
_ => unreachable!(),
}
}
fn sum(&self, rowcol: usize, dimension: Dimension) -> T {
unimplemented!()
}
fn prod(&self, rowcol: usize, dimension: Dimension) -> T {
unimplemented!()
}
}
impl<'a, T> Tensor<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn add(&self, other: &Self) -> Result<Self, TensorError> {
if self.shape != other.shape {
return Err(TensorError::TensorDimensionMismatchError.into());
}
let data = self
.data
.iter()
.zip(other.data.iter())
.map(|(&x, &y)| x + y)
.collect_vec();
Ok(Self::new(data, self.shape.clone()).unwrap())
}
pub fn sub(&self, other: &Self) -> Result<Self, TensorError> {
if self.shape != other.shape {
return Err(TensorError::TensorDimensionMismatchError.into());
}
let data = self
.data
.iter()
.zip(other.data.iter())
.map(|(&x, &y)| x - y)
.collect_vec();
Ok(Self::new(data, self.shape.clone()).unwrap())
}
pub fn sub_abs(&self, other: &Self) -> Result<Self, TensorError> {
if self.shape != other.shape {
return Err(TensorError::TensorDimensionMismatchError.into());
}
let data = self
.data
.iter()
.zip(other.data.iter())
.map(|(&x, &y)| if x > y { x - y } else { y - x })
.collect_vec();
Ok(Self::new(data, self.shape.clone()).unwrap())
}
pub fn mul(&self, other: &Self) -> Result<Self, TensorError> {
if self.shape != other.shape {
return Err(TensorError::TensorDimensionMismatchError.into());
}
let data = self
.data
.iter()
.zip(other.data.iter())
.map(|(&x, &y)| x * y)
.collect_vec();
Ok(Self::new(data, self.shape.clone()).unwrap())
}
pub fn div(&self, other: &Self) -> Result<Self, TensorError> {
if other.any(|e| e == &T::zero()) {
return Err(TensorError::TensorDivideByZeroError.into());
}
if self.shape != other.shape {
return Err(TensorError::TensorDimensionMismatchError.into());
}
let data = self
.data
.iter()
.zip(other.data.iter())
.map(|(&x, &y)| x / y)
.collect_vec();
Ok(Self::new(data, self.shape.clone()).unwrap())
}
pub fn add_val(&self, val: T) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| e + val).collect();
Self::new(data, self.shape.clone()).unwrap()
}
pub fn sub_val(&self, val: T) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| e - val).collect();
Self::new(data, self.shape.clone()).unwrap()
}
pub fn mul_val(&self, val: T) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| e * val).collect();
Self::new(data, self.shape.clone()).unwrap()
}
pub fn div_val(&self, val: T) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| e / val).collect();
Self::new(data, self.shape.clone()).unwrap()
}
pub fn log(&self, base: T) -> Self {
let data: Vec<T> = self.data.iter().map(|&e| e.log(base)).collect();
Self::new(data, self.shape.clone()).unwrap()
}
pub fn ln(&self) -> Self {
let data: Vec<T> = self.data.iter().map(|&e| e.ln()).collect();
Self::new(data, self.shape.clone()).unwrap()
}
pub fn tanh(&self) -> Self {
let data: Vec<T> = self.data.iter().map(|&e| e.tanh()).collect();
Self::new(data, self.shape.clone()).unwrap()
}
pub fn sinh(&self) -> Self {
let data: Vec<T> = self.data.iter().map(|&e| e.tanh()).collect();
Self::new(data, self.shape.clone()).unwrap()
}
pub fn cosh(&self) -> Self {
let data: Vec<T> = self.data.iter().map(|&e| e.cosh()).collect();
Self::new(data, self.shape.clone()).unwrap()
}
pub fn pow(&self, val: usize) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| pow(e, val)).collect();
Self::new(data, self.shape.clone()).unwrap()
}
pub fn abs(&self) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| abs(e)).collect();
Self::new(data, self.shape.clone()).unwrap()
}
pub fn add_self(&mut self, other: &Self) {
self.data
.par_iter_mut()
.zip(&other.data)
.for_each(|(a, b)| *a += *b);
}
pub fn sub_self(&mut self, other: &Self) {
self.data
.par_iter_mut()
.zip(&other.data)
.for_each(|(a, b)| *a -= *b);
}
pub fn mul_self(&mut self, other: &Self) {
self.data
.par_iter_mut()
.zip(&other.data)
.for_each(|(a, b)| *a *= *b);
}
pub fn div_self(&mut self, other: &Self) {
self.data
.par_iter_mut()
.zip(&other.data)
.for_each(|(a, b)| *a /= *b);
}
pub fn abs_self(&mut self) {
self.data.par_iter_mut().for_each(|e| *e = abs(*e))
}
pub fn add_val_self(&mut self, val: T) {
self.data.par_iter_mut().for_each(|e| *e += val);
}
pub fn sub_val_self(&mut self, val: T) {
self.data.par_iter_mut().for_each(|e| *e -= val);
}
pub fn mul_val_self(&mut self, val: T) {
self.data.par_iter_mut().for_each(|e| *e *= val);
}
pub fn div_val_self(&mut self, val: T) {
self.data.par_iter_mut().for_each(|e| *e /= val);
}
pub fn matmul(&self, other: &Self) -> Result<Self, TensorError> {
if self.ndims == 1 && other.ndims == 1 {
return Ok(self.mul(&other).unwrap());
}
if self.ndims == 1 && other.ndims == 0 {
return Ok(self.mul_val(other.data[0]));
}
if self.ndims == 0 && other.ndims == 1 {
return Ok(self.mul_val(other.data[0]));
}
if self.shape[1] != other.shape[0] {
return Err(TensorError::MatrixMultiplicationDimensionMismatchError.into());
}
Ok(Self::default())
}
pub fn mm(&self, other: &Self) -> Result<Self, TensorError> {
if self.ndims != 2 || other.ndims != 2 {
return self.matmul(&other);
}
let r1 = self.shape[0];
let c1 = self.shape[1];
let c2 = other.shape[1];
let mut data = vec![T::zero(); c2 * r1];
let t_other = other.transpose_copy();
for i in 0..r1 {
for j in 0..c2 {
data[at!(i, j, c2)] = (0..c1)
.into_par_iter()
.map(|k| self.data[at!(i, k, c1)] * t_other.data[at!(j, k, t_other.shape[1])])
.sum();
}
}
Self::new(data, vec![c2, r1])
}
pub fn transpose(&mut self) {
if self.ndims < 2 {
eprintln!("Error: You need at least a 2D tensor (matrix) to transpose.");
return;
}
let ncols = self.shape.iter().nth_back(1).unwrap().clone();
let nrows = self.shape.iter().nth_back(0).unwrap().clone();
for i in 0..nrows {
for j in (i + 1)..ncols {
let lhs = index!(vec![i, j], self.shape);
let rhs = index!(vec![j, i], self.shape);
self.data.swap(lhs, rhs);
}
}
let ncols_idx = self.shape.len() - 1;
let nrows_idx = self.shape.len() - 2;
self.shape.swap(ncols_idx, nrows_idx);
}
pub fn t(&mut self) {
self.transpose()
}
pub fn transpose_copy(&self) -> Self {
let mut res = self.clone();
res.transpose();
res
}
pub fn eigenvalue(&self) -> T {
todo!()
}
}
impl<'a, T> Tensor<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn count_where<F>(&'a self, pred: F) -> usize
where
F: Fn(&T) -> bool + Sync,
{
self.data.par_iter().filter(|&e| pred(e)).count()
}
pub fn sum_where<F>(&self, pred: F) -> T
where
F: Fn(&T) -> bool + Sync,
{
self.data
.par_iter()
.filter(|&e| pred(e))
.copied()
.sum::<T>()
}
pub fn set_where<P>(&mut self, mut pred: P)
where
P: FnMut(&mut T) + Sync + Send,
{
for element in self.data.iter_mut() {
pred(element);
}
}
pub fn any<F>(&self, pred: F) -> bool
where
F: Fn(&T) -> bool + Sync + Send,
{
self.data.par_iter().any(pred)
}
pub fn all<F>(&self, pred: F) -> bool
where
F: Fn(&T) -> bool + Sync + Send,
{
self.data.par_iter().all(pred)
}
pub fn find<F>(&self, pred: F) -> Option<Shape>
where
F: Fn(&T) -> bool + Sync,
{
if let Some((idx, _)) = self.data.iter().find_position(|&e| pred(e)) {
return Some(index_list!(idx, self.shape));
}
None
}
pub fn find_all<F>(&self, pred: F) -> Option<Vec<Shape>>
where
F: Fn(&T) -> bool + Sync,
{
let data: Vec<Shape> = self
.data
.par_iter()
.enumerate()
.filter_map(|(idx, elem)| {
if pred(elem) {
Some(index_list!(idx, self.shape))
} else {
None
}
})
.collect();
if data.is_empty() {
None
} else {
Some(data)
}
}
}