use burn_backend::{
DType, Distribution, ElementConversion, ExecutionError, IntDType, Shape, Slice, TensorData,
ops::{FloatTensorOps, IntTensorOps},
tensor::{Bool, BoolTensor, Device, FloatTensor, IntElem, IntTensor},
};
use crate::{
Candle, CandleDevice, CandleTensor, IntoDType,
element::{CandleElement, FloatCandleElement, IntCandleElement},
};
use super::base::{cpu_random, expand, permute, sign, unfold};
impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F, I> {
fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
super::base::empty(shape, device, dtype.into_dtype())
}
async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {
super::base::into_data(tensor)
}
fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
match data.dtype {
DType::I64 => super::base::from_data::<i64>(data, device),
DType::U32 => super::base::from_data::<u32>(data, device),
DType::U8 => super::base::from_data::<u8>(data, device),
_ => unimplemented!("Unsupported dtype for `int_from_data`"),
}
}
fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
super::base::device(tensor)
}
fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
super::base::to_device(tensor, device)
}
fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
super::base::reshape(tensor, shape)
}
fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
super::base::slice_with_steps(tensor, slices)
}
fn int_slice_assign(
tensor: IntTensor<Self>,
slices: &[Slice],
value: IntTensor<Self>,
) -> IntTensor<Self> {
super::base::slice_assign(tensor, slices, value)
}
fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap())
}
fn int_mask_where(
tensor: IntTensor<Self>,
mask: BoolTensor<Self>,
source: IntTensor<Self>,
) -> IntTensor<Self> {
super::base::mask_where_broadcasted(tensor, mask, source)
}
fn int_mask_fill(
tensor: IntTensor<Self>,
mask: BoolTensor<Self>,
value: IntElem<Self>,
) -> IntTensor<Self> {
CandleTensor::new(
mask.tensor
.where_cond(
&super::candle_utils::fill_like::<I>(value, &tensor.tensor),
&tensor.tensor,
)
.unwrap(),
)
}
fn int_gather(
dim: usize,
tensor: IntTensor<Self>,
indices: IntTensor<Self>,
) -> IntTensor<Self> {
let tensor = tensor.tensor.contiguous().unwrap();
let indices = indices.tensor.contiguous().unwrap();
CandleTensor::new(tensor.gather(&indices, dim).unwrap())
}
fn int_scatter_add(
dim: usize,
tensor: IntTensor<Self>,
indices: IntTensor<Self>,
value: IntTensor<Self>,
) -> IntTensor<Self> {
CandleTensor::new(
tensor
.tensor
.scatter_add(&indices.tensor, &value.tensor, dim)
.unwrap(),
)
}
fn int_select(
tensor: IntTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
) -> IntTensor<Self> {
CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap())
}
fn int_select_add(
tensor: IntTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
value: IntTensor<Self>,
) -> IntTensor<Self> {
CandleTensor::new(
tensor
.tensor
.index_add(&indices.tensor, &value.tensor, dim)
.unwrap(),
)
}
fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {
super::base::cat(tensors, dim)
}
fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.eq(&rhs_broadcast).unwrap())
}
fn int_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
CandleTensor::new(lhs.tensor.eq(rhs).unwrap())
}
fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.gt(&rhs_broadcast).unwrap())
}
fn int_greater_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.gt(&super::candle_utils::fill_like::<I>(rhs, &lhs.tensor))
.unwrap(),
)
}
fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.ge(&rhs_broadcast).unwrap())
}
fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.ge(&super::candle_utils::fill_like::<I>(rhs, &lhs.tensor))
.unwrap(),
)
}
fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.lt(&rhs_broadcast).unwrap())
}
fn int_lower_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.lt(&super::candle_utils::fill_like::<I>(rhs, &lhs.tensor))
.unwrap(),
)
}
fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.le(&rhs_broadcast).unwrap())
}
fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.le(&super::candle_utils::fill_like::<I>(rhs, &lhs.tensor))
.unwrap(),
)
}
fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap())
}
fn int_add_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
CandleTensor::new((lhs.tensor + rhs.elem::<f64>()).unwrap())
}
fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap())
}
fn int_sub_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
CandleTensor::new((lhs.tensor - rhs.elem::<f64>()).unwrap())
}
fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap())
}
fn int_mul_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
CandleTensor::new((lhs.tensor * rhs.elem::<f64>()).unwrap())
}
fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap())
}
fn int_div_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
panic!("Not supported by Candle")
}
fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
CandleTensor::new(
(lhs.tensor.clone()
- lhs
.tensor
.broadcast_div(&rhs.tensor)
.unwrap()
.broadcast_mul(&rhs.tensor)
.unwrap())
.unwrap(),
)
}
fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
panic!("Not supported by Candle")
}
fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
CandleTensor::new(
candle_core::Tensor::zeros(shape.dims, dtype.into_dtype(), &(device.clone()).into())
.unwrap(),
)
}
fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
CandleTensor::new(
candle_core::Tensor::ones(shape.dims, dtype.into_dtype(), &(device.clone()).into())
.unwrap(),
)
}
fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
let sum = tensor.tensor.sum_all().unwrap().to_scalar::<I>().unwrap();
CandleTensor::from_data::<I>(
TensorData::new([sum].into(), [1]),
Self::int_device(&tensor),
)
}
fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
}
fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
todo!(
"prod is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)"
)
}
fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
todo!(
"prod_int is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)"
)
}
fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
panic!("Not supported by Candle")
}
fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
let dtype = tensor.tensor.dtype();
let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap();
let result_float = tensor_float.cumsum(dim).unwrap();
CandleTensor::new(result_float.to_dtype(dtype).unwrap())
}
fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
let dtype = tensor.tensor.dtype();
let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap();
let result_float = super::utils::cumulative_with_op(&tensor_float, dim, |prev, curr| {
prev.broadcast_mul(curr)
});
CandleTensor::new(result_float.to_dtype(dtype).unwrap())
}
fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
let dtype = tensor.tensor.dtype();
let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap();
let result_float = super::utils::cumulative_with_op(&tensor_float, dim, |prev, curr| {
prev.broadcast_minimum(curr)
});
CandleTensor::new(result_float.to_dtype(dtype).unwrap())
}
fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| {
prev.broadcast_maximum(curr)
});
CandleTensor::new(result)
}
fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
CandleTensor::new(
tensor
.tensor
.argmax_keepdim(dim)
.unwrap()
.to_dtype(I::DTYPE)
.unwrap(),
)
}
fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
CandleTensor::new(
tensor
.tensor
.argmin_keepdim(dim)
.unwrap()
.to_dtype(I::DTYPE)
.unwrap(),
)
}
fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
match tensor.tensor.dtype() {
candle_core::DType::U8 | candle_core::DType::U32 => tensor,
candle_core::DType::I64 => CandleTensor::new(
tensor
.tensor
.to_dtype(F::DTYPE)
.unwrap()
.abs()
.unwrap()
.to_dtype(candle_core::DType::I64)
.unwrap(),
),
_ => unreachable!(),
}
}
fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
super::base::swap_dims(tensor, dim1, dim2)
}
fn int_random(
shape: Shape,
distribution: Distribution,
device: &Device<Self>,
) -> IntTensor<Self> {
if let CandleDevice::Cpu = device {
let distribution = if distribution == Distribution::Default {
Distribution::Uniform(0.0, 255.0)
} else {
distribution
};
return Self::int_from_data(cpu_random::<I>(shape, distribution), device);
}
let shape = shape.dims;
let device = &(device.clone()).into();
match distribution {
Distribution::Default => CandleTensor::new(
candle_core::Tensor::rand(0.elem::<F>(), 255.elem::<F>(), shape, device)
.unwrap()
.to_dtype(I::DTYPE)
.unwrap(),
),
Distribution::Bernoulli(prob) => CandleTensor::new(
candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape.clone(), device)
.unwrap()
.to_dtype(I::DTYPE)
.unwrap()
.lt(&super::candle_utils::fill(prob, shape, I::DTYPE, device))
.unwrap()
.to_dtype(I::DTYPE)
.unwrap(),
),
Distribution::Uniform(from, to) => CandleTensor::new(
candle_core::Tensor::rand(from.elem::<F>(), to.elem::<F>(), shape, device).unwrap(),
),
Distribution::Normal(mean, std) => CandleTensor::new(
candle_core::Tensor::randn(mean.elem::<F>(), std.elem::<F>(), shape, device)
.unwrap(),
),
}
}
fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
super::base::permute(tensor, axes)
}
fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
super::base::flip(tensor, axes)
}
fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
expand(tensor, shape)
}
fn int_unfold(
tensor: IntTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> IntTensor<Self> {
unfold(tensor, dim, size, step)
}
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
sign(tensor)
}
fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_and is not implemented for Candle IntTensor");
}
fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_and_scalar is not implemented for Candle IntTensor");
}
fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_or is not implemented for Candle IntTensor");
}
fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_or_scalar is not implemented for Candle IntTensor");
}
fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_xor is not implemented for Candle IntTensor");
}
fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_xor_scalar is not implemented for Candle IntTensor");
}
fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_not is not implemented for Candle IntTensor");
}
fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_left_shift is not implemented for Candle IntTensor");
}
fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_right_shift is not implemented for Candle IntTensor");
}
fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_left_shift_scalar is not implemented for Candle IntTensor");
}
fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_right_shift_scalar is not implemented for Candle IntTensor");
}
fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
let lhs = Self::int_into_float(lhs);
let rhs = Self::int_into_float(rhs);
let out = Self::float_matmul(lhs, rhs);
Self::float_into_int(out)
}
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
let dtype = dtype.into_dtype();
if tensor.tensor.dtype() == dtype {
tensor
} else {
CandleTensor::new(tensor.tensor.to_dtype(dtype).unwrap())
}
}
}