use std::borrow::Borrow;
use burn_backend::{
DType, Distribution, ElementConversion, ExecutionError, FloatDType, Shape, Slice, TensorData,
bf16, f16,
ops::FloatTensorOps,
tensor::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor},
};
use candle_core::{Tensor, backend::BackendStorage, shape};
use crate::{
Candle, CandleDevice, CandleTensor, IntoDType,
element::{CandleElement, FloatCandleElement, IntCandleElement},
};
use super::base::{cpu_random, expand, permute, sign, unfold};
impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
fn float_from_data(data: TensorData, device: &Device<Self>) -> CandleTensor {
match data.dtype {
DType::F64 => super::base::from_data::<f64>(data, device),
DType::F32 => super::base::from_data::<f32>(data, device),
DType::F16 => super::base::from_data::<f16>(data, device),
DType::BF16 => super::base::from_data::<bf16>(data, device),
_ => unimplemented!("Unsupported dtype for `float_from_data`"),
}
}
fn float_random(
shape: Shape,
distribution: Distribution,
device: &Device<Self>,
) -> FloatTensor<Self> {
if let CandleDevice::Cpu = device {
return Self::float_from_data(cpu_random::<F>(shape, distribution), device);
}
let shape = shape.dims;
let device = &(device.clone()).into();
match distribution {
Distribution::Default => CandleTensor::new(
candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape, device)
.unwrap()
.to_dtype(F::DTYPE)
.unwrap(),
),
Distribution::Bernoulli(prob) => CandleTensor::new(
candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape.clone(), device)
.unwrap()
.to_dtype(F::DTYPE)
.unwrap()
.lt(&super::candle_utils::fill(prob, shape, F::DTYPE, device))
.unwrap()
.to_dtype(F::DTYPE)
.unwrap(),
),
Distribution::Uniform(from, to) => CandleTensor::new(
candle_core::Tensor::rand(from.elem::<F>(), to.elem::<F>(), shape, device).unwrap(),
),
Distribution::Normal(mean, std) => CandleTensor::new(
candle_core::Tensor::randn(mean.elem::<F>(), std.elem::<F>(), shape, device)
.unwrap(),
),
}
}
async fn float_into_data(tensor: CandleTensor) -> Result<TensorData, ExecutionError> {
super::base::into_data(tensor)
}
fn float_device(tensor: &CandleTensor) -> Device<Self> {
super::base::device(tensor)
}
fn float_to_device(tensor: CandleTensor, device: &Device<Self>) -> CandleTensor {
super::base::to_device(tensor, device)
}
fn float_into_int(tensor: CandleTensor) -> IntTensor<Self> {
CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap())
}
fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
super::base::empty(shape, device, dtype.into_dtype())
}
fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap())
}
fn float_add_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
CandleTensor::new((lhs.tensor + rhs.elem::<f64>()).unwrap())
}
fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap())
}
fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
CandleTensor::new((lhs.tensor - rhs.elem::<f64>()).unwrap())
}
fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap())
}
fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
CandleTensor::new((lhs.tensor * rhs.elem::<f64>()).unwrap())
}
fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap())
}
fn float_div_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
CandleTensor::new((lhs.tensor / rhs.elem::<f64>()).unwrap())
}
fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(
(lhs.tensor.clone()
- lhs
.tensor
.broadcast_div(&rhs.tensor)
.unwrap()
.floor()
.unwrap()
.broadcast_mul(&rhs.tensor)
.unwrap())
.unwrap(),
)
}
fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
let rhs_val = rhs.elem::<f64>();
let division_result = (lhs.tensor.clone() / rhs_val).unwrap().floor().unwrap();
let product = division_result * rhs_val;
CandleTensor::new((lhs.tensor - product).unwrap())
}
fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
let lhs_contiguous = if !lhs.tensor.is_contiguous() {
lhs.tensor.contiguous().unwrap()
} else {
lhs.tensor
};
let rhs_contiguous = if !rhs.tensor.is_contiguous() {
rhs.tensor.contiguous().unwrap()
} else {
rhs.tensor
};
CandleTensor::new(lhs_contiguous.broadcast_matmul(&rhs_contiguous).unwrap())
}
fn float_cross(
lhs: FloatTensor<Self>,
rhs: FloatTensor<Self>,
dim: usize,
) -> FloatTensor<Self> {
super::base::cross(lhs, rhs, dim)
}
fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
super::base::swap_dims(tensor, dim1, dim2)
}
fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
super::base::reshape(tensor, shape)
}
fn float_gather(
dim: usize,
tensor: FloatTensor<Self>,
indices: IntTensor<Self>,
) -> FloatTensor<Self> {
let tensor = tensor.tensor.contiguous().unwrap();
let indices = indices.tensor.contiguous().unwrap();
CandleTensor::new(tensor.gather(&indices, dim).unwrap())
}
fn float_scatter_add(
dim: usize,
tensor: FloatTensor<Self>,
indices: IntTensor<Self>,
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
CandleTensor::new(
tensor
.tensor
.scatter_add(&indices.tensor, &value.tensor, dim)
.unwrap(),
)
}
fn float_select(
tensor: FloatTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap())
}
fn float_select_add(
tensor: FloatTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
CandleTensor::new(
tensor
.tensor
.index_add(&indices.tensor, &value.tensor, dim)
.unwrap(),
)
}
fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
super::base::slice_with_steps(tensor, slices)
}
fn float_slice_assign(
tensor: FloatTensor<Self>,
slices: &[Slice],
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
super::base::slice_assign(tensor, slices, value)
}
fn float_mask_where(
tensor: FloatTensor<Self>,
mask: BoolTensor<Self>,
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
super::base::mask_where_broadcasted(tensor, mask, value)
}
fn float_mask_fill(
tensor: FloatTensor<Self>,
mask: BoolTensor<Self>,
value: FloatElem<Self>,
) -> FloatTensor<Self> {
let value = super::candle_utils::fill_like::<F>(value, &tensor.tensor);
super::base::mask_where_broadcasted(tensor, mask, CandleTensor::new(value))
}
fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.eq(&rhs_broadcast).unwrap())
}
fn float_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
CandleTensor::new(lhs.tensor.eq(rhs).unwrap())
}
fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.gt(&rhs_broadcast).unwrap())
}
fn float_greater_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.gt(&super::candle_utils::fill_like::<F>(rhs, &lhs.tensor))
.unwrap(),
)
}
fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.ge(&rhs_broadcast).unwrap())
}
fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.ge(&super::candle_utils::fill_like::<F>(rhs, &lhs.tensor))
.unwrap(),
)
}
fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.lt(&rhs_broadcast).unwrap())
}
fn float_lower_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.lt(&super::candle_utils::fill_like::<F>(rhs, &lhs.tensor))
.unwrap(),
)
}
fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.le(&rhs_broadcast).unwrap())
}
fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.le(&super::candle_utils::fill_like::<F>(rhs, &lhs.tensor))
.unwrap(),
)
}
fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let sum = tensor.tensor.sum_all().unwrap().to_scalar::<F>().unwrap();
CandleTensor::from_data::<F>(
TensorData::new([sum].into(), [1]),
Self::float_device(&tensor),
)
}
fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
}
fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.mean_keepdim(dim).unwrap())
}
fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.cumsum(dim).unwrap())
}
fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| {
prev.broadcast_mul(curr)
});
CandleTensor::new(result)
}
fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| {
prev.broadcast_minimum(curr)
});
CandleTensor::new(result)
}
fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| {
prev.broadcast_maximum(curr)
});
CandleTensor::new(result)
}
fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.exp().unwrap())
}
fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.log().unwrap())
}
fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new((tensor.tensor + 1.).unwrap().log().unwrap())
}
fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: f32) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.powf(value.elem::<f64>()).unwrap())
}
fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.sqrt().unwrap())
}
fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.abs().unwrap())
}
fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.cos().unwrap())
}
fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let exp_x = tensor.tensor.exp().unwrap();
CandleTensor::new(((exp_x.clone() + exp_x.recip().unwrap()).unwrap() / 2.0).unwrap())
}
fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.sin().unwrap())
}
fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let exp_x = tensor.tensor.exp().unwrap();
CandleTensor::new(((exp_x.clone() - exp_x.recip().unwrap()).unwrap() / 2.0).unwrap())
}
fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new((tensor.tensor.sin().unwrap() / tensor.tensor.cos().unwrap()).unwrap())
}
fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.tanh().unwrap())
}
fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let neg_asin_x = Self::float_neg(Self::float_asin(tensor));
Self::float_add_scalar(neg_asin_x, core::f64::consts::FRAC_PI_2.elem())
}
fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let x_squared = Self::float_powi_scalar(tensor.clone(), 2.elem());
let x_sq_minus_one = Self::float_sub_scalar(x_squared, 1.elem());
let sqrt_term = Self::float_sqrt(x_sq_minus_one);
Self::float_log(Self::float_add(tensor, sqrt_term))
}
fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let x_squared = Self::float_powi_scalar(tensor.clone(), 2.elem());
let one_minus_x_sq = Self::float_add_scalar(Self::float_neg(x_squared), 1.elem());
let sqrt_term = Self::float_sqrt(one_minus_x_sq);
Self::float_atan(Self::float_div(tensor, sqrt_term))
}
fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let x_squared = Self::float_powi_scalar(tensor.clone(), 2.elem());
let x_sq_plus_one = Self::float_add_scalar(x_squared, 1.elem());
let sqrt_term = Self::float_sqrt(x_sq_plus_one);
Self::float_log(Self::float_add(tensor, sqrt_term))
}
fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let x_squared = Self::float_powi_scalar(tensor.clone(), 2.elem());
let one_plus_x_sq = Self::float_add_scalar(x_squared, 1.elem());
let sqrt_term = Self::float_sqrt(one_plus_x_sq);
Self::float_asin(Self::float_div(tensor, sqrt_term))
}
fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let num = (1.0 + tensor.tensor.clone()).unwrap();
let denom = (1.0 - tensor.tensor).unwrap();
CandleTensor::new(((num / denom).unwrap().log().unwrap() / 2.0).unwrap())
}
fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
let x_squared = Self::float_powi_scalar(rhs.clone(), 2.elem());
let y_squared = Self::float_powi_scalar(lhs.clone(), 2.elem());
let r = Self::float_sqrt(Self::float_add(x_squared, y_squared));
let ratio = Self::float_div(lhs, Self::float_add(r, rhs));
Self::float_mul_scalar(Self::float_atan(ratio), 2.elem())
}
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let inner = |tensor: FloatTensor<Self>| -> candle_core::Result<FloatTensor<Self>> {
let floor_a = tensor.tensor.floor()?;
let frac_part = tensor.tensor.sub(&floor_a)?;
let half = (candle_core::Tensor::ones_like(&tensor.tensor)? * 0.5)?;
let mask_half = frac_part.eq(&half)?;
let half_tensor = tensor.tensor.mul(&half)?;
let rounded_half = half_tensor.round()?;
let doubled =
rounded_half.mul(&(candle_core::Tensor::ones_like(&tensor.tensor)? * 2.0)?)?;
let standard_round = tensor.tensor.round()?;
Ok(CandleTensor::new(
mask_half.where_cond(&doubled, &standard_round)?,
))
};
inner(tensor).unwrap()
}
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.floor().unwrap())
}
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.ceil().unwrap())
}
fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let is_negative = tensor.tensor.lt(0.0).unwrap();
let floored = tensor.tensor.floor().unwrap();
let ceiled = tensor.tensor.ceil().unwrap();
CandleTensor::new(is_negative.where_cond(&ceiled, &floored).unwrap())
}
fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.erf().unwrap())
}
fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
super::base::cat(tensors, dim)
}
fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
CandleTensor::new(
tensor
.tensor
.argmax_keepdim(dim)
.unwrap()
.to_dtype(I::DTYPE)
.unwrap(),
)
}
fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
CandleTensor::new(
tensor
.tensor
.argmin_keepdim(dim)
.unwrap()
.to_dtype(I::DTYPE)
.unwrap(),
)
}
fn float_clamp_max(tensor: FloatTensor<Self>, max: FloatElem<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.minimum(max).unwrap())
}
fn float_clamp_min(tensor: FloatTensor<Self>, min: FloatElem<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.maximum(min).unwrap())
}
fn float_clamp(
tensor: FloatTensor<Self>,
min: FloatElem<Self>,
max: FloatElem<Self>,
) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.clamp(min, max).unwrap())
}
fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.recip().unwrap())
}
fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(
rhs.tensor
.broadcast_mul(&lhs.tensor.log().unwrap())
.unwrap()
.exp()
.unwrap(),
)
}
fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
super::base::permute(tensor, axes)
}
fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
super::base::flip(tensor, axes)
}
fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
expand(tensor, shape)
}
fn float_unfold(
tensor: FloatTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> FloatTensor<Self> {
unfold(tensor, dim, size, step)
}
fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
sign(tensor)
}
fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
let dtype = dtype.into_dtype();
if tensor.tensor.dtype() == dtype {
tensor
} else {
CandleTensor::new(tensor.tensor.to_dtype(dtype).unwrap())
}
}
}