mod cuda;
mod ops;
mod shape;
mod nn_ops;
pub use cuda::*;
pub use nn_ops::RnnParams;
use std::ffi::{c_void, CStr};
use std::fmt;
use std::ptr;
use std::sync::atomic::{AtomicU64, Ordering};
use flodl_sys::{self as ffi, FlodlTensor};
pub(super) static LIVE_TENSOR_COUNT: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(i32)]
pub enum DType {
Float16 = ffi::FLODL_FLOAT16,
BFloat16 = ffi::FLODL_BFLOAT16,
Float32 = ffi::FLODL_FLOAT32,
Float64 = ffi::FLODL_FLOAT64,
Int32 = ffi::FLODL_INT32,
Int64 = ffi::FLODL_INT64,
}
impl DType {
fn from_raw(v: i32) -> Self {
match v {
ffi::FLODL_FLOAT16 => DType::Float16,
ffi::FLODL_BFLOAT16 => DType::BFloat16,
ffi::FLODL_FLOAT32 => DType::Float32,
ffi::FLODL_FLOAT64 => DType::Float64,
ffi::FLODL_INT32 => DType::Int32,
ffi::FLODL_INT64 => DType::Int64,
_ => DType::Float32,
}
}
pub fn element_size(self) -> usize {
match self {
DType::Float16 | DType::BFloat16 => 2,
DType::Float32 | DType::Int32 => 4,
DType::Float64 | DType::Int64 => 8,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Device {
CPU,
CUDA(u8),
}
impl Device {
pub(crate) fn to_ffi(self) -> (i32, i32) {
match self {
Device::CPU => (ffi::FLODL_CPU, 0),
Device::CUDA(idx) => (ffi::FLODL_CUDA, idx as i32),
}
}
pub(crate) fn from_ffi(device_type: i32, device_index: i32) -> Self {
match device_type {
ffi::FLODL_CUDA => Device::CUDA(device_index as u8),
_ => Device::CPU,
}
}
pub fn is_cuda(&self) -> bool {
matches!(self, Device::CUDA(_))
}
pub fn index(&self) -> u8 {
match self {
Device::CPU => 0,
Device::CUDA(idx) => *idx,
}
}
}
impl fmt::Display for Device {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Device::CPU => write!(f, "cpu"),
Device::CUDA(0) => write!(f, "cuda"),
Device::CUDA(idx) => write!(f, "cuda:{}", idx),
}
}
}
#[derive(Debug, Clone)]
pub struct TensorError(String);
impl TensorError {
pub fn new(msg: &str) -> Self {
TensorError(msg.to_string())
}
pub fn is_cuda_oom(&self) -> bool {
self.0.contains("out of memory")
}
}
impl fmt::Display for TensorError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for TensorError {}
pub type Result<T> = std::result::Result<T, TensorError>;
pub(crate) fn check_err(err: *mut i8) -> Result<()> {
if err.is_null() {
Ok(())
} else {
let msg = unsafe { CStr::from_ptr(err) }
.to_string_lossy()
.into_owned();
unsafe { ffi::flodl_free_string(err) };
Err(TensorError(msg))
}
}
#[derive(Debug, Clone, Copy)]
pub struct TensorOptions {
pub dtype: DType,
pub device: Device,
}
impl Default for TensorOptions {
fn default() -> Self {
Self {
dtype: DType::Float32,
device: Device::CPU,
}
}
}
pub struct Tensor {
pub(crate) handle: FlodlTensor,
}
unsafe impl Send for Tensor {}
unsafe impl Sync for Tensor {}
impl Drop for Tensor {
fn drop(&mut self) {
if !self.handle.is_null() {
LIVE_TENSOR_COUNT.fetch_sub(1, Ordering::Relaxed);
unsafe { ffi::flodl_free_tensor(self.handle) };
}
}
}
impl Clone for Tensor {
fn clone(&self) -> Self {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_shallow_clone(self.handle, &mut handle) };
if !err.is_null() {
let msg = unsafe { CStr::from_ptr(err) }
.to_string_lossy()
.into_owned();
unsafe { ffi::flodl_free_string(err) };
panic!("tensor clone failed: {}", msg);
}
Self::from_raw(handle)
}
}
impl Tensor {
pub(crate) fn from_raw(handle: FlodlTensor) -> Self {
debug_assert!(!handle.is_null());
LIVE_TENSOR_COUNT.fetch_add(1, Ordering::Relaxed);
Self { handle }
}
pub(crate) unsafe fn from_raw_handle(handle: FlodlTensor) -> Self {
Self::from_raw(handle)
}
pub(crate) fn raw(&self) -> FlodlTensor {
self.handle
}
pub fn zeros(shape: &[i64], opts: TensorOptions) -> Result<Self> {
let mut shape = shape.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let (dt, di) = opts.device.to_ffi();
let err = unsafe {
ffi::flodl_zeros(
shape.as_mut_ptr(),
shape.len() as i32,
opts.dtype as i32,
dt, di,
&mut handle,
)
};
check_err(err)?;
Ok(Self::from_raw(handle))
}
pub fn ones(shape: &[i64], opts: TensorOptions) -> Result<Self> {
let mut shape = shape.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let (dt, di) = opts.device.to_ffi();
let err = unsafe {
ffi::flodl_ones(
shape.as_mut_ptr(),
shape.len() as i32,
opts.dtype as i32,
dt, di,
&mut handle,
)
};
check_err(err)?;
Ok(Self::from_raw(handle))
}
pub fn from_f32(data: &[f32], shape: &[i64], device: Device) -> Result<Self> {
let mut shape = shape.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let (dt, di) = device.to_ffi();
let err = unsafe {
ffi::flodl_from_blob(
data.as_ptr() as *mut c_void,
shape.as_mut_ptr(),
shape.len() as i32,
DType::Float32 as i32,
dt, di,
&mut handle,
)
};
check_err(err)?;
Ok(Self::from_raw(handle))
}
pub fn from_f64(data: &[f64], shape: &[i64], device: Device) -> Result<Self> {
let mut shape = shape.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let (dt, di) = device.to_ffi();
let err = unsafe {
ffi::flodl_from_blob(
data.as_ptr() as *mut c_void,
shape.as_mut_ptr(),
shape.len() as i32,
DType::Float64 as i32,
dt, di,
&mut handle,
)
};
check_err(err)?;
Ok(Self::from_raw(handle))
}
pub fn from_i64(data: &[i64], shape: &[i64], device: Device) -> Result<Self> {
let mut shape = shape.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let (dt, di) = device.to_ffi();
let err = unsafe {
ffi::flodl_from_blob(
data.as_ptr() as *mut c_void,
shape.as_mut_ptr(),
shape.len() as i32,
DType::Int64 as i32,
dt, di,
&mut handle,
)
};
check_err(err)?;
Ok(Self::from_raw(handle))
}
pub fn zeros_like(t: &Tensor) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_zeros_like(t.handle, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn ones_like(t: &Tensor) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_ones_like(t.handle, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn full_like(t: &Tensor, value: f64) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_full_like(t.handle, value, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn rand_like(t: &Tensor) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_rand_like(t.handle, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn randn_like(t: &Tensor) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_randn_like(t.handle, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn rand(shape: &[i64], opts: TensorOptions) -> Result<Self> {
let mut shape = shape.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let (dt, di) = opts.device.to_ffi();
let err = unsafe {
ffi::flodl_rand(
shape.as_mut_ptr(), shape.len() as i32,
opts.dtype as i32, dt, di,
&mut handle,
)
};
check_err(err)?;
Ok(Self::from_raw(handle))
}
pub fn randn(shape: &[i64], opts: TensorOptions) -> Result<Self> {
let mut shape = shape.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let (dt, di) = opts.device.to_ffi();
let err = unsafe {
ffi::flodl_randn(
shape.as_mut_ptr(), shape.len() as i32,
opts.dtype as i32, dt, di,
&mut handle,
)
};
check_err(err)?;
Ok(Self::from_raw(handle))
}
pub fn linspace(start: f64, end: f64, steps: i64, opts: TensorOptions) -> Result<Self> {
let mut handle: FlodlTensor = ptr::null_mut();
let (dt, di) = opts.device.to_ffi();
let err = unsafe {
ffi::flodl_linspace(start, end, steps, opts.dtype as i32, dt, di, &mut handle)
};
check_err(err)?;
Ok(Self::from_raw(handle))
}
pub fn arange(start: f64, end: f64, step: f64, opts: TensorOptions) -> Result<Self> {
let mut handle: FlodlTensor = ptr::null_mut();
let (dt, di) = opts.device.to_ffi();
let err = unsafe {
ffi::flodl_arange(start, end, step, opts.dtype as i32, dt, di, &mut handle)
};
check_err(err)?;
Ok(Self::from_raw(handle))
}
pub fn eye(n: i64, opts: TensorOptions) -> Result<Self> {
let mut handle: FlodlTensor = ptr::null_mut();
let (dt, di) = opts.device.to_ffi();
let err = unsafe {
ffi::flodl_eye(n, opts.dtype as i32, dt, di, &mut handle)
};
check_err(err)?;
Ok(Self::from_raw(handle))
}
pub fn full(shape: &[i64], value: f64, opts: TensorOptions) -> Result<Self> {
let mut shape = shape.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let (dt, di) = opts.device.to_ffi();
let err = unsafe {
ffi::flodl_full(
shape.as_mut_ptr(), shape.len() as i32, value,
opts.dtype as i32, dt, di, &mut handle,
)
};
check_err(err)?;
Ok(Self::from_raw(handle))
}
pub fn randperm(n: i64, opts: TensorOptions) -> Result<Self> {
let mut handle: FlodlTensor = ptr::null_mut();
let (dt, di) = opts.device.to_ffi();
let err = unsafe {
ffi::flodl_randperm(n, opts.dtype as i32, dt, di, &mut handle)
};
check_err(err)?;
Ok(Self::from_raw(handle))
}
pub fn randint(low: i64, high: i64, shape: &[i64], opts: TensorOptions) -> Result<Self> {
let mut shape = shape.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let (dt, di) = opts.device.to_ffi();
let err = unsafe {
ffi::flodl_randint(
low, high,
shape.as_mut_ptr(), shape.len() as i32,
opts.dtype as i32, dt, di,
&mut handle,
)
};
check_err(err)?;
Ok(Self::from_raw(handle))
}
pub fn empty(shape: &[i64], opts: TensorOptions) -> Result<Self> {
let mut shape = shape.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let (dt, di) = opts.device.to_ffi();
let err = unsafe {
ffi::flodl_empty(
shape.as_mut_ptr(), shape.len() as i32,
opts.dtype as i32, dt, di,
&mut handle,
)
};
check_err(err)?;
Ok(Self::from_raw(handle))
}
pub fn one_hot(&self, num_classes: i64) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_one_hot(self.handle, num_classes, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn bernoulli(&self) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_bernoulli(self.handle, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn ndim(&self) -> usize {
unsafe { ffi::flodl_ndim(self.handle) as usize }
}
pub fn shape(&self) -> Vec<i64> {
let n = self.ndim();
(0..n)
.map(|i| unsafe { ffi::flodl_shape(self.handle, i as i32) })
.collect()
}
pub fn numel(&self) -> i64 {
unsafe { ffi::flodl_numel(self.handle) }
}
pub fn nbytes(&self) -> usize {
self.numel() as usize * self.dtype().element_size()
}
pub fn dtype(&self) -> DType {
DType::from_raw(unsafe { ffi::flodl_dtype(self.handle) })
}
pub fn device(&self) -> Device {
let dt = unsafe { ffi::flodl_device_type(self.handle) };
let di = unsafe { ffi::flodl_device_index(self.handle) };
Device::from_ffi(dt, di)
}
pub fn to_f32_vec(&self) -> Result<Vec<f32>> {
let n = self.numel() as usize;
let mut buf = vec![0f32; n];
let bytes = (n * 4) as i64;
let err = unsafe {
ffi::flodl_copy_data(self.handle, buf.as_mut_ptr() as *mut c_void, bytes)
};
check_err(err)?;
Ok(buf)
}
pub fn to_f64_vec(&self) -> Result<Vec<f64>> {
if self.dtype() == DType::Float64 {
let n = self.numel() as usize;
let mut buf = vec![0.0f64; n];
let bytes = (n * 8) as i64;
let err = unsafe {
ffi::flodl_copy_data(self.handle, buf.as_mut_ptr() as *mut c_void, bytes)
};
check_err(err)?;
Ok(buf)
} else {
let f32s = self.to_f32_vec()?;
Ok(f32s.into_iter().map(|v| v as f64).collect())
}
}
pub fn to_i64_vec(&self) -> Result<Vec<i64>> {
let n = self.numel() as usize;
let mut buf = vec![0i64; n];
let bytes = (n * 8) as i64;
let err = unsafe {
ffi::flodl_copy_data(self.handle, buf.as_mut_ptr() as *mut c_void, bytes)
};
check_err(err)?;
Ok(buf)
}
pub fn item(&self) -> Result<f64> {
if self.numel() != 1 {
return Err(TensorError::new(&format!(
"item() requires exactly 1 element, got {} (shape {:?})",
self.numel(), self.shape()
)));
}
if self.dtype() == DType::Float64 {
let mut buf = [0.0f64; 1];
let err = unsafe {
ffi::flodl_copy_data(self.handle, buf.as_mut_ptr() as *mut c_void, 8)
};
check_err(err)?;
Ok(buf[0])
} else {
let mut buf = [0.0f32; 1];
let err = unsafe {
ffi::flodl_copy_data(self.handle, buf.as_mut_ptr() as *mut c_void, 4)
};
check_err(err)?;
Ok(buf[0] as f64)
}
}
pub fn to_device(&self, device: Device) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let (dt, di) = device.to_ffi();
let err = unsafe { ffi::flodl_to_device(self.handle, dt, di, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn to_device_of(&self, other: &Tensor) -> Result<Tensor> {
let target = other.device();
if self.device() == target {
return Ok(self.clone());
}
self.to_device(target)
}
pub fn to_device_async(&self, device: Device) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let (dt, di) = device.to_ffi();
let err = unsafe { ffi::flodl_to_device_async(self.handle, dt, di, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn set_requires_grad(&self, requires_grad: bool) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_set_requires_grad(self.handle, requires_grad as i32, &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn requires_grad(&self) -> bool {
unsafe { ffi::flodl_requires_grad(self.handle) != 0 }
}
pub fn backward(&self) -> Result<()> {
let err = unsafe { ffi::flodl_backward(self.handle) };
check_err(err)
}
pub fn grad(&self) -> Option<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_grad(self.handle, &mut handle) };
if !err.is_null() {
unsafe { ffi::flodl_free_string(err) };
return None;
}
if handle.is_null() {
None
} else {
Some(Tensor::from_raw(handle))
}
}
pub fn set_grad(&self, grad: &Tensor) -> Result<()> {
let err = unsafe { ffi::flodl_set_grad(self.handle, grad.handle) };
check_err(err)
}
pub fn zero_grad(&self) -> Result<()> {
let err = unsafe { ffi::flodl_zero_grad(self.handle) };
check_err(err)
}
pub fn zero_grad_set_to_none(&self) {
unsafe { ffi::flodl_zero_grad_set_to_none(self.handle) }
}
pub fn clip_grad_norm_fused(params: &[Tensor], max_norm: f64) -> Result<f64> {
if params.is_empty() {
return Ok(0.0);
}
let mut handles: Vec<FlodlTensor> = params.iter().map(|t| t.handle).collect();
let mut total_norm: f64 = 0.0;
let err = unsafe {
ffi::flodl_clip_grad_norm(
handles.as_mut_ptr(),
handles.len() as i32,
max_norm,
&mut total_norm,
)
};
check_err(err)?;
Ok(total_norm)
}
pub fn is_leaf(&self) -> bool {
unsafe { ffi::flodl_is_leaf(self.handle) != 0 }
}
pub fn ensure_grad_accumulator(&self) -> Result<Option<GradAccumulatorHandle>> {
let mut handle: *mut std::ffi::c_void = std::ptr::null_mut();
let err = unsafe { ffi::flodl_ensure_grad_accumulator(self.handle, &mut handle) };
check_err(err)?;
if handle.is_null() {
Ok(None)
} else {
Ok(Some(GradAccumulatorHandle { handle }))
}
}
pub fn autograd_node_count(&self) -> i64 {
unsafe { ffi::flodl_autograd_node_count(self.handle) }
}
pub fn detach(&self) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_detach(self.handle, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn detach_(&self) -> Result<()> {
let err = unsafe { ffi::flodl_detach_(self.handle) };
check_err(err)
}
pub fn add_(&self, other: &Tensor) -> Result<()> {
let err = unsafe { ffi::flodl_add_(self.handle, other.handle) };
check_err(err)
}
pub fn sub_(&self, other: &Tensor) -> Result<()> {
let err = unsafe { ffi::flodl_sub_(self.handle, other.handle) };
check_err(err)
}
pub fn mul_scalar_(&self, scalar: f64) -> Result<()> {
let err = unsafe { ffi::flodl_mul_scalar_(self.handle, scalar) };
check_err(err)
}
pub fn add_scalar_(&self, scalar: f64) -> Result<()> {
let err = unsafe { ffi::flodl_add_scalar_(self.handle, scalar) };
check_err(err)
}
pub fn zero_(&self) -> Result<()> {
let err = unsafe { ffi::flodl_zero_(self.handle) };
check_err(err)
}
pub fn mul_(&self, other: &Tensor) -> Result<()> {
let err = unsafe { ffi::flodl_mul_(self.handle, other.handle) };
check_err(err)
}
pub fn div_scalar_(&self, scalar: f64) -> Result<()> {
let err = unsafe { ffi::flodl_div_scalar_(self.handle, scalar) };
check_err(err)
}
pub fn div_(&self, other: &Tensor) -> Result<()> {
let err = unsafe { ffi::flodl_div_(self.handle, other.handle) };
check_err(err)
}
pub fn fill_(&self, value: f64) -> Result<()> {
let err = unsafe { ffi::flodl_fill_(self.handle, value) };
check_err(err)
}
pub fn copy_(&self, src: &Tensor, non_blocking: bool) -> Result<()> {
let err = unsafe { ffi::flodl_copy_(self.handle, src.handle, non_blocking as i32) };
check_err(err)
}
#[allow(clippy::too_many_arguments)]
pub fn adam_step(
&self, grad: &Tensor, m: &Tensor, v: &Tensor,
lr: f64, beta1: f64, beta2: f64, eps: f64,
weight_decay: f64, step: i64,
) -> Result<()> {
let err = unsafe {
ffi::flodl_adam_step(
self.handle, grad.handle, m.handle, v.handle,
lr, beta1, beta2, eps, weight_decay, step,
)
};
check_err(err)
}
#[allow(clippy::too_many_arguments)]
pub fn adam_step_batched(
params: &[Tensor], grads: &[Tensor], ms: &[Tensor], vs: &[Tensor],
lrs: &mut [f64], beta1: f64, beta2: f64, eps: f64,
weight_decay: f64, step: i64,
) -> Result<()> {
let count = params.len() as i32;
let mut p_handles: Vec<FlodlTensor> = params.iter().map(|t| t.handle).collect();
let mut g_handles: Vec<FlodlTensor> = grads.iter().map(|t| t.handle).collect();
let mut m_handles: Vec<FlodlTensor> = ms.iter().map(|t| t.handle).collect();
let mut v_handles: Vec<FlodlTensor> = vs.iter().map(|t| t.handle).collect();
let err = unsafe {
ffi::flodl_adam_step_batched(
p_handles.as_mut_ptr(), g_handles.as_mut_ptr(),
m_handles.as_mut_ptr(), v_handles.as_mut_ptr(),
lrs.as_mut_ptr(), count,
beta1, beta2, eps, weight_decay, step,
)
};
check_err(err)
}
#[allow(clippy::too_many_arguments)]
pub fn fused_adam_(
params: &[Tensor], grads: &[Tensor], exp_avgs: &[Tensor], exp_avg_sqs: &[Tensor],
lr: f64, beta1: f64, beta2: f64, eps: f64,
weight_decay: f64, step: i64,
grad_scale: Option<&Tensor>, found_inf: Option<&Tensor>,
) -> Result<()> {
if params.is_empty() { return Ok(()); }
let count = params.len() as i32;
let mut p = Self::handles(params);
let mut g = Self::handles(grads);
let mut m = Self::handles(exp_avgs);
let mut v = Self::handles(exp_avg_sqs);
let gs = grad_scale.map_or(ptr::null_mut(), |t| t.handle);
let fi = found_inf.map_or(ptr::null_mut(), |t| t.handle);
let err = unsafe {
ffi::flodl_fused_adam_(
p.as_mut_ptr(), g.as_mut_ptr(), m.as_mut_ptr(), v.as_mut_ptr(),
count, lr, beta1, beta2, eps, weight_decay, step, gs, fi,
)
};
check_err(err)
}
#[allow(clippy::too_many_arguments)]
pub fn fused_adamw_(
params: &[Tensor], grads: &[Tensor], exp_avgs: &[Tensor], exp_avg_sqs: &[Tensor],
lr: f64, beta1: f64, beta2: f64, eps: f64,
weight_decay: f64, step: i64,
grad_scale: Option<&Tensor>, found_inf: Option<&Tensor>,
) -> Result<()> {
if params.is_empty() { return Ok(()); }
let count = params.len() as i32;
let mut p = Self::handles(params);
let mut g = Self::handles(grads);
let mut m = Self::handles(exp_avgs);
let mut v = Self::handles(exp_avg_sqs);
let gs = grad_scale.map_or(ptr::null_mut(), |t| t.handle);
let fi = found_inf.map_or(ptr::null_mut(), |t| t.handle);
let err = unsafe {
ffi::flodl_fused_adamw_(
p.as_mut_ptr(), g.as_mut_ptr(), m.as_mut_ptr(), v.as_mut_ptr(),
count, lr, beta1, beta2, eps, weight_decay, step, gs, fi,
)
};
check_err(err)
}
fn handles(tensors: &[Tensor]) -> Vec<FlodlTensor> {
tensors.iter().map(|t| t.handle).collect()
}
pub fn foreach_add_scalar_(tensors: &[Tensor], scalar: f64) -> Result<()> {
if tensors.is_empty() { return Ok(()); }
let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
let err = unsafe {
ffi::flodl_foreach_add_scalar_(handles.as_mut_ptr(), handles.len() as i32, scalar)
};
check_err(err)
}
pub fn foreach_mul_scalar_(tensors: &[Tensor], scalar: f64) -> Result<()> {
if tensors.is_empty() { return Ok(()); }
let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
let err = unsafe {
ffi::flodl_foreach_mul_scalar_(handles.as_mut_ptr(), handles.len() as i32, scalar)
};
check_err(err)
}
pub fn foreach_zero_(tensors: &[Tensor]) -> Result<()> {
if tensors.is_empty() { return Ok(()); }
let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
let err = unsafe {
ffi::flodl_foreach_zero_(handles.as_mut_ptr(), handles.len() as i32)
};
check_err(err)
}
pub fn foreach_add_list_(tensors1: &[Tensor], tensors2: &[Tensor], alpha: f64) -> Result<()> {
if tensors1.is_empty() { return Ok(()); }
assert_eq!(tensors1.len(), tensors2.len(), "foreach_add_list_: list length mismatch");
let mut h1: Vec<FlodlTensor> = tensors1.iter().map(|t| t.handle).collect();
let mut h2: Vec<FlodlTensor> = tensors2.iter().map(|t| t.handle).collect();
let err = unsafe {
ffi::flodl_foreach_add_list_(
h1.as_mut_ptr(), h2.as_mut_ptr(), h1.len() as i32, alpha,
)
};
check_err(err)
}
pub fn foreach_norm(tensors: &[Tensor], ord: f64) -> Result<Vec<Tensor>> {
if tensors.is_empty() { return Ok(vec![]); }
let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
let mut results: Vec<FlodlTensor> = vec![ptr::null_mut(); tensors.len()];
let err = unsafe {
ffi::flodl_foreach_norm(
handles.as_mut_ptr(), handles.len() as i32, ord,
results.as_mut_ptr(),
)
};
check_err(err)?;
Ok(results.into_iter().map(Tensor::from_raw).collect())
}
pub fn foreach_lerp_scalar_(tensors1: &[Tensor], tensors2: &[Tensor], weight: f64) -> Result<()> {
if tensors1.is_empty() { return Ok(()); }
assert_eq!(tensors1.len(), tensors2.len(), "foreach_lerp_scalar_: list length mismatch");
let mut h1: Vec<FlodlTensor> = tensors1.iter().map(|t| t.handle).collect();
let mut h2: Vec<FlodlTensor> = tensors2.iter().map(|t| t.handle).collect();
let err = unsafe {
ffi::flodl_foreach_lerp_scalar_(
h1.as_mut_ptr(), h2.as_mut_ptr(), h1.len() as i32, weight,
)
};
check_err(err)
}
pub fn foreach_sqrt_(tensors: &[Tensor]) -> Result<()> {
if tensors.is_empty() { return Ok(()); }
let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
let err = unsafe {
ffi::flodl_foreach_sqrt_(handles.as_mut_ptr(), handles.len() as i32)
};
check_err(err)
}
pub fn pin_memory(&self) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_pin_memory(self.handle, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn is_pinned(&self) -> bool {
unsafe { ffi::flodl_is_pinned(self.handle) != 0 }
}
pub fn to_channels_last(&self) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_to_channels_last(self.handle, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn is_channels_last(&self) -> bool {
unsafe { ffi::flodl_is_channels_last(self.handle) != 0 }
}
pub fn is_contiguous(&self) -> bool {
unsafe { ffi::flodl_is_contiguous(self.handle) != 0 }
}
}
pub struct GradAccumulatorHandle {
handle: *mut std::ffi::c_void,
}
unsafe impl Send for GradAccumulatorHandle {}
unsafe impl Sync for GradAccumulatorHandle {}
impl Drop for GradAccumulatorHandle {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { ffi::flodl_grad_accumulator_delete(self.handle) };
self.handle = std::ptr::null_mut();
}
}
}
impl fmt::Debug for Tensor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Tensor({:?}, {:?}, {:?})",
self.shape(),
self.dtype(),
self.device()
)
}
}
#[cfg(test)]
pub fn test_device() -> Device {
use std::sync::Once;
static PRINT: Once = Once::new();
let dev = if cfg!(feature = "cuda") && cuda_available() { Device::CUDA(0) } else { Device::CPU };
PRINT.call_once(|| eprintln!("\n*** flodl test device: {} ***\n", dev));
dev
}
#[cfg(test)]
pub fn test_opts() -> TensorOptions {
TensorOptions { dtype: DType::Float32, device: test_device() }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zeros() {
let t = Tensor::zeros(&[2, 3], test_opts()).unwrap();
assert_eq!(t.shape(), vec![2, 3]);
assert_eq!(t.dtype(), DType::Float32);
assert_eq!(t.device(), test_device());
assert_eq!(t.numel(), 6);
let data = t.to_f32_vec().unwrap();
assert_eq!(data, vec![0.0; 6]);
}
#[test]
fn test_nbytes() {
let f32_t = Tensor::zeros(&[2, 3], test_opts()).unwrap();
assert_eq!(f32_t.nbytes(), 6 * 4);
let f64_t = Tensor::zeros(&[2, 3], TensorOptions { dtype: DType::Float64, device: test_device() }).unwrap();
assert_eq!(f64_t.nbytes(), 6 * 8);
let i64_t = Tensor::from_i64(&[1, 2, 3], &[3], test_device()).unwrap();
assert_eq!(i64_t.nbytes(), 3 * 8); }
#[test]
fn test_from_f32() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
assert_eq!(t.shape(), vec![3]);
let data = t.to_f32_vec().unwrap();
assert_eq!(data, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_drop_frees_memory() {
let _ = Tensor::zeros(&[1000, 1000], test_opts()).unwrap();
}
#[test]
fn test_debug_format() {
let t = Tensor::zeros(&[2, 3], test_opts()).unwrap();
let s = format!("{:?}", t);
assert!(s.contains("[2, 3]"));
assert!(s.contains("Float32"));
}
#[test]
fn test_ones_from_f64_from_i64() {
let o = Tensor::ones(&[2, 3], test_opts()).unwrap();
assert_eq!(o.to_f32_vec().unwrap(), vec![1.0; 6]);
let f = Tensor::from_f64(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
assert_eq!(f.dtype(), DType::Float64);
assert_eq!(f.to_f64_vec().unwrap(), vec![1.0, 2.0, 3.0]);
let i = Tensor::from_i64(&[10, 20, 30], &[3], test_device()).unwrap();
assert_eq!(i.dtype(), DType::Int64);
assert_eq!(i.to_i64_vec().unwrap(), vec![10, 20, 30]);
}
#[test]
fn test_eye_full() {
let eye = Tensor::eye(3, test_opts()).unwrap();
assert_eq!(eye.shape(), vec![3, 3]);
let data = eye.to_f32_vec().unwrap();
assert_eq!(data, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
let f = Tensor::full(&[2, 3], 7.0, test_opts()).unwrap();
assert_eq!(f.shape(), vec![2, 3]);
assert_eq!(f.to_f32_vec().unwrap(), vec![7.0; 6]);
}
#[test]
fn test_zeros_like_ones_like() {
let t = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
let zl = Tensor::zeros_like(&t).unwrap();
assert_eq!(zl.to_f32_vec().unwrap(), vec![0.0, 0.0]);
assert_eq!(zl.dtype(), DType::Float32);
let ol = Tensor::ones_like(&t).unwrap();
assert_eq!(ol.to_f32_vec().unwrap(), vec![1.0, 1.0]);
}
#[test]
fn test_from_i64_device() {
let t = Tensor::from_i64(&[1, 2, 3], &[3], test_device()).unwrap();
assert_eq!(t.device(), test_device());
assert_eq!(t.dtype(), DType::Int64);
assert_eq!(t.to_i64_vec().unwrap(), vec![1, 2, 3]);
}
#[test]
fn test_pin_memory() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], Device::CPU).unwrap();
assert!(!t.is_pinned(), "regular CPU tensor should not be pinned");
if cuda_available() {
let pinned = t.pin_memory().unwrap();
assert!(pinned.is_pinned(), "pin_memory() result should be pinned");
assert_eq!(pinned.device(), Device::CPU, "pinned tensor should stay on CPU");
assert_eq!(pinned.to_f32_vec().unwrap(), vec![1.0, 2.0, 3.0],
"data should be preserved after pinning");
} else {
assert!(t.pin_memory().is_err(),
"pin_memory should fail without CUDA");
}
}
#[test]
fn test_channels_last() {
let t = Tensor::randn(&[1, 3, 4, 4], test_opts()).unwrap();
assert!(!t.is_channels_last());
let cl = t.to_channels_last().unwrap();
assert!(cl.is_channels_last());
assert_eq!(cl.shape(), vec![1, 3, 4, 4]); }
#[test]
fn test_adam_step_basic() {
let param = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
let grad = Tensor::from_f32(&[0.5, 0.5], &[2], test_device()).unwrap();
let m = Tensor::zeros(&[2], test_opts()).unwrap();
let v = Tensor::zeros(&[2], test_opts()).unwrap();
param.adam_step(&grad, &m, &v, 0.001, 0.9, 0.999, 1e-8, 0.0, 1).unwrap();
let p = param.to_f32_vec().unwrap();
assert!(p[0] < 1.0, "param[0] should decrease");
assert!(p[1] < 2.0, "param[1] should decrease");
let m_data = m.to_f32_vec().unwrap();
let v_data = v.to_f32_vec().unwrap();
assert!(m_data[0] > 0.0, "m should be updated");
assert!(v_data[0] > 0.0, "v should be updated");
}
#[test]
fn test_device_enum_basics() {
assert_eq!(Device::CPU, Device::CPU);
assert_eq!(Device::CUDA(0), Device::CUDA(0));
assert_ne!(Device::CUDA(0), Device::CUDA(1));
assert_ne!(Device::CPU, Device::CUDA(0));
assert!(!Device::CPU.is_cuda());
assert!(Device::CUDA(0).is_cuda());
assert!(Device::CUDA(1).is_cuda());
assert_eq!(Device::CPU.index(), 0);
assert_eq!(Device::CUDA(0).index(), 0);
assert_eq!(Device::CUDA(1).index(), 1);
}
#[test]
fn test_device_display() {
assert_eq!(format!("{}", Device::CPU), "cpu");
assert_eq!(format!("{}", Device::CUDA(0)), "cuda");
assert_eq!(format!("{}", Device::CUDA(1)), "cuda:1");
}
#[test]
fn test_device_ffi_roundtrip() {
let devices = [Device::CPU, Device::CUDA(0), Device::CUDA(1), Device::CUDA(7)];
for dev in &devices {
let (dt, di) = dev.to_ffi();
let back = Device::from_ffi(dt, di);
assert_eq!(*dev, back, "FFI roundtrip failed for {:?}", dev);
}
}
#[test]
fn test_device_hash() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(Device::CPU);
set.insert(Device::CUDA(0));
set.insert(Device::CUDA(1));
assert_eq!(set.len(), 3);
assert!(set.contains(&Device::CPU));
assert!(set.contains(&Device::CUDA(0)));
assert!(set.contains(&Device::CUDA(1)));
}
#[test]
fn test_tensor_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Tensor>();
}
#[test]
#[ignore]
fn test_manual_seed_reproducible() {
let opts = test_opts();
manual_seed(123);
let a = Tensor::randn(&[4, 4], opts).unwrap().to_f32_vec().unwrap();
manual_seed(123);
let b = Tensor::randn(&[4, 4], opts).unwrap().to_f32_vec().unwrap();
assert_eq!(a, b);
}
#[test]
fn test_fused_adamw_matches_batched() {
let dev = test_device();
let opts = test_opts();
manual_seed(42);
let p1 = Tensor::randn(&[4, 3], opts).unwrap();
let p2 = Tensor::from_f32(&p1.to_f32_vec().unwrap(), &[4, 3], dev).unwrap();
let g = Tensor::randn(&[4, 3], opts).unwrap();
let m1 = Tensor::zeros(&[4, 3], opts).unwrap();
let m2 = Tensor::zeros(&[4, 3], opts).unwrap();
let v1 = Tensor::zeros(&[4, 3], opts).unwrap();
let v2 = Tensor::zeros(&[4, 3], opts).unwrap();
let lr = 0.001;
let beta1 = 0.9;
let beta2 = 0.999;
let eps = 1e-8;
let wd = 0.01;
p1.adam_step(&g, &m1, &v1, lr, beta1, beta2, eps, wd, 1).unwrap();
Tensor::fused_adamw_(
std::slice::from_ref(&p2), std::slice::from_ref(&g),
std::slice::from_ref(&m2), std::slice::from_ref(&v2),
lr, beta1, beta2, eps, wd, 1, None, None,
).unwrap();
let p1_data = p1.to_f32_vec().unwrap();
let p2_data = p2.to_f32_vec().unwrap();
for (i, (a, b)) in p1_data.iter().zip(&p2_data).enumerate() {
assert!((a - b).abs() < 1e-5,
"param mismatch at {}: batched={}, fused={}", i, a, b);
}
let m1_data = m1.to_f32_vec().unwrap();
let m2_data = m2.to_f32_vec().unwrap();
for (i, (a, b)) in m1_data.iter().zip(&m2_data).enumerate() {
assert!((a - b).abs() < 1e-6,
"m mismatch at {}: batched={}, fused={}", i, a, b);
}
}
#[test]
fn test_fused_adam_no_weight_decay() {
let opts = test_opts();
let p = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[4], test_device()).unwrap();
let g = Tensor::from_f32(&[0.1, 0.2, 0.3, 0.4], &[4], test_device()).unwrap();
let m = Tensor::zeros(&[4], opts).unwrap();
let v = Tensor::zeros(&[4], opts).unwrap();
Tensor::fused_adamw_(
std::slice::from_ref(&p), std::slice::from_ref(&g),
std::slice::from_ref(&m), std::slice::from_ref(&v),
0.001, 0.9, 0.999, 1e-8, 0.0, 1, None, None,
).unwrap();
let p_data = p.to_f32_vec().unwrap();
let orig = [1.0f32, 2.0, 3.0, 4.0];
for (i, &o) in orig.iter().enumerate() {
assert!((p_data[i] - (o - 0.001)).abs() < 1e-4,
"p[{}]: got {}, expected ~{}", i, p_data[i], o - 0.001);
}
}
#[test]
fn test_fused_adam_multi_step() {
let opts = test_opts();
let p = Tensor::from_f32(&[5.0], &[1], test_device()).unwrap();
let g = Tensor::from_f32(&[1.0], &[1], test_device()).unwrap();
let m = Tensor::zeros(&[1], opts).unwrap();
let v = Tensor::zeros(&[1], opts).unwrap();
for step in 1..=10 {
Tensor::fused_adamw_(
std::slice::from_ref(&p), std::slice::from_ref(&g),
std::slice::from_ref(&m), std::slice::from_ref(&v),
0.01, 0.9, 0.999, 1e-8, 0.0, step, None, None,
).unwrap();
}
let p_data = p.to_f32_vec().unwrap();
assert!(p_data[0] < 5.0, "param should decrease: got {}", p_data[0]);
let m_data = m.to_f32_vec().unwrap();
assert!((m_data[0] - 0.6513).abs() < 0.01,
"m after 10 steps: got {}", m_data[0]);
}
#[test]
fn test_fused_adam_empty_is_noop() {
Tensor::fused_adamw_(&[], &[], &[], &[], 0.001, 0.9, 0.999, 1e-8, 0.0, 1, None, None).unwrap();
Tensor::fused_adam_(&[], &[], &[], &[], 0.001, 0.9, 0.999, 1e-8, 0.0, 1, None, None).unwrap();
}
#[test]
fn test_foreach_add_scalar() {
let dev = test_device();
let a = Tensor::from_f32(&[1.0, 2.0], &[2], dev).unwrap();
let b = Tensor::from_f32(&[3.0, 4.0, 5.0], &[3], dev).unwrap();
Tensor::foreach_add_scalar_(&[a.clone(), b.clone()], 10.0).unwrap();
assert_eq!(a.to_f32_vec().unwrap(), vec![11.0, 12.0]);
assert_eq!(b.to_f32_vec().unwrap(), vec![13.0, 14.0, 15.0]);
}
#[test]
fn test_foreach_mul_scalar() {
let dev = test_device();
let a = Tensor::from_f32(&[2.0, 3.0], &[2], dev).unwrap();
let b = Tensor::from_f32(&[4.0, 5.0], &[2], dev).unwrap();
Tensor::foreach_mul_scalar_(&[a.clone(), b.clone()], 0.5).unwrap();
assert_eq!(a.to_f32_vec().unwrap(), vec![1.0, 1.5]);
assert_eq!(b.to_f32_vec().unwrap(), vec![2.0, 2.5]);
}
#[test]
fn test_foreach_zero() {
let dev = test_device();
let a = Tensor::from_f32(&[1.0, 2.0], &[2], dev).unwrap();
let b = Tensor::from_f32(&[3.0, 4.0], &[2], dev).unwrap();
Tensor::foreach_zero_(&[a.clone(), b.clone()]).unwrap();
assert_eq!(a.to_f32_vec().unwrap(), vec![0.0, 0.0]);
assert_eq!(b.to_f32_vec().unwrap(), vec![0.0, 0.0]);
}
#[test]
fn test_foreach_add_list() {
let dev = test_device();
let a = Tensor::from_f32(&[1.0, 2.0], &[2], dev).unwrap();
let b = Tensor::from_f32(&[10.0, 20.0], &[2], dev).unwrap();
let x = Tensor::from_f32(&[0.5, 0.5], &[2], dev).unwrap();
let y = Tensor::from_f32(&[1.0, 1.0], &[2], dev).unwrap();
Tensor::foreach_add_list_(
&[a.clone(), b.clone()],
&[x, y],
2.0,
).unwrap();
assert_eq!(a.to_f32_vec().unwrap(), vec![2.0, 3.0]);
assert_eq!(b.to_f32_vec().unwrap(), vec![12.0, 22.0]);
}
#[test]
fn test_foreach_norm() {
let dev = test_device();
let a = Tensor::from_f32(&[3.0, 4.0], &[2], dev).unwrap();
let b = Tensor::from_f32(&[1.0, 0.0], &[1, 2], dev).unwrap();
let norms = Tensor::foreach_norm(&[a, b], 2.0).unwrap();
assert_eq!(norms.len(), 2);
let n0: f64 = norms[0].item().unwrap();
let n1: f64 = norms[1].item().unwrap();
assert!((n0 - 5.0).abs() < 1e-5, "norm of [3,4] should be 5, got {}", n0);
assert!((n1 - 1.0).abs() < 1e-5, "norm of [1,0] should be 1, got {}", n1);
}
#[test]
fn test_foreach_lerp_scalar() {
let dev = test_device();
let a = Tensor::from_f32(&[0.0, 10.0], &[2], dev).unwrap();
let b = Tensor::from_f32(&[10.0, 0.0], &[2], dev).unwrap();
let a_target = Tensor::from_f32(&[10.0, 10.0], &[2], dev).unwrap();
let b_target = Tensor::from_f32(&[10.0, 10.0], &[2], dev).unwrap();
Tensor::foreach_lerp_scalar_(
&[a.clone(), b.clone()],
&[a_target, b_target],
0.5,
).unwrap();
assert_eq!(a.to_f32_vec().unwrap(), vec![5.0, 10.0]);
assert_eq!(b.to_f32_vec().unwrap(), vec![10.0, 5.0]);
}
#[test]
fn test_foreach_sqrt() {
let dev = test_device();
let a = Tensor::from_f32(&[4.0, 9.0], &[2], dev).unwrap();
let b = Tensor::from_f32(&[16.0, 25.0], &[2], dev).unwrap();
Tensor::foreach_sqrt_(&[a.clone(), b.clone()]).unwrap();
assert_eq!(a.to_f32_vec().unwrap(), vec![2.0, 3.0]);
assert_eq!(b.to_f32_vec().unwrap(), vec![4.0, 5.0]);
}
#[test]
fn test_foreach_empty_list_is_noop() {
Tensor::foreach_add_scalar_(&[], 1.0).unwrap();
Tensor::foreach_mul_scalar_(&[], 1.0).unwrap();
Tensor::foreach_zero_(&[]).unwrap();
Tensor::foreach_add_list_(&[], &[], 1.0).unwrap();
assert!(Tensor::foreach_norm(&[], 2.0).unwrap().is_empty());
Tensor::foreach_lerp_scalar_(&[], &[], 0.5).unwrap();
Tensor::foreach_sqrt_(&[]).unwrap();
}
#[test]
fn test_full_like() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
let fl = Tensor::full_like(&t, 7.0).unwrap();
assert_eq!(fl.to_f32_vec().unwrap(), vec![7.0, 7.0, 7.0]);
assert_eq!(fl.dtype(), DType::Float32);
}
#[test]
fn test_rand_like_randn_like() {
let t = Tensor::ones(&[3, 4], test_opts()).unwrap();
let rl = Tensor::rand_like(&t).unwrap();
assert_eq!(rl.shape(), vec![3, 4]);
let data = rl.to_f32_vec().unwrap();
assert!(data.iter().all(|&v| (0.0..1.0).contains(&v)));
let nl = Tensor::randn_like(&t).unwrap();
assert_eq!(nl.shape(), vec![3, 4]);
}
#[test]
fn test_randint() {
let mut opts = test_opts();
opts.dtype = DType::Int64;
let t = Tensor::randint(0, 10, &[100], opts).unwrap();
assert_eq!(t.shape(), vec![100]);
let data = t.to_i64_vec().unwrap();
assert!(data.iter().all(|&v| (0..10).contains(&v)));
}
#[test]
fn test_empty() {
let t = Tensor::empty(&[2, 3], test_opts()).unwrap();
assert_eq!(t.shape(), vec![2, 3]);
assert_eq!(t.dtype(), DType::Float32);
}
#[test]
fn test_one_hot() {
let t = Tensor::from_i64(&[0, 1, 2], &[3], test_device()).unwrap();
let oh = t.one_hot(4).unwrap();
assert_eq!(oh.shape(), vec![3, 4]);
let data = oh.to_f32_vec().unwrap();
assert_eq!(&data[0..4], &[1.0, 0.0, 0.0, 0.0]);
assert_eq!(&data[4..8], &[0.0, 1.0, 0.0, 0.0]);
assert_eq!(&data[8..12], &[0.0, 0.0, 1.0, 0.0]);
}
#[test]
fn test_bernoulli() {
let probs = Tensor::from_f32(&[0.0, 1.0, 0.0, 1.0], &[4], test_device()).unwrap();
let samples = probs.bernoulli().unwrap();
assert_eq!(samples.shape(), vec![4]);
let data = samples.to_f32_vec().unwrap();
assert!((data[0] - 0.0).abs() < 1e-5);
assert!((data[1] - 1.0).abs() < 1e-5);
assert!((data[2] - 0.0).abs() < 1e-5);
assert!((data[3] - 1.0).abs() < 1e-5);
}
#[test]
fn test_is_contiguous() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], test_device()).unwrap();
assert!(t.is_contiguous());
}
#[test]
fn test_mul_inplace() {
let a = Tensor::from_f32(&[2.0, 3.0], &[2], test_device()).unwrap();
let b = Tensor::from_f32(&[4.0, 5.0], &[2], test_device()).unwrap();
a.mul_(&b).unwrap();
assert_eq!(a.to_f32_vec().unwrap(), vec![8.0, 15.0]);
}
#[test]
fn test_div_scalar_inplace() {
let t = Tensor::from_f32(&[6.0, 9.0], &[2], test_device()).unwrap();
t.div_scalar_(3.0).unwrap();
let data = t.to_f32_vec().unwrap();
assert!((data[0] - 2.0).abs() < 1e-5);
assert!((data[1] - 3.0).abs() < 1e-5);
}
#[test]
fn test_div_inplace() {
let a = Tensor::from_f32(&[8.0, 15.0], &[2], test_device()).unwrap();
let b = Tensor::from_f32(&[4.0, 5.0], &[2], test_device()).unwrap();
a.div_(&b).unwrap();
let data = a.to_f32_vec().unwrap();
assert!((data[0] - 2.0).abs() < 1e-5);
assert!((data[1] - 3.0).abs() < 1e-5);
}
#[test]
fn test_fill_inplace() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
t.fill_(42.0).unwrap();
assert_eq!(t.to_f32_vec().unwrap(), vec![42.0, 42.0, 42.0]);
}
#[test]
fn test_probe_device_cpu() {
assert!(probe_device(Device::CPU).is_ok());
}
#[test]
#[ignore = "GPU probe needs CUDA; run with: make cuda-test-all"]
fn test_probe_device_cuda() {
if !test_device().is_cuda() { return; }
assert!(probe_device(Device::CUDA(0)).is_ok());
}
#[test]
#[ignore = "GPU diagnostics need CUDA; run with: make cuda-test-all"]
fn test_cuda_devices_has_compute_capability() {
if !test_device().is_cuda() { return; }
let devices = cuda_devices();
assert!(!devices.is_empty());
for info in &devices {
assert!(info.sm_major > 0, "compute capability should be detected");
eprintln!(" CUDA({}) {} {} {:.1}GB",
info.index, info.name, info.sm_version(),
info.total_memory as f64 / (1024.0 * 1024.0 * 1024.0));
}
}
#[test]
#[ignore = "GPU diagnostics need CUDA; run with: make cuda-test-all"]
fn test_usable_cuda_devices() {
if !test_device().is_cuda() { return; }
let usable = usable_cuda_devices();
assert!(!usable.is_empty(), "at least one device should be usable");
assert!(usable.contains(&Device::CUDA(0)));
}
}