use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::tensor::Tensor;
fn check_inplace_allowed<T: Float>(tensor: &Tensor<T>, op_name: &str) -> FerrotorchResult<()> {
if tensor.grad_fn().is_some() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"in-place operation '{op_name}' not allowed on a tensor that is \
part of the computation graph (has grad_fn = {:?})",
tensor.grad_fn().map(|gf| gf.name()),
),
});
}
if tensor.requires_grad() && tensor.is_leaf() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"in-place operation '{op_name}' not allowed on a leaf tensor \
with requires_grad=true (the modification would not be tracked \
by autograd)",
),
});
}
Ok(())
}
impl<T: Float> Tensor<T> {
pub fn add_scalar_(&self, value: T) -> FerrotorchResult<&Self> {
check_inplace_allowed(self, "add_scalar_")?;
let mut data = self.data_vec()?;
for x in data.iter_mut() {
*x += value;
}
unsafe { self.update_data(&data)? };
Ok(self)
}
pub fn mul_scalar_(&self, value: T) -> FerrotorchResult<&Self> {
check_inplace_allowed(self, "mul_scalar_")?;
let mut data = self.data_vec()?;
for x in data.iter_mut() {
*x = *x * value;
}
unsafe { self.update_data(&data)? };
Ok(self)
}
pub fn fill_(&self, value: T) -> FerrotorchResult<&Self> {
check_inplace_allowed(self, "fill_")?;
let new_data = vec![value; self.numel()];
unsafe { self.update_data(&new_data)? };
Ok(self)
}
pub fn zero_(&self) -> FerrotorchResult<&Self> {
self.fill_(<T as num_traits::Zero>::zero())
}
pub fn clamp_(&self, min: T, max: T) -> FerrotorchResult<&Self> {
if min > max {
return Err(FerrotorchError::InvalidArgument {
message: format!("clamp_ requires min <= max, got min={min:?}, max={max:?}",),
});
}
check_inplace_allowed(self, "clamp_")?;
let mut data = self.data_vec()?;
for x in data.iter_mut() {
if *x < min {
*x = min;
} else if *x > max {
*x = max;
}
}
unsafe { self.update_data(&data)? };
Ok(self)
}
}
#[cfg(test)]
mod tests {
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
#[test]
fn test_add_scalar_basic() {
let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]), vec![3], false)
.unwrap();
t.add_scalar_(10.0).unwrap();
let data = t.data().unwrap();
assert_eq!(data, &[11.0, 12.0, 13.0]);
}
#[test]
fn test_add_scalar_negative() {
let t =
Tensor::from_storage(TensorStorage::cpu(vec![5.0f64, 10.0]), vec![2], false).unwrap();
t.add_scalar_(-3.0).unwrap();
let data = t.data().unwrap();
assert!((data[0] - 2.0).abs() < 1e-10);
assert!((data[1] - 7.0).abs() < 1e-10);
}
#[test]
fn test_add_scalar_chaining() {
let t =
Tensor::from_storage(TensorStorage::cpu(vec![0.0f32; 4]), vec![2, 2], false).unwrap();
t.add_scalar_(1.0).unwrap().add_scalar_(2.0).unwrap();
let data = t.data().unwrap();
assert_eq!(data, &[3.0, 3.0, 3.0, 3.0]);
}
#[test]
fn test_add_scalar_rejects_requires_grad_leaf() {
let t =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0, 2.0]), vec![2], true).unwrap();
let err = t.add_scalar_(1.0).unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("requires_grad=true"), "got: {msg}");
}
#[test]
fn test_mul_scalar_basic() {
let t = Tensor::from_storage(TensorStorage::cpu(vec![2.0f32, 3.0, 4.0]), vec![3], false)
.unwrap();
t.mul_scalar_(0.5).unwrap();
let data = t.data().unwrap();
assert_eq!(data, &[1.0, 1.5, 2.0]);
}
#[test]
fn test_mul_scalar_zero() {
let t = Tensor::from_storage(
TensorStorage::cpu(vec![42.0f64, -7.0, 100.0]),
vec![3],
false,
)
.unwrap();
t.mul_scalar_(0.0).unwrap();
let data = t.data().unwrap();
assert_eq!(data, &[0.0, 0.0, 0.0]);
}
#[test]
fn test_mul_scalar_rejects_requires_grad_leaf() {
let t = Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0]), vec![1], true).unwrap();
assert!(t.mul_scalar_(2.0).is_err());
}
#[test]
fn test_fill_basic() {
let t = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
vec![2, 2],
false,
)
.unwrap();
t.fill_(99.0).unwrap();
let data = t.data().unwrap();
assert_eq!(data, &[99.0, 99.0, 99.0, 99.0]);
}
#[test]
fn test_fill_scalar_tensor() {
let t = Tensor::from_storage(TensorStorage::cpu(vec![0.0f32]), vec![], false).unwrap();
t.fill_(42.0).unwrap();
assert_eq!(t.item().unwrap(), 42.0);
}
#[test]
fn test_fill_rejects_requires_grad_leaf() {
let t =
Tensor::<f64>::from_storage(TensorStorage::cpu(vec![1.0, 2.0]), vec![2], true).unwrap();
assert!(t.fill_(0.0).is_err());
}
#[test]
fn test_zero_basic() {
let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]), vec![3], false)
.unwrap();
t.zero_().unwrap();
let data = t.data().unwrap();
assert_eq!(data, &[0.0, 0.0, 0.0]);
}
#[test]
fn test_zero_empty_tensor() {
let t =
Tensor::from_storage(TensorStorage::cpu(Vec::<f32>::new()), vec![0], false).unwrap();
t.zero_().unwrap();
assert_eq!(t.numel(), 0);
}
#[test]
fn test_zero_rejects_requires_grad_leaf() {
let t = Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0]), vec![1], true).unwrap();
assert!(t.zero_().is_err());
}
#[test]
fn test_clamp_basic() {
let t = Tensor::from_storage(
TensorStorage::cpu(vec![-5.0f32, 0.0, 3.0, 10.0, 100.0]),
vec![5],
false,
)
.unwrap();
t.clamp_(0.0, 10.0).unwrap();
let data = t.data().unwrap();
assert_eq!(data, &[0.0, 0.0, 3.0, 10.0, 10.0]);
}
#[test]
fn test_clamp_all_within_range() {
let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0f64, 2.0, 3.0]), vec![3], false)
.unwrap();
t.clamp_(0.0, 10.0).unwrap();
let data = t.data().unwrap();
assert_eq!(data, &[1.0, 2.0, 3.0]);
}
#[test]
fn test_clamp_single_value_range() {
let t = Tensor::from_storage(
TensorStorage::cpu(vec![-1.0f32, 0.0, 1.0, 5.0]),
vec![4],
false,
)
.unwrap();
t.clamp_(3.0, 3.0).unwrap();
let data = t.data().unwrap();
assert_eq!(data, &[3.0, 3.0, 3.0, 3.0]);
}
#[test]
fn test_clamp_invalid_range() {
let t =
Tensor::from_storage(TensorStorage::cpu(vec![1.0f32, 2.0]), vec![2], false).unwrap();
let err = t.clamp_(10.0, 0.0).unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("min <= max"), "got: {msg}");
}
#[test]
fn test_clamp_rejects_requires_grad_leaf() {
let t =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0, 2.0]), vec![2], true).unwrap();
assert!(t.clamp_(0.0, 1.0).is_err());
}
#[test]
fn test_detached_tensor_allows_inplace() {
let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]), vec![3], true)
.unwrap();
let d = t.detach();
assert!(!d.requires_grad());
d.add_scalar_(10.0).unwrap();
let data = d.data().unwrap();
assert_eq!(data, &[11.0, 12.0, 13.0]);
}
#[test]
fn test_mixed_inplace_chaining() {
let t = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
vec![4],
false,
)
.unwrap();
t.add_scalar_(10.0)
.unwrap()
.mul_scalar_(2.0)
.unwrap()
.clamp_(20.0, 25.0)
.unwrap();
let data = t.data().unwrap();
assert_eq!(data, &[22.0, 24.0, 25.0, 25.0]);
}
#[test]
fn test_inplace_ops_f64() {
let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0f64, 2.0, 3.0]), vec![3], false)
.unwrap();
t.add_scalar_(100.0).unwrap();
t.mul_scalar_(0.1).unwrap();
let data = t.data().unwrap();
assert!((data[0] - 10.1).abs() < 1e-10);
assert!((data[1] - 10.2).abs() < 1e-10);
assert!((data[2] - 10.3).abs() < 1e-10);
}
}