use crate::device::Device;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::gpu_dispatch::{CompareOp, GpuBufferHandle};
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
#[derive(Debug)]
pub struct BoolTensor {
storage: TensorStorage<bool>,
shape: Vec<usize>,
}
impl Clone for BoolTensor {
fn clone(&self) -> Self {
Self {
storage: self.storage.clone(),
shape: self.shape.clone(),
}
}
}
impl BoolTensor {
pub fn from_vec(data: Vec<bool>, shape: Vec<usize>) -> FerrotorchResult<Self> {
let expected: usize = if shape.is_empty() {
1
} else {
shape.iter().product()
};
if data.len() != expected {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"BoolTensor::from_vec: data.len()={} != prod(shape)={} for shape {:?}",
data.len(),
expected,
shape
),
});
}
Ok(Self {
storage: TensorStorage::cpu(data),
shape,
})
}
pub fn from_slice(data: &[bool], shape: &[usize]) -> FerrotorchResult<Self> {
Self::from_vec(data.to_vec(), shape.to_vec())
}
pub fn zeros(shape: &[usize]) -> Self {
let total: usize = if shape.is_empty() {
1
} else {
shape.iter().product()
};
Self {
storage: TensorStorage::cpu(vec![false; total]),
shape: shape.to_vec(),
}
}
pub fn ones(shape: &[usize]) -> Self {
let total: usize = if shape.is_empty() {
1
} else {
shape.iter().product()
};
Self {
storage: TensorStorage::cpu(vec![true; total]),
shape: shape.to_vec(),
}
}
pub fn from_predicate<T: Float>(
t: &Tensor<T>,
pred: impl Fn(T) -> bool,
) -> FerrotorchResult<Self> {
let data = t.data_vec()?;
let mask: Vec<bool> = data.iter().map(|&v| pred(v)).collect();
Self::from_vec(mask, t.shape().to_vec())
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn numel(&self) -> usize {
self.storage.len()
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
#[inline]
pub fn device(&self) -> Device {
self.storage.device()
}
#[inline]
pub fn is_cuda(&self) -> bool {
self.device().is_cuda()
}
pub fn data(&self) -> FerrotorchResult<&[bool]> {
self.storage.try_as_slice()
}
pub fn gpu_handle(&self) -> FerrotorchResult<&GpuBufferHandle> {
self.storage
.gpu_handle()
.ok_or(FerrotorchError::InvalidArgument {
message: "BoolTensor is not on a CUDA GPU".into(),
})
}
pub fn from_gpu_handle(handle: GpuBufferHandle, shape: Vec<usize>) -> Self {
debug_assert_eq!(
handle.dtype(),
<bool as crate::dtype::Element>::dtype(),
"from_gpu_handle: handle dtype tag must be Bool"
);
Self {
storage: TensorStorage::gpu(handle),
shape,
}
}
pub fn to(&self, device: Device) -> FerrotorchResult<BoolTensor> {
if self.device() == device {
return Ok(self.clone());
}
match (self.device(), device) {
(Device::Cpu, Device::Cuda(_)) => {
let data = self.data()?.to_vec();
let storage = TensorStorage::on_device(data, device)?;
Ok(Self {
storage,
shape: self.shape.clone(),
})
}
(Device::Cuda(_), Device::Cpu) => {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let handle = self.gpu_handle()?;
let bytes = backend.gpu_to_cpu(handle)?;
let data: Vec<bool> = bytes.iter().map(|&b| b != 0).collect();
Ok(Self {
storage: TensorStorage::cpu(data),
shape: self.shape.clone(),
})
}
(Device::Cuda(_), Device::Cuda(_)) => {
let cpu = self.to(Device::Cpu)?;
cpu.to(device)
}
(from, to) => Err(FerrotorchError::InvalidArgument {
message: format!(
"BoolTensor::to: unsupported device transfer {from:?} -> {to:?} \
(CPU <-> CUDA only)"
),
}),
}
}
pub fn not(&self) -> Self {
if self.is_cuda() {
return self
.unary_gpu(|b, h| b.bool_not(h))
.expect("BoolTensor::not GPU kernel");
}
let out: Vec<bool> = self
.data()
.expect("CPU BoolTensor data")
.iter()
.map(|&b| !b)
.collect();
Self {
storage: TensorStorage::cpu(out),
shape: self.shape.clone(),
}
}
pub fn and(&self, other: &Self) -> FerrotorchResult<Self> {
self.binary_op(other, |b, a, c| b.bool_and(a, c), |a, b| a && b, "and")
}
pub fn or(&self, other: &Self) -> FerrotorchResult<Self> {
self.binary_op(other, |b, a, c| b.bool_or(a, c), |a, b| a || b, "or")
}
pub fn xor(&self, other: &Self) -> FerrotorchResult<Self> {
self.binary_op(other, |b, a, c| b.bool_xor(a, c), |a, b| a ^ b, "xor")
}
fn unary_gpu(
&self,
gpu: impl FnOnce(
&dyn crate::gpu_dispatch::GpuBackend,
&GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle>,
) -> FerrotorchResult<Self> {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = gpu(backend, self.gpu_handle()?)?;
Ok(Self::from_gpu_handle(h, self.shape.clone()))
}
fn binary_op(
&self,
other: &Self,
gpu: impl FnOnce(
&dyn crate::gpu_dispatch::GpuBackend,
&GpuBufferHandle,
&GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle>,
f: impl Fn(bool, bool) -> bool,
op_name: &str,
) -> FerrotorchResult<Self> {
if self.device() != other.device() {
return Err(FerrotorchError::DeviceMismatch {
expected: self.device(),
got: other.device(),
});
}
if self.shape != other.shape {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"BoolTensor::{op_name}: shapes {:?} vs {:?}",
self.shape, other.shape
),
});
}
if self.is_cuda() {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = gpu(backend, self.gpu_handle()?, other.gpu_handle()?)?;
return Ok(Self::from_gpu_handle(h, self.shape.clone()));
}
let out: Vec<bool> = self
.data()?
.iter()
.zip(other.data()?.iter())
.map(|(&a, &b)| f(a, b))
.collect();
Ok(Self {
storage: TensorStorage::cpu(out),
shape: self.shape.clone(),
})
}
pub fn reshape(&self, shape: &[usize]) -> FerrotorchResult<Self> {
let new_total: usize = if shape.is_empty() {
1
} else {
shape.iter().product()
};
let cur = self.storage.len();
if new_total != cur {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"BoolTensor::reshape: new shape {shape:?} (numel {new_total}) != current numel {cur}"
),
});
}
Ok(Self {
storage: self.storage.clone(),
shape: shape.to_vec(),
})
}
pub fn count_true(&self) -> FerrotorchResult<usize> {
Ok(self.data()?.iter().filter(|&&b| b).count())
}
pub fn any(&self) -> FerrotorchResult<bool> {
if self.is_cuda() {
return self.reduce_gpu(|b, h| b.bool_any(h));
}
Ok(self.data()?.iter().any(|&b| b))
}
pub fn all(&self) -> FerrotorchResult<bool> {
if self.is_cuda() {
return self.reduce_gpu(|b, h| b.bool_all(h));
}
Ok(self.data()?.iter().all(|&b| b))
}
fn reduce_gpu(
&self,
gpu: impl FnOnce(
&dyn crate::gpu_dispatch::GpuBackend,
&GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle>,
) -> FerrotorchResult<bool> {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let reduced = gpu(backend, self.gpu_handle()?)?;
let bytes = backend.gpu_to_cpu(&reduced)?;
Ok(bytes.first().is_some_and(|&b| b != 0))
}
pub fn gt<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Self> {
Self::compare_float(a, b, CompareOp::Gt, |x, y| x > y)
}
pub fn lt<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Self> {
Self::compare_float(a, b, CompareOp::Lt, |x, y| x < y)
}
pub fn ge<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Self> {
Self::compare_float(a, b, CompareOp::Ge, |x, y| x >= y)
}
pub fn le<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Self> {
Self::compare_float(a, b, CompareOp::Le, |x, y| x <= y)
}
pub fn eq_t<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Self> {
Self::compare_float(a, b, CompareOp::Eq, |x, y| x == y)
}
pub fn ne<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Self> {
Self::compare_float(a, b, CompareOp::Ne, |x, y| x != y)
}
fn compare_float<T: Float>(
a: &Tensor<T>,
b: &Tensor<T>,
op: CompareOp,
f: impl Fn(T, T) -> bool,
) -> FerrotorchResult<Self> {
if a.shape() != b.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"BoolTensor::{}: shapes {:?} vs {:?}",
op.suffix(),
a.shape(),
b.shape()
),
});
}
if a.device() != b.device() {
return Err(FerrotorchError::DeviceMismatch {
expected: a.device(),
got: b.device(),
});
}
if a.is_cuda() {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let a = a.contiguous()?;
let b = b.contiguous()?;
let h = backend.compare(a.gpu_handle()?, b.gpu_handle()?, op)?;
return Ok(Self::from_gpu_handle(h, a.shape().to_vec()));
}
let a_data = a.data_vec()?;
let b_data = b.data_vec()?;
let result: Vec<bool> = a_data
.iter()
.zip(b_data.iter())
.map(|(&x, &y)| f(x, y))
.collect();
Self::from_vec(result, a.shape().to_vec())
}
pub fn gt_int<I: crate::int_tensor::IntElement>(
a: &crate::int_tensor::IntTensor<I>,
b: &crate::int_tensor::IntTensor<I>,
) -> FerrotorchResult<Self> {
Self::compare_int(a, b, CompareOp::Gt, |x, y| x > y)
}
pub fn lt_int<I: crate::int_tensor::IntElement>(
a: &crate::int_tensor::IntTensor<I>,
b: &crate::int_tensor::IntTensor<I>,
) -> FerrotorchResult<Self> {
Self::compare_int(a, b, CompareOp::Lt, |x, y| x < y)
}
pub fn ge_int<I: crate::int_tensor::IntElement>(
a: &crate::int_tensor::IntTensor<I>,
b: &crate::int_tensor::IntTensor<I>,
) -> FerrotorchResult<Self> {
Self::compare_int(a, b, CompareOp::Ge, |x, y| x >= y)
}
pub fn le_int<I: crate::int_tensor::IntElement>(
a: &crate::int_tensor::IntTensor<I>,
b: &crate::int_tensor::IntTensor<I>,
) -> FerrotorchResult<Self> {
Self::compare_int(a, b, CompareOp::Le, |x, y| x <= y)
}
pub fn eq_int<I: crate::int_tensor::IntElement>(
a: &crate::int_tensor::IntTensor<I>,
b: &crate::int_tensor::IntTensor<I>,
) -> FerrotorchResult<Self> {
Self::compare_int(a, b, CompareOp::Eq, |x, y| x == y)
}
pub fn ne_int<I: crate::int_tensor::IntElement>(
a: &crate::int_tensor::IntTensor<I>,
b: &crate::int_tensor::IntTensor<I>,
) -> FerrotorchResult<Self> {
Self::compare_int(a, b, CompareOp::Ne, |x, y| x != y)
}
fn compare_int<I: crate::int_tensor::IntElement>(
a: &crate::int_tensor::IntTensor<I>,
b: &crate::int_tensor::IntTensor<I>,
op: CompareOp,
f: impl Fn(i64, i64) -> bool,
) -> FerrotorchResult<Self> {
if a.shape() != b.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"BoolTensor::{}_int: shapes {:?} vs {:?}",
op.suffix(),
a.shape(),
b.shape()
),
});
}
if a.device() != b.device() {
return Err(FerrotorchError::DeviceMismatch {
expected: a.device(),
got: b.device(),
});
}
if a.is_cuda() {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = backend.compare(a.gpu_handle()?, b.gpu_handle()?, op)?;
return Ok(Self::from_gpu_handle(h, a.shape().to_vec()));
}
let result: Vec<bool> = a
.data()?
.iter()
.zip(b.data()?.iter())
.map(|(&x, &y)| f(x.to_i64(), y.to_i64()))
.collect();
Self::from_vec(result, a.shape().to_vec())
}
pub fn to_float<T: Float>(&self) -> FerrotorchResult<Tensor<T>> {
if self.is_cuda() {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = backend
.cast_bool_to_f(self.gpu_handle()?, <T as crate::dtype::Element>::dtype())?;
return Tensor::from_storage(TensorStorage::gpu(h), self.shape.clone(), false);
}
let one = T::from(1.0).unwrap();
let zero = T::from(0.0).unwrap();
let data: Vec<T> = self
.data()?
.iter()
.map(|&b| if b { one } else { zero })
.collect();
Tensor::from_storage(TensorStorage::cpu(data), self.shape.clone(), false)
}
}
impl std::fmt::Display for BoolTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"BoolTensor(shape={:?}, len={}, device={:?})",
self.shape,
self.storage.len(),
self.device(),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zeros_and_ones() {
let z = BoolTensor::zeros(&[2, 3]);
let o = BoolTensor::ones(&[2, 3]);
assert_eq!(z.numel(), 6);
assert_eq!(o.numel(), 6);
assert!(z.data().unwrap().iter().all(|&b| !b));
assert!(o.data().unwrap().iter().all(|&b| b));
}
#[test]
fn from_vec_shape_mismatch_errors() {
let err = BoolTensor::from_vec(vec![true, false], vec![3]).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn from_predicate_builds_mask() {
let t = crate::creation::from_slice::<f32>(&[-1.0, 0.0, 1.0, 2.0], &[4]).unwrap();
let mask = BoolTensor::from_predicate(&t, |x| x > 0.0).unwrap();
assert_eq!(mask.data().unwrap(), &[false, false, true, true]);
}
#[test]
fn pointwise_not() {
let m = BoolTensor::from_vec(vec![true, false, true], vec![3]).unwrap();
let n = m.not();
assert_eq!(n.data().unwrap(), &[false, true, false]);
}
#[test]
fn pointwise_and_or_xor() {
let a = BoolTensor::from_vec(vec![true, false, true, false], vec![4]).unwrap();
let b = BoolTensor::from_vec(vec![true, true, false, false], vec![4]).unwrap();
assert_eq!(
a.and(&b).unwrap().data().unwrap(),
&[true, false, false, false]
);
assert_eq!(
a.or(&b).unwrap().data().unwrap(),
&[true, true, true, false]
);
assert_eq!(
a.xor(&b).unwrap().data().unwrap(),
&[false, true, true, false]
);
}
#[test]
fn binary_op_shape_mismatch() {
let a = BoolTensor::ones(&[3]);
let b = BoolTensor::ones(&[2]);
assert!(matches!(
a.and(&b).unwrap_err(),
FerrotorchError::ShapeMismatch { .. }
));
}
#[test]
fn count_true_any_all() {
let m = BoolTensor::from_vec(vec![true, false, true], vec![3]).unwrap();
assert_eq!(m.count_true().unwrap(), 2);
assert!(m.any().unwrap());
assert!(!m.all().unwrap());
let z = BoolTensor::zeros(&[3]);
assert!(!z.any().unwrap());
assert_eq!(z.count_true().unwrap(), 0);
let o = BoolTensor::ones(&[3]);
assert!(o.all().unwrap());
assert_eq!(o.count_true().unwrap(), 3);
}
#[test]
fn reshape_preserves_data() {
let m = BoolTensor::from_vec(vec![true, false, true, false, true, false], vec![6]).unwrap();
let r = m.reshape(&[2, 3]).unwrap();
assert_eq!(r.shape(), &[2, 3]);
assert_eq!(r.data().unwrap(), m.data().unwrap());
}
#[test]
fn to_float_emits_zeros_and_ones() {
let m = BoolTensor::from_vec(vec![true, false, true], vec![3]).unwrap();
let f = m.to_float::<f32>().unwrap();
assert_eq!(f.data().unwrap(), &[1.0_f32, 0.0, 1.0]);
}
#[test]
fn cpu_tensor_reports_cpu_device() {
let m = BoolTensor::ones(&[5]);
assert_eq!(m.device(), Device::Cpu);
assert!(!m.is_cuda());
assert!(m.gpu_handle().is_err());
}
#[test]
fn clone_preserves_cpu_data() {
let m = BoolTensor::from_vec(vec![true, false, true, true, false], vec![5]).unwrap();
let m2 = m.clone();
assert_eq!(m2.data().unwrap(), &[true, false, true, true, false]);
assert_eq!(m2.device(), Device::Cpu);
}
#[test]
fn compare_gt_basic() {
let a = crate::creation::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
let b = crate::creation::from_slice::<f32>(&[0.0, 3.0, 3.0, 5.0], &[4]).unwrap();
let m = BoolTensor::gt(&a, &b).unwrap();
assert_eq!(m.data().unwrap(), &[true, false, false, false]);
}
#[test]
fn compare_lt_basic() {
let a = crate::creation::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3]).unwrap();
let b = crate::creation::from_slice::<f32>(&[2.0, 2.0, 4.0], &[3]).unwrap();
let m = BoolTensor::lt(&a, &b).unwrap();
assert_eq!(m.data().unwrap(), &[true, false, true]);
}
#[test]
fn compare_ge_le() {
let a = crate::creation::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3]).unwrap();
let b = crate::creation::from_slice::<f32>(&[1.0, 3.0, 2.0], &[3]).unwrap();
assert_eq!(
BoolTensor::ge(&a, &b).unwrap().data().unwrap(),
&[true, false, true]
);
assert_eq!(
BoolTensor::le(&a, &b).unwrap().data().unwrap(),
&[true, true, false]
);
}
#[test]
fn compare_eq_ne() {
let a = crate::creation::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3]).unwrap();
let b = crate::creation::from_slice::<f32>(&[1.0, 5.0, 3.0], &[3]).unwrap();
assert_eq!(
BoolTensor::eq_t(&a, &b).unwrap().data().unwrap(),
&[true, false, true]
);
assert_eq!(
BoolTensor::ne(&a, &b).unwrap().data().unwrap(),
&[false, true, false]
);
}
#[test]
fn compare_int_basic() {
use crate::int_tensor::IntTensor;
let a = IntTensor::<i32>::from_vec(vec![1, 5, 3, 8], vec![4]).unwrap();
let b = IntTensor::<i32>::from_vec(vec![2, 5, 1, 8], vec![4]).unwrap();
assert_eq!(
BoolTensor::gt_int(&a, &b).unwrap().data().unwrap(),
&[false, false, true, false]
);
assert_eq!(
BoolTensor::eq_int(&a, &b).unwrap().data().unwrap(),
&[false, true, false, true]
);
assert_eq!(
BoolTensor::le_int(&a, &b).unwrap().data().unwrap(),
&[true, true, false, true]
);
}
#[test]
fn compare_rejects_shape_mismatch() {
let a = crate::creation::from_slice::<f32>(&[1.0, 2.0], &[2]).unwrap();
let b = crate::creation::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3]).unwrap();
let err = BoolTensor::gt(&a, &b).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
}