#![deny(
unsafe_code,
clippy::all,
clippy::pedantic,
clippy::nursery,
clippy::cargo,
clippy::module_name_repetitions,
clippy::pattern_type_mismatch,
clippy::shadow_unrelated,
clippy::missing_inline_in_public_items
)]
#[cfg(feature = "nn")]
pub mod nn;
mod graph;
mod tensor;
use graph::{node::Node, tape::Tape};
use tensor::variable::Variable;
pub use tensor::Tensor;
#[must_use]
#[inline]
pub fn fill<const B: u64, const L: u64, const R: u64, const C: u64>(
v: f32,
) -> Variable<B, L, R, C> {
let data = arrayfire::constant!(v; R,C,L,B);
Variable::new(Tape::default(), Node::declaration(data))
}
#[must_use]
#[inline]
pub fn eye<const B: u64, const L: u64, const R: u64, const C: u64>(v: f32) -> Variable<B, L, R, C> {
let data = v * arrayfire::identity::<f32>(arrayfire::dim4!(R, C, L, B));
Variable::new(Tape::default(), Node::declaration(data))
}
#[must_use]
#[inline]
pub fn randu<const B: u64, const L: u64, const R: u64, const C: u64>() -> Variable<B, L, R, C> {
let data = arrayfire::randu!(R, C, L, B);
Variable::new(Tape::default(), Node::declaration(data))
}
#[must_use]
#[inline]
pub fn randn<const B: u64, const L: u64, const R: u64, const C: u64>() -> Variable<B, L, R, C> {
let data = arrayfire::randn!(R, C, L, B);
Variable::new(Tape::default(), Node::declaration(data))
}
#[must_use]
#[inline]
pub fn custom<const B: u64, const L: u64, const R: u64, const C: u64>(
values: &[f32],
) -> Variable<B, L, R, C> {
let data = arrayfire::Array::new(values, arrayfire::dim4!(R, C, L, B));
Variable::new(Tape::default(), Node::declaration(data))
}
#[cfg(test)]
mod tests {
use crate as mu;
use arrayfire::{abs, all_true_all, constant, dim4, identity, le, Array};
use mu::Tensor;
pub(crate) fn equal_arrays(x: Array<f32>, y: Array<f32>) -> bool {
all_true_all(&le(&abs(&(x - y)), &1e-15, false)).0
}
#[test]
fn fill() {
let x = mu::fill::<1, 2, 3, 4>(2.0);
assert!(equal_arrays(x.data(), constant!(2.0; 3,4,2,1)));
}
#[test]
fn eye() {
let x = mu::eye::<1, 2, 3, 4>(2.0);
assert!(equal_arrays(
x.data(),
identity::<f32>(dim4!(3, 4, 2, 1)) * 2.0f32
));
}
#[test]
fn randu() {
let x = mu::randu::<1, 2, 3, 4>();
assert!(all_true_all(&le(&x.data(), &constant!(1.0; 3,4,2,1), false)).0)
}
#[test]
fn randn() {
let x = mu::randn::<1, 2, 3, 4>();
assert!(all_true_all(&le(&x.data(), &constant!(3.0; 3,4,2,1), false)).0)
}
#[test]
fn custom() {
let x = mu::custom::<1, 1, 1, 1>(&[1.0]);
assert!(equal_arrays(x.data(), constant!(1.0;1,1,1,1)));
}
}