use super::Tensor;
use super::properties::is_broadcastable;
use mlua::prelude::*;
impl Tensor {
#[tracing::instrument(skip_all)]
pub fn exp(&self) -> Result<Self, LuaError> {
Ok(Self(self.0.exp()))
}
}
#[tracing::instrument(skip_all)]
pub fn add(Tensor(this): &Tensor, other: LuaValue) -> Result<Tensor, LuaError> {
let new = match other {
LuaValue::UserData(user_data) => {
let Tensor(oth) = user_data.borrow::<Tensor>()?.to_owned();
if !is_broadcastable(this.shape(), oth.shape()) {
return Err(LuaError::external(format!(
"Shape {:?} not broadcastable to {:?}",
this.shape(),
oth.shape()
)));
}
this + oth
}
LuaValue::Number(n) => this + (n as f32),
LuaValue::Integer(i) => this + (i as f32),
_ => return Err(LuaError::external("Expected either number or Tensor.")),
};
Ok(Tensor(new))
}
#[tracing::instrument(skip_all)]
pub fn sub(Tensor(this): &Tensor, other: LuaValue) -> Result<Tensor, LuaError> {
let new = match other {
LuaValue::UserData(user_data) => {
let Tensor(oth) = user_data.borrow::<Tensor>()?.to_owned();
if !is_broadcastable(oth.shape(), this.shape()) {
return Err(LuaError::external(format!(
"Shape {:?} not broadcastable to {:?}",
this.shape(),
oth.shape()
)));
}
this - oth
}
LuaValue::Number(n) => this - (n as f32),
LuaValue::Integer(i) => this - (i as f32),
_ => return Err(LuaError::external("Expected either number or Tensor.")),
};
Ok(Tensor(new))
}
#[tracing::instrument(skip_all)]
pub fn mul(Tensor(this): &Tensor, other: LuaValue) -> Result<Tensor, LuaError> {
let new = match other {
LuaValue::UserData(user_data) => {
let Tensor(oth) = user_data.borrow::<Tensor>()?.to_owned();
if !is_broadcastable(this.shape(), oth.shape()) {
return Err(LuaError::external(format!(
"Shape {:?} not broadcastable to {:?}",
this.shape(),
oth.shape()
)));
}
this * oth
}
LuaValue::Number(n) => this * (n as f32),
LuaValue::Integer(i) => this * (i as f32),
_ => return Err(LuaError::external("Expected either number or Tensor.")),
};
Ok(Tensor(new))
}
#[tracing::instrument(skip_all)]
pub fn div(Tensor(this): &Tensor, other: LuaValue) -> Result<Tensor, LuaError> {
let new = match other {
LuaValue::UserData(user_data) => {
let Tensor(oth) = user_data.borrow::<Tensor>()?.to_owned();
if !is_broadcastable(oth.shape(), this.shape()) {
return Err(LuaError::external(format!(
"Shape {:?} not broadcastable to {:?}",
this.shape(),
oth.shape()
)));
}
this / oth
}
LuaValue::Number(n) => this / (n as f32),
LuaValue::Integer(i) => this / (i as f32),
_ => return Err(LuaError::external("Expected either number or Tensor.")),
};
Ok(Tensor(new))
}
#[cfg(test)]
mod tests {
use super::*;
use mlua::prelude::{Lua, LuaValue};
fn tensor(data: Vec<f32>, shape: &[usize]) -> Tensor {
Tensor(ndarray::ArrayD::from_shape_vec(shape, data).unwrap())
}
fn lua_number(n: f64) -> LuaValue {
LuaValue::Number(n)
}
fn lua_tensor(t: Tensor, lua: &Lua) -> LuaValue {
mlua::Value::UserData(lua.create_userdata(t).unwrap())
}
macro_rules! generate_ops_test {
($mod_name:ident, $op:tt, $rust_fn:ident, $lua_op:expr) => {
mod $mod_name {
#[test]
fn test_binding() {
use crate::transforms::tensor::load_env;
use super::Tensor;
use super::$rust_fn;
use ndarray::Array2;
use mlua::prelude::{LuaValue, LuaFunction};
let lua = load_env();
let arr1 = Tensor(Array2::<f32>::ones((3, 3)).into_dyn());
let arr2 = arr1.clone();
let gold_val = $rust_fn(
&arr1,
LuaValue::UserData(lua.create_userdata(arr2.clone()).unwrap())
).expect("Failed to compute");
let result: Tensor = lua.load(format!("return function(x, y) return x {} y end", $lua_op))
.eval::<LuaFunction>()
.unwrap()
.call((arr1, arr2))
.expect("Binding failed");
assert_eq!(result, gold_val);
}
#[test]
fn test_tensor() {
use crate::transforms::tensor::load_env;
use super::Tensor;
use ndarray::Array2;
use mlua::prelude::LuaValue;
use super::$rust_fn;
let lua = load_env();
let arr1 = Tensor(Array2::<f32>::ones((3, 3)).into_dyn());
let arr2 = arr1.clone();
let val = LuaValue::UserData(lua.create_userdata(arr1.clone()).unwrap());
let result = $rust_fn(&arr1, val).unwrap();
let gold = &arr1.0 $op &arr2.0;
assert_eq!(gold, result.0);
}
#[test]
fn test_number() {
use super::Tensor;
use ndarray::Array2;
use mlua::prelude::LuaValue;
use super::$rust_fn;
let arr1 = Tensor(Array2::<f32>::ones((3, 3)).into_dyn());
let gold_sum = &arr1.0 $op Array2::<f32>::from_elem((3, 3), 5.0);
let result = $rust_fn(&arr1, LuaValue::Number(5.0)).unwrap();
assert_eq!(gold_sum, result.0);
}
#[test]
fn test_integer() {
use super::Tensor;
use ndarray::Array2;
use mlua::prelude::LuaValue;
use super::$rust_fn;
let arr1 = Tensor(Array2::<f32>::ones((3, 3)).into_dyn());
let gold_sum = &arr1.0 $op Array2::<f32>::from_elem((3, 3), 5.0);
let result = $rust_fn(&arr1, LuaValue::Integer(5)).unwrap();
assert_eq!(gold_sum, result.0);
}
#[test]
fn test_bad_dtype() {
use super::Tensor;
use ndarray::Array2;
use mlua::prelude::{LuaValue, LuaError};
use super::$rust_fn;
let arr1 = Tensor(Array2::<f32>::ones((3, 3)).into_dyn());
let result: Result<Tensor, LuaError> = $rust_fn(&arr1, LuaValue::Boolean(false));
assert!(result.is_err());
}
}
}
}
generate_ops_test!(
test_addition, +, add, "+"
);
generate_ops_test!(
test_subtraction, -, sub, "-"
);
generate_ops_test!(
test_multiplication, *, mul, "*"
);
generate_ops_test!(
test_division, /, div, "/"
);
#[test]
fn test_add_broadcast_success() {
let lua = Lua::new();
let a = tensor(vec![1., 2., 3., 4., 5., 6.], &[2, 3]);
let b = tensor(vec![10., 20., 30.], &[3]);
let res = add(&a, lua_tensor(b, &lua)).unwrap();
assert_eq!(
res.0,
ndarray::arr2(&[[11., 22., 33.], [14., 25., 36.]]).into_dyn()
);
}
#[test]
fn test_add_broadcast_failure() {
let lua = Lua::new();
let a = tensor(vec![1., 2., 3., 4., 5., 6.], &[2, 3]);
let b = tensor(vec![1., 2.], &[2]);
let err = add(&a, lua_tensor(b, &lua)).unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("not broadcastable"), "Got: {msg}");
}
#[test]
fn test_sub_broadcast_success() {
let lua = Lua::new();
let a = tensor(vec![1., 2., 3.], &[3, 1]);
let b = tensor(vec![1., 10., 100.], &[3]);
let res = sub(&a, lua_tensor(b, &lua)).unwrap();
assert_eq!(
res.0,
ndarray::arr2(&[[0., -9., -99.], [1., -8., -98.], [2., -7., -97.]]).into_dyn()
);
}
#[test]
fn test_sub_broadcast_failure() {
let lua = Lua::new();
let a = tensor(vec![1., 2., 3., 4., 5., 6.], &[3, 2]);
let b = tensor(vec![1., 2., 3.], &[3]);
let err = sub(&a, lua_tensor(b, &lua)).unwrap_err();
assert!(format!("{err}").contains("not broadcastable"));
}
#[test]
fn test_mul_broadcast_success() {
let a = tensor(vec![1., 2., 3., 4., 5., 6.], &[2, 3]);
let res = mul(&a, lua_number(2.0)).unwrap();
assert_eq!(
res.0,
ndarray::arr2(&[[2., 4., 6.], [8., 10., 12.]]).into_dyn()
);
}
#[test]
fn test_mul_broadcast_shape_success() {
let lua = Lua::new();
let a = tensor(vec![1., 2., 3., 4.], &[4, 1]);
let b = tensor(vec![10., 20., 30.], &[1, 3]);
let res = mul(&a, lua_tensor(b, &lua)).unwrap();
assert_eq!(
res.0,
ndarray::arr2(&[
[10., 20., 30.],
[20., 40., 60.],
[30., 60., 90.],
[40., 80., 120.]
])
.into_dyn()
);
}
#[test]
fn test_mul_broadcast_fail() {
let lua = Lua::new();
let a = tensor(vec![1., 2., 3., 4.], &[2, 2]);
let b = tensor(vec![1., 2., 3.], &[3]);
let err = mul(&a, lua_tensor(b, &lua)).unwrap_err();
assert!(format!("{err}").contains("not broadcastable"));
}
#[test]
fn test_div_broadcast_success() {
let lua = Lua::new();
let a = tensor((1..=9).map(|x| x as f32).collect(), &[3, 3]);
let b = tensor(vec![1., 2., 3.], &[3]);
let res = div(&a, lua_tensor(b, &lua)).unwrap();
assert_eq!(
res.0,
ndarray::arr2(&[
[1.0 / 1., 2.0 / 2., 3.0 / 3.],
[4.0 / 1., 5.0 / 2., 6.0 / 3.],
[7.0 / 1., 8.0 / 2., 9.0 / 3.],
])
.into_dyn()
);
}
#[test]
fn test_div_broadcast_fail() {
let lua = Lua::new();
let a = tensor(vec![1., 2., 3., 4., 5., 6.], &[2, 3]);
let b = tensor(vec![1., 2.], &[2]);
let err = div(&a, lua_tensor(b, &lua)).unwrap_err();
assert!(format!("{err}").contains("not broadcastable"));
}
#[test]
fn test_exp() {
use ndarray::Array2;
let arr = Array2::ones((3, 3)).into_dyn();
let tensor = Tensor(arr.clone());
assert_eq!(tensor.exp().unwrap(), Tensor(arr.mapv(f32::exp)));
}
#[test]
fn test_exp_empty() {
let tensor = Tensor(ndarray::array![[[]]].into_dyn());
let Tensor(exp) = tensor.exp().unwrap();
assert!(exp.is_empty());
}
}