use crate::device::Device;
use crate::dtype::Element;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::gpu_dispatch::GpuBufferHandle;
use crate::storage::TensorStorage;
pub trait IntElement:
Element + Copy + Send + Sync + 'static + std::fmt::Debug + std::fmt::Display
{
const BITS: u32;
fn dtype_name() -> &'static str;
fn try_from_i64(v: i64) -> Option<Self>;
fn to_i64(self) -> i64;
}
impl IntElement for i32 {
const BITS: u32 = 32;
fn dtype_name() -> &'static str {
"i32"
}
fn try_from_i64(v: i64) -> Option<Self> {
if (i32::MIN as i64..=i32::MAX as i64).contains(&v) {
Some(v as i32)
} else {
None
}
}
fn to_i64(self) -> i64 {
self as i64
}
}
impl IntElement for i64 {
const BITS: u32 = 64;
fn dtype_name() -> &'static str {
"i64"
}
fn try_from_i64(v: i64) -> Option<Self> {
Some(v)
}
fn to_i64(self) -> i64 {
self
}
}
#[derive(Debug)]
pub struct IntTensor<I: IntElement> {
storage: TensorStorage<I>,
shape: Vec<usize>,
}
impl<I: IntElement> Clone for IntTensor<I> {
fn clone(&self) -> Self {
Self {
storage: self.storage.clone(),
shape: self.shape.clone(),
}
}
}
impl<I: IntElement> IntTensor<I> {
pub fn from_vec(data: Vec<I>, shape: Vec<usize>) -> FerrotorchResult<Self> {
let expected: usize = if shape.is_empty() {
1
} else {
shape.iter().product()
};
if data.len() != expected {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"IntTensor::from_vec: data.len()={} != prod(shape)={} for shape {:?}",
data.len(),
expected,
shape
),
});
}
Ok(Self {
storage: TensorStorage::cpu(data),
shape,
})
}
pub fn from_slice(data: &[I], shape: &[usize]) -> FerrotorchResult<Self> {
Self::from_vec(data.to_vec(), shape.to_vec())
}
pub fn zeros(shape: &[usize]) -> Self {
let total: usize = if shape.is_empty() {
1
} else {
shape.iter().product()
};
let zero = I::try_from_i64(0).expect("0 fits in any IntElement");
Self {
storage: TensorStorage::cpu(vec![zero; total]),
shape: shape.to_vec(),
}
}
pub fn arange(n: usize) -> FerrotorchResult<Self> {
let mut data: Vec<I> = Vec::with_capacity(n);
for i in 0..n {
data.push(
I::try_from_i64(i as i64).ok_or(FerrotorchError::InvalidArgument {
message: format!(
"IntTensor::arange: {i} out of range for {}",
I::dtype_name()
),
})?,
);
}
Self::from_vec(data, vec![n])
}
pub fn scalar(v: I) -> Self {
Self {
storage: TensorStorage::cpu(vec![v]),
shape: Vec::new(),
}
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn numel(&self) -> usize {
self.storage.len()
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
#[inline]
pub fn device(&self) -> Device {
self.storage.device()
}
#[inline]
pub fn is_cuda(&self) -> bool {
self.device().is_cuda()
}
pub fn data(&self) -> FerrotorchResult<&[I]> {
self.storage.try_as_slice()
}
pub fn dtype_name(&self) -> &'static str {
I::dtype_name()
}
pub fn gpu_handle(&self) -> FerrotorchResult<&GpuBufferHandle> {
self.storage
.gpu_handle()
.ok_or(FerrotorchError::InvalidArgument {
message: "IntTensor is not on a CUDA GPU".into(),
})
}
pub fn to(&self, device: Device) -> FerrotorchResult<IntTensor<I>> {
if self.device() == device {
return Ok(self.clone());
}
match (self.device(), device) {
(Device::Cpu, Device::Cuda(_)) => {
let data = self.data()?.to_vec();
let storage = TensorStorage::on_device(data, device)?;
Ok(Self {
storage,
shape: self.shape.clone(),
})
}
(Device::Cuda(_), Device::Cpu) => {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let handle = self.gpu_handle()?;
let bytes = backend.gpu_to_cpu(handle)?;
let elem_size = std::mem::size_of::<I>();
if bytes.len() % elem_size != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"IntTensor::to(Cpu): D2H readback of {} bytes is not a multiple \
of size_of::<{}>()={elem_size}",
bytes.len(),
I::dtype_name()
),
});
}
let data: Vec<I> = unsafe {
let mut bytes = std::mem::ManuallyDrop::new(bytes);
let len = bytes.len() / elem_size;
let cap = bytes.capacity() / elem_size;
Vec::from_raw_parts(bytes.as_mut_ptr().cast::<I>(), len, cap)
};
Ok(Self {
storage: TensorStorage::cpu(data),
shape: self.shape.clone(),
})
}
(Device::Cuda(_), Device::Cuda(_)) => {
let cpu = self.to(Device::Cpu)?;
cpu.to(device)
}
(from, to) => Err(FerrotorchError::InvalidArgument {
message: format!(
"IntTensor::to: unsupported device transfer {from:?} -> {to:?} \
(Phase 2a supports CPU <-> CUDA only)"
),
}),
}
}
pub fn cast<J: IntElement>(&self) -> FerrotorchResult<IntTensor<J>> {
if let Some(result) = self.cast_gpu::<J>() {
return result;
}
let data = self.data()?;
let mut out: Vec<J> = Vec::with_capacity(data.len());
for (i, &v) in data.iter().enumerate() {
let widened = v.to_i64();
out.push(
J::try_from_i64(widened).ok_or(FerrotorchError::InvalidArgument {
message: format!(
"IntTensor::cast: element {i} = {v} out of range for {}",
J::dtype_name()
),
})?,
);
}
IntTensor::<J>::from_vec(out, self.shape.clone())
}
pub fn reshape(&self, shape: &[usize]) -> FerrotorchResult<Self> {
let new_total: usize = if shape.is_empty() {
1
} else {
shape.iter().product()
};
let cur = self.storage.len();
if new_total != cur {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"IntTensor::reshape: new shape {shape:?} (numel {new_total}) != current numel {cur}"
),
});
}
Ok(Self {
storage: self.storage.clone(),
shape: shape.to_vec(),
})
}
pub(crate) fn from_gpu_handle(handle: GpuBufferHandle, shape: Vec<usize>) -> Self {
debug_assert_eq!(
handle.dtype(),
I::dtype(),
"from_gpu_handle: handle dtype tag must match IntElement"
);
Self {
storage: TensorStorage::gpu(handle),
shape,
}
}
fn check_binary(&self, other: &IntTensor<I>, op: &'static str) -> FerrotorchResult<()> {
if self.device() != other.device() {
return Err(FerrotorchError::DeviceMismatch {
expected: self.device(),
got: other.device(),
});
}
if self.shape != other.shape {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"IntTensor::{op}: operand shapes differ {:?} vs {:?} \
(broadcasting is out of scope for Phase 2b — same shape only)",
self.shape, other.shape
),
});
}
Ok(())
}
fn binary_op(
&self,
other: &IntTensor<I>,
op: &'static str,
gpu: impl FnOnce(
&dyn crate::gpu_dispatch::GpuBackend,
&GpuBufferHandle,
&GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle>,
f: impl Fn(I, I) -> I,
) -> FerrotorchResult<IntTensor<I>> {
self.check_binary(other, op)?;
if self.is_cuda() {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = gpu(backend, self.gpu_handle()?, other.gpu_handle()?)?;
Ok(IntTensor::from_gpu_handle(h, self.shape.clone()))
} else {
let a = self.data()?;
let b = other.data()?;
let out: Vec<I> = a.iter().zip(b.iter()).map(|(&x, &y)| f(x, y)).collect();
IntTensor::from_vec(out, self.shape.clone())
}
}
fn unary_op(
&self,
gpu: impl FnOnce(
&dyn crate::gpu_dispatch::GpuBackend,
&GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle>,
f: impl Fn(I) -> I,
) -> FerrotorchResult<IntTensor<I>> {
if self.is_cuda() {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = gpu(backend, self.gpu_handle()?)?;
Ok(IntTensor::from_gpu_handle(h, self.shape.clone()))
} else {
let a = self.data()?;
let out: Vec<I> = a.iter().map(|&x| f(x)).collect();
IntTensor::from_vec(out, self.shape.clone())
}
}
fn reduce_op(
&self,
op: &'static str,
gpu: impl FnOnce(
&dyn crate::gpu_dispatch::GpuBackend,
&GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle>,
empty: Option<I>,
f: impl Fn(I, I) -> I,
) -> FerrotorchResult<IntTensor<I>> {
if self.is_cuda() {
if self.numel() == 0 {
match empty {
Some(id) => return Ok(IntTensor::scalar(id)),
None => {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"IntTensor::{op}: reduction of an empty tensor is undefined"
),
});
}
}
}
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = gpu(backend, self.gpu_handle()?)?;
Ok(IntTensor::from_gpu_handle(h, Vec::new()))
} else {
let a = self.data()?;
match a.split_first() {
Some((&first, rest)) => {
let acc = rest.iter().fold(first, |acc, &x| f(acc, x));
Ok(IntTensor::scalar(acc))
}
None => match empty {
Some(id) => Ok(IntTensor::scalar(id)),
None => Err(FerrotorchError::InvalidArgument {
message: format!(
"IntTensor::{op}: reduction of an empty tensor is undefined"
),
}),
},
}
}
}
pub fn add(&self, other: &IntTensor<I>) -> FerrotorchResult<IntTensor<I>> {
self.binary_op(
other,
"add",
|b, x, y| b.int_add(x, y),
|x, y| I::try_from_i64(x.to_i64().wrapping_add(y.to_i64())).unwrap_or(x),
)
}
pub fn sub(&self, other: &IntTensor<I>) -> FerrotorchResult<IntTensor<I>> {
self.binary_op(
other,
"sub",
|b, x, y| b.int_sub(x, y),
|x, y| I::try_from_i64(x.to_i64().wrapping_sub(y.to_i64())).unwrap_or(x),
)
}
pub fn mul(&self, other: &IntTensor<I>) -> FerrotorchResult<IntTensor<I>> {
self.binary_op(
other,
"mul",
|b, x, y| b.int_mul(x, y),
|x, y| int_wrapping_mul(x, y),
)
}
pub fn neg(&self) -> FerrotorchResult<IntTensor<I>> {
self.unary_op(
|b, x| b.int_neg(x),
|x| I::try_from_i64(0_i64.wrapping_sub(x.to_i64())).unwrap_or(x),
)
}
pub fn floor_div(&self, other: &IntTensor<I>) -> FerrotorchResult<IntTensor<I>> {
self.binary_op(
other,
"floor_div",
|b, x, y| b.int_floor_div(x, y),
int_floor_div_ref,
)
}
pub fn remainder(&self, other: &IntTensor<I>) -> FerrotorchResult<IntTensor<I>> {
self.binary_op(
other,
"remainder",
|b, x, y| b.int_remainder(x, y),
int_remainder_ref,
)
}
pub fn bitand(&self, other: &IntTensor<I>) -> FerrotorchResult<IntTensor<I>> {
self.binary_op(
other,
"bitand",
|b, x, y| b.int_bitand(x, y),
|x, y| I::try_from_i64(x.to_i64() & y.to_i64()).unwrap_or(x),
)
}
pub fn bitor(&self, other: &IntTensor<I>) -> FerrotorchResult<IntTensor<I>> {
self.binary_op(
other,
"bitor",
|b, x, y| b.int_bitor(x, y),
|x, y| I::try_from_i64(x.to_i64() | y.to_i64()).unwrap_or(x),
)
}
pub fn bitxor(&self, other: &IntTensor<I>) -> FerrotorchResult<IntTensor<I>> {
self.binary_op(
other,
"bitxor",
|b, x, y| b.int_bitxor(x, y),
|x, y| I::try_from_i64(x.to_i64() ^ y.to_i64()).unwrap_or(x),
)
}
pub fn bitnot(&self) -> FerrotorchResult<IntTensor<I>> {
self.unary_op(|b, x| b.int_bitnot(x), int_bitnot_ref)
}
pub fn shl(&self, other: &IntTensor<I>) -> FerrotorchResult<IntTensor<I>> {
self.binary_op(other, "shl", |b, x, y| b.int_shl(x, y), int_shl_ref)
}
pub fn shr(&self, other: &IntTensor<I>) -> FerrotorchResult<IntTensor<I>> {
self.binary_op(other, "shr", |b, x, y| b.int_shr(x, y), int_shr_ref)
}
pub fn sum(&self) -> FerrotorchResult<IntTensor<I>> {
self.reduce_op(
"sum",
|b, x| b.int_sum(x),
Some(I::try_from_i64(0).expect("0 fits any IntElement")),
|acc, x| I::try_from_i64(acc.to_i64().wrapping_add(x.to_i64())).unwrap_or(acc),
)
}
pub fn prod(&self) -> FerrotorchResult<IntTensor<I>> {
self.reduce_op(
"prod",
|b, x| b.int_prod(x),
Some(I::try_from_i64(1).expect("1 fits any IntElement")),
int_wrapping_mul,
)
}
pub fn min(&self) -> FerrotorchResult<IntTensor<I>> {
self.reduce_op(
"min",
|b, x| b.int_min(x),
None,
|acc, x| if x.to_i64() < acc.to_i64() { x } else { acc },
)
}
pub fn max(&self) -> FerrotorchResult<IntTensor<I>> {
self.reduce_op(
"max",
|b, x| b.int_max(x),
None,
|acc, x| if x.to_i64() > acc.to_i64() { x } else { acc },
)
}
}
fn int_wrapping_mul<I: IntElement>(x: I, y: I) -> I {
let prod = match I::BITS {
32 => (x.to_i64() as i32).wrapping_mul(y.to_i64() as i32) as i64,
_ => x.to_i64().wrapping_mul(y.to_i64()),
};
I::try_from_i64(prod).unwrap_or(x)
}
fn int_floor_div_ref<I: IntElement>(x: I, y: I) -> I {
let a = x.to_i64();
let b = y.to_i64();
if b == 0 {
return I::try_from_i64(0).unwrap_or(x);
}
let q = a.wrapping_div(b);
let r = a.wrapping_rem(b);
let q = if r != 0 && ((r < 0) != (b < 0)) {
q.wrapping_sub(1)
} else {
q
};
I::try_from_i64(q).unwrap_or(x)
}
fn int_remainder_ref<I: IntElement>(x: I, y: I) -> I {
let a = x.to_i64();
let b = y.to_i64();
if b == 0 {
return I::try_from_i64(0).unwrap_or(x);
}
let r = a.wrapping_rem(b);
let r = if r != 0 && ((r < 0) != (b < 0)) {
r.wrapping_add(b)
} else {
r
};
I::try_from_i64(r).unwrap_or(x)
}
fn int_bitnot_ref<I: IntElement>(x: I) -> I {
let v = match I::BITS {
32 => !(x.to_i64() as i32) as i64,
_ => !x.to_i64(),
};
I::try_from_i64(v).unwrap_or(x)
}
fn int_shl_ref<I: IntElement>(x: I, y: I) -> I {
let sh = (y.to_i64() as u32) & (I::BITS - 1);
let v = match I::BITS {
32 => ((x.to_i64() as i32).wrapping_shl(sh)) as i64,
_ => x.to_i64().wrapping_shl(sh),
};
I::try_from_i64(v).unwrap_or(x)
}
fn int_shr_ref<I: IntElement>(x: I, y: I) -> I {
let sh = (y.to_i64() as u32) & (I::BITS - 1);
let v = match I::BITS {
32 => ((x.to_i64() as i32).wrapping_shr(sh)) as i64,
_ => x.to_i64().wrapping_shr(sh),
};
I::try_from_i64(v).unwrap_or(x)
}
impl<I: IntElement> std::fmt::Display for IntTensor<I> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"IntTensor<{}>(shape={:?}, len={}, device={:?})",
I::dtype_name(),
self.shape,
self.storage.len(),
self.device(),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_vec_basic() {
let t = IntTensor::<i32>::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
assert_eq!(t.shape(), &[2, 2]);
assert_eq!(t.numel(), 4);
assert_eq!(t.data().unwrap(), &[1, 2, 3, 4]);
}
#[test]
fn from_vec_shape_mismatch_errors() {
let err = IntTensor::<i32>::from_vec(vec![1, 2, 3], vec![2, 2]).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn zeros_correct_size() {
let t = IntTensor::<i64>::zeros(&[3, 4]);
assert_eq!(t.numel(), 12);
assert!(t.data().unwrap().iter().all(|&x| x == 0));
}
#[test]
fn arange_sequence() {
let t = IntTensor::<i32>::arange(5).unwrap();
assert_eq!(t.data().unwrap(), &[0, 1, 2, 3, 4]);
}
#[test]
fn arange_oob_for_i32() {
assert!(i32::try_from_i64(i64::MAX).is_none());
}
#[test]
fn cast_i64_to_i32_in_range() {
let t = IntTensor::<i64>::from_vec(vec![1, -1, 100], vec![3]).unwrap();
let c = t.cast::<i32>().unwrap();
assert_eq!(c.data().unwrap(), &[1, -1, 100]);
assert_eq!(c.dtype_name(), "i32");
}
#[test]
fn cast_i64_to_i32_out_of_range_errors() {
let t = IntTensor::<i64>::from_vec(vec![i64::MAX], vec![1]).unwrap();
let err = t.cast::<i32>().unwrap_err();
assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
}
#[test]
fn reshape_preserves_data() {
let t = IntTensor::<i32>::from_vec(vec![1, 2, 3, 4, 5, 6], vec![6]).unwrap();
let r = t.reshape(&[2, 3]).unwrap();
assert_eq!(r.shape(), &[2, 3]);
assert_eq!(r.data().unwrap(), &[1, 2, 3, 4, 5, 6]);
}
#[test]
fn reshape_size_mismatch_errors() {
let t = IntTensor::<i32>::from_vec(vec![1, 2, 3, 4], vec![4]).unwrap();
let err = t.reshape(&[3, 2]).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn scalar_constructor() {
let t = IntTensor::<i64>::scalar(42);
assert_eq!(t.shape(), &[] as &[usize]);
assert_eq!(t.numel(), 1);
assert_eq!(t.data().unwrap()[0], 42);
}
#[test]
fn dtype_name_reports_i32_or_i64() {
let t32 = IntTensor::<i32>::scalar(0);
let t64 = IntTensor::<i64>::scalar(0);
assert_eq!(t32.dtype_name(), "i32");
assert_eq!(t64.dtype_name(), "i64");
}
#[test]
fn cpu_tensor_reports_cpu_device() {
let t = IntTensor::<i32>::arange(4).unwrap();
assert_eq!(t.device(), Device::Cpu);
assert!(!t.is_cuda());
assert!(t.gpu_handle().is_err());
}
#[test]
fn clone_preserves_cpu_data() {
let t = IntTensor::<i32>::arange(4).unwrap();
let t2 = t.clone();
assert_eq!(t2.data().unwrap(), &[0, 1, 2, 3]);
assert_eq!(t2.device(), Device::Cpu);
}
}