use dfdx::{
shapes::{Const, Rank1, Rank2, Rank3},
tensor::{AsArray, OnesTensor, SampleTensor, Tensor, TensorFrom, ZerosTensor},
tensor_ops::RealizeTo,
};
#[cfg(not(feature = "cuda"))]
type Device = dfdx::tensor::Cpu;
#[cfg(feature = "cuda")]
type Device = dfdx::tensor::Cuda;
fn main() {
let dev: Device = Device::default();
let _: Tensor<Rank1<5>, f32, Device> = dev.tensor([1.0, 2.0, 3.0, 4.0, 5.0]);
let _: Tensor<Rank2<2, 3>, f32, _> = dev.zeros();
let _: Tensor<Rank3<1, 2, 3>, f32, _> = dev.ones();
let _: Tensor<Rank2<2, 3>, f32, _> = dev.zeros_like(&(Const::<2>, Const::<3>));
let _: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(2, Const::<3>));
let _: Tensor<(usize, usize), f32, _> = dev.zeros_like(&(2, 4));
let _: Tensor<(usize, usize, usize), f32, _> = dev.ones_like(&(3, 4, 5));
let _: Tensor<(usize, usize, Const<5>), f32, _> = dev.ones_like(&(3, 4, Const));
let a: Tensor<(usize, usize), f32, _> = dev.zeros_like(&(2, 3));
let _: Tensor<(usize, Const<3>), f32, _> = a.try_realize().expect("`a` should have 3 columns");
let _: Tensor<Rank2<2, 3>, f64, _> = dev.zeros();
let _: Tensor<Rank2<2, 3>, usize, _> = dev.zeros();
let _: Tensor<Rank2<2, 3>, i16, _> = dev.zeros();
let _: Tensor<Rank3<2, 3, 4>, f32, Device> = dev.sample_normal();
let _: Tensor<Rank3<2, 3, 4>, f32, Device> = dev.sample_uniform();
let a: Tensor<Rank3<2, 3, 4>, f32, Device> = dev.sample(rand_distr::Uniform::new(-1.0, 1.0));
let _: Tensor<(usize, usize), f32, _> = dev.sample_uniform_like(&(1, 2));
let _: Tensor<(usize, usize, usize), f32, _> = dev.sample_normal_like(&(1, 2, 3));
let _: Tensor<(usize, usize, usize, usize), u64, _> =
dev.sample_like(&(1, 2, 3, 4), rand_distr::StandardGeometric);
let _: Tensor<Rank3<1, 2, 3>, f32, _> = dev.tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let _: Tensor<(usize,), f32, _> = dev.tensor((vec![1.0, 2.0], (2,)));
let a_data: [[[f32; 4]; 3]; 2] = a.array();
println!("a={a_data:?}");
let a_data: Vec<f32> = a.as_vec();
println!("a={a_data:?}");
let a_copy = a.clone();
assert_eq!(a_copy.array(), a.array());
}