pub(crate) mod cache;
pub(crate) mod cpu;
#[cfg(feature = "cuda")]
pub(crate) mod cuda;
mod ghost;
mod gradients;
mod masks;
#[cfg(feature = "numpy")]
pub(crate) mod numpy;
#[cfg(feature = "numpy")]
pub use numpy::NumpyDtype;
#[cfg(feature = "safetensors")]
pub mod safetensors;
mod tensorlike;
mod unique_id;
pub(crate) mod storage_traits;
mod tensor_impls;
pub(crate) use ghost::GhostTensor;
pub(crate) use storage_traits::{OneFillStorage, ZeroFillStorage};
pub(crate) use tensorlike::Tensorlike;
pub use cpu::{Cpu, CpuError};
#[cfg(not(feature = "cuda"))]
pub type AutoDevice = Cpu;
#[cfg(feature = "cuda")]
pub(crate) use cuda::launch_cfg;
#[cfg(feature = "cuda")]
pub use cuda::{Cuda, CudaError};
#[cfg(feature = "cuda")]
pub type AutoDevice = Cuda;
pub use storage_traits::{AsArray, CopySlice, TensorFrom, TensorFromVec, TensorToArray};
pub use storage_traits::{Cache, HasErr, RandomU64, Storage, Synchronize};
pub use storage_traits::{OnesTensor, SampleTensor, TriangleTensor, ZerosTensor};
pub use tensor_impls::{PutTape, SplitTape, Tensor, Trace, WithEmptyTape};
pub use tensor_impls::{Tensor0D, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, Tensor6D};
pub(crate) use unique_id::unique_id;
pub use unique_id::UniqueId;
pub use gradients::{Gradients, Merge, NoneTape, OwnedTape, Tape};
#[cfg(test)]
mod tests {
use super::*;
use crate::shapes::*;
use crate::tests::*;
use std::collections::HashSet;
#[test]
fn test_id() {
let dev: TestDevice = Default::default();
let mut ids: HashSet<UniqueId> = Default::default();
ids.insert(unique_id());
let x: Tensor<Rank0, f32, _> = dev.zeros();
assert!(!ids.contains(&x.id));
ids.insert(x.id);
let x: Tensor<Rank0, f32, _> = dev.zeros();
assert!(!ids.contains(&x.id));
ids.insert(x.id);
let x: Tensor<Rank1<5>, f32, _> = dev.zeros();
assert!(!ids.contains(&x.id));
ids.insert(x.id);
let x: Tensor<Rank2<3, 2>, f32, _> = dev.ones();
assert!(!ids.contains(&x.id));
ids.insert(x.id);
let x: Tensor<Rank3<4, 3, 2>, f32, _> = dev.sample(rand_distr::Standard);
assert!(!ids.contains(&x.id));
ids.insert(x.id);
}
#[test]
fn test_ids_with_clone() {
let dev: TestDevice = Default::default();
let t1: Tensor<Rank1<32>, f32, _> = dev.zeros();
let t2 = t1.clone();
assert_eq!(t1.id, t2.id);
}
#[test]
fn test_ids_with_split_and_put() {
let dev: TestDevice = Default::default();
let t1: Tensor<Rank1<32>, f32, _> = dev.zeros();
let t1_id = t1.id;
let (t2, tape) = t1.split_tape();
assert_eq!(t2.id, t1_id);
let t3 = t2.put_tape(tape);
assert_eq!(t3.id, t1_id);
}
#[test]
fn test_zeros() {
let dev: TestDevice = Default::default();
let x: Tensor<Rank2<3, 2>, f32, _> = dev.zeros();
assert_eq!(x.array(), [[0.0; 2]; 3]);
}
#[test]
fn test_ones() {
let dev: TestDevice = Default::default();
let x: Tensor<Rank2<3, 2>, f32, _> = dev.ones();
assert_eq!(x.array(), [[1.0; 2]; 3]);
}
#[test]
fn test_convert_array() {
let dev: TestDevice = Default::default();
let a = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let t = dev.tensor(a);
assert_eq!(t.array(), a);
}
#[test]
fn test_convert_slice() {
let dev: TestDevice = Default::default();
let data = [1.0, 2.0, 3.0, 4.0];
let mut t: Tensor<Rank2<2, 2>, f32, _> = dev.zeros();
t.copy_from(&data);
assert_eq!(t.array(), [[1.0, 2.0], [3.0, 4.0]]);
}
#[test]
fn fuzz_test_rand() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank1<1000>, f32, _> = dev.sample_uniform();
for v in t.as_vec() {
assert!((0.0..1.0).contains(&v));
}
}
#[test]
fn test_sample_normal() {
let dev: TestDevice = Default::default();
let _: Tensor<Rank1<1000>, f32, _> = dev.sample_normal();
}
#[test]
fn test_upper_tri() {
let dev: TestDevice = Default::default();
let a: TestDtype = NumCast::from(42.0).unwrap();
let z = TestDtype::zero();
assert_eq!(dev.upper_tri::<Rank0>(a, None).array(), a);
assert_eq!(dev.upper_tri::<Rank0>(a, 1).array(), z);
assert_eq!(dev.upper_tri::<Rank1<3>>(a, None).array(), [a, a, a]);
assert_eq!(dev.upper_tri::<Rank1<3>>(a, 1).array(), [z, a, a]);
assert_eq!(
dev.upper_tri::<Rank2<3, 4>>(a, None).array(),
[[a, a, a, a], [z, a, a, a], [z, z, a, a]]
);
assert_eq!(
dev.upper_tri::<Rank2<3, 1>>(a, None).array(),
[[a], [z], [z]]
);
assert_eq!(dev.upper_tri::<Rank2<3, 1>>(a, 1).array(), [[z], [z], [z]]);
assert_eq!(dev.upper_tri::<Rank2<3, 1>>(a, -1).array(), [[a], [a], [z]]);
assert_eq!(
dev.upper_tri::<Rank2<4, 4>>(a, -1).array(),
[[a, a, a, a], [a, a, a, a], [z, a, a, a], [z, z, a, a]]
);
assert_eq!(
dev.upper_tri::<Rank2<4, 4>>(a, -2).array(),
[[a, a, a, a], [a, a, a, a], [a, a, a, a], [z, a, a, a]]
);
assert_eq!(
dev.upper_tri::<Rank2<4, 3>>(a, 1).array(),
[[z, a, a], [z, z, a], [z, z, z], [z, z, z]]
);
assert_eq!(
dev.upper_tri::<Rank3<2, 5, 5>>(a, None).array(),
[[
[a, a, a, a, a],
[z, a, a, a, a],
[z, z, a, a, a],
[z, z, z, a, a],
[z, z, z, z, a]
]; 2]
);
assert_eq!(
dev.upper_tri::<Rank3<4, 5, 5>>(a, 2).array(),
[[
[z, z, a, a, a],
[z, z, z, a, a],
[z, z, z, z, a],
[z, z, z, z, z],
[z, z, z, z, z]
]; 4]
);
assert_eq!(
dev.upper_tri::<Rank4<3, 4, 5, 6>>(a, None).array(),
[[[
[a, a, a, a, a, a],
[z, a, a, a, a, a],
[z, z, a, a, a, a],
[z, z, z, a, a, a],
[z, z, z, z, a, a]
]; 4]; 3]
);
}
#[test]
fn test_lower_tri() {
let dev: TestDevice = Default::default();
let a: TestDtype = NumCast::from(42.0).unwrap();
let z = TestDtype::zero();
assert_eq!(dev.lower_tri::<Rank0>(a, None).array(), a);
assert_eq!(dev.lower_tri::<Rank0>(a, -1).array(), z);
assert_eq!(dev.lower_tri::<Rank1<3>>(a, None).array(), [a, z, z]);
assert_eq!(dev.lower_tri::<Rank1<3>>(a, 1).array(), [a, a, z]);
assert_eq!(
dev.lower_tri::<Rank2<3, 4>>(a, None).array(),
[[a, z, z, z], [a, a, z, z], [a, a, a, z]]
);
assert_eq!(
dev.lower_tri::<Rank2<3, 1>>(a, None).array(),
[[a], [a], [a]]
);
assert_eq!(dev.lower_tri::<Rank2<3, 1>>(a, 1).array(), [[a], [a], [a]]);
assert_eq!(dev.lower_tri::<Rank2<3, 1>>(a, -1).array(), [[z], [a], [a]]);
assert_eq!(
dev.lower_tri::<Rank2<4, 4>>(a, -1).array(),
[[z, z, z, z], [a, z, z, z], [a, a, z, z], [a, a, a, z]]
);
assert_eq!(
dev.lower_tri::<Rank2<4, 4>>(a, -2).array(),
[[z, z, z, z], [z, z, z, z], [a, z, z, z], [a, a, z, z]]
);
assert_eq!(
dev.lower_tri::<Rank2<4, 3>>(a, 1).array(),
[[a, a, z], [a, a, a], [a, a, a], [a, a, a]]
);
assert_eq!(
dev.lower_tri::<Rank3<2, 5, 5>>(a, None).array(),
[[
[a, z, z, z, z],
[a, a, z, z, z],
[a, a, a, z, z],
[a, a, a, a, z],
[a, a, a, a, a]
]; 2]
);
assert_eq!(
dev.lower_tri::<Rank3<4, 5, 5>>(a, 2).array(),
[[
[a, a, a, z, z],
[a, a, a, a, z],
[a, a, a, a, a],
[a, a, a, a, a],
[a, a, a, a, a]
]; 4]
);
assert_eq!(
dev.lower_tri::<Rank4<3, 4, 5, 6>>(a, None).array(),
[[[
[a, z, z, z, z, z],
[a, a, z, z, z, z],
[a, a, a, z, z, z],
[a, a, a, a, z, z],
[a, a, a, a, a, z]
]; 4]; 3]
);
}
}