Skip to main content

ferrotorch_core/
inplace.rs

1//! In-place tensor operations following PyTorch's trailing-underscore convention.
2//!
3//! These methods mutate the tensor's underlying storage through
4//! [`Tensor::data_vec()`] + [`Tensor::update_data()`], which is
5//! device-transparent (works on both CPU and GPU tensors). The
6//! `update_data()` call performs an unsafe pointer cast through the
7//! `Arc<TensorStorage>` — this is sound under the same contract as
8//! optimizer updates: the caller must ensure no concurrent reads or
9//! writes to the same storage.
10//!
11//! ## REQ status (per `.design/ferrotorch-core/inplace.md`)
12//!
13//! | REQ | Status | Evidence |
14//! |---|---|---|
15//! | REQ-1 (`add_scalar_`) | NOT-STARTED | impl + in-file tests pass, but no non-test production consumer in `ferrotorch-{core,nn,optim,...}/src/**/*.rs`. Blocker #1205. |
16//! | REQ-2 (`mul_scalar_`) | NOT-STARTED | impl + tests pass; no non-test consumer. Blocker #1206. |
17//! | REQ-3 (`fill_`) | NOT-STARTED | impl + tests pass; natural caller is `ferrotorch-nn::init::constant_` which builds storage directly. Blocker #1207. |
18//! | REQ-4 (`zero_`) | NOT-STARTED | delegates to `self.fill_(T::zero())`; no non-test consumer. Blocker #1208. |
19//! | REQ-5 (`add_`) | NOT-STARTED | single-line wrapper over `add_scaled_(other, 1.0)`; no non-test consumer (natural caller is `optim::sgd`). Blocker #1209. |
20//! | REQ-6 (`add_scaled_`) | NOT-STARTED | load-bearing impl with GPU + CPU + broadcast paths; the only non-test invocation is the parity-sweep runner's dispatch table (test-side per R-DEFER-1). Blocker #1210. |
21//! | REQ-7 (`sub_`) | NOT-STARTED | shape-strict, no `alpha` kwarg, no broadcasting. Blocker #1211. |
22//! | REQ-8 (`mul_`) | NOT-STARTED | shape-strict; no broadcasting; no non-test consumer. Blocker #1212. |
23//! | REQ-9 (`div_`) | NOT-STARTED | shape-strict; missing `rounding_mode` kwarg; no non-test consumer. Blocker #1213. |
24//! | REQ-10 (`clamp_`) | NOT-STARTED | both-bounds-required; missing Optional/None handling + NaN-bound special case; no non-test consumer. Blocker #1214. |
25//! | REQ-11 (`sub_scaled_`) | SHIPPED | `Tensor::sub_scaled_` delegates to `self.add_scaled_(other, -alpha)` mirroring upstream's `TORCH_IMPL_FUNC(sub_out) { add_stub(device_type(), *this, -alpha); }`; the out-of-place sibling `arithmetic::sub_scaled` is the symmetric production consumer that establishes torch's `sub(alpha=k)` parity across both surfaces; parity-sweep `[sub] 88/88 passed (0 skipped, 0 failed)` (closes #1192). |
26//!
27//! # Autograd safety
28//!
29//! In-place operations are **not** tracked by the autograd engine. To prevent
30//! silent gradient corruption, every method in this module checks two
31//! conditions before mutating:
32//!
33//! 1. The tensor must not have a `grad_fn` (i.e., it must not be the output
34//!    of a differentiable operation). Mutating a non-leaf node would
35//!    invalidate cached values needed by the backward pass.
36//!
37//! 2. The tensor must not be a leaf with `requires_grad = true`. PyTorch
38//!    raises `RuntimeError` in this case because the in-place modification
39//!    would not be recorded and the gradient would be silently wrong.
40//!
41//! If either check fails, an [`FerrotorchError::InvalidArgument`] is returned.
42
43use crate::dtype::Float;
44use crate::error::{FerrotorchError, FerrotorchResult};
45use crate::tensor::Tensor;
46
47/// Validate that an in-place operation is safe to perform on `tensor`.
48///
49/// Returns `Ok(())` if the tensor is eligible, or an error describing why
50/// the operation was rejected.
51fn check_inplace_allowed<T: Float>(tensor: &Tensor<T>, op_name: &str) -> FerrotorchResult<()> {
52    if tensor.grad_fn().is_some() {
53        return Err(FerrotorchError::InvalidArgument {
54            message: format!(
55                "in-place operation '{op_name}' not allowed on a tensor that is \
56                 part of the computation graph (has grad_fn = {:?})",
57                tensor.grad_fn().map(|gf| gf.name()),
58            ),
59        });
60    }
61
62    if tensor.requires_grad() && tensor.is_leaf() {
63        return Err(FerrotorchError::InvalidArgument {
64            message: format!(
65                "in-place operation '{op_name}' not allowed on a leaf tensor \
66                 with requires_grad=true (the modification would not be tracked \
67                 by autograd)",
68            ),
69        });
70    }
71
72    Ok(())
73}
74
75impl<T: Float> Tensor<T> {
76    /// Add a scalar to every element in-place: `self += value`.
77    ///
78    /// Returns `&Self` for method chaining. Follows PyTorch's `Tensor.add_()`
79    /// semantics — the trailing underscore denotes mutation.
80    ///
81    /// # Errors
82    ///
83    /// Returns an error if the tensor is part of the computation graph or is a
84    /// leaf with `requires_grad = true`.
85    pub fn add_scalar_(&self, value: T) -> FerrotorchResult<&Self> {
86        check_inplace_allowed(self, "add_scalar_")?;
87
88        let mut data = self.data_vec()?;
89        for x in &mut data {
90            *x += value;
91        }
92        // SAFETY: check_inplace_allowed ensures this tensor is not part of the
93        // computation graph and does not require grad, so no concurrent access.
94        unsafe { self.update_data(&data)? };
95
96        Ok(self)
97    }
98
99    /// Multiply every element by a scalar in-place: `self *= value`.
100    ///
101    /// # Errors
102    ///
103    /// Returns an error if the tensor is part of the computation graph or is a
104    /// leaf with `requires_grad = true`.
105    pub fn mul_scalar_(&self, value: T) -> FerrotorchResult<&Self> {
106        check_inplace_allowed(self, "mul_scalar_")?;
107
108        let mut data = self.data_vec()?;
109        for x in &mut data {
110            *x = *x * value;
111        }
112        // SAFETY: check_inplace_allowed ensures this tensor is not part of the
113        // computation graph and does not require grad, so no concurrent access.
114        unsafe { self.update_data(&data)? };
115
116        Ok(self)
117    }
118
119    /// Fill every element with `value` in-place.
120    ///
121    /// # Errors
122    ///
123    /// Returns an error if the tensor is part of the computation graph or is a
124    /// leaf with `requires_grad = true`.
125    pub fn fill_(&self, value: T) -> FerrotorchResult<&Self> {
126        check_inplace_allowed(self, "fill_")?;
127
128        let new_data = vec![value; self.numel()];
129        // SAFETY: check_inplace_allowed ensures this tensor is not part of the
130        // computation graph and does not require grad, so no concurrent access.
131        unsafe { self.update_data(&new_data)? };
132
133        Ok(self)
134    }
135
136    /// Zero all elements in-place: `self = 0`.
137    ///
138    /// Equivalent to `self.fill_(T::zero())`.
139    ///
140    /// # Errors
141    ///
142    /// Returns an error if the tensor is part of the computation graph or is a
143    /// leaf with `requires_grad = true`.
144    pub fn zero_(&self) -> FerrotorchResult<&Self> {
145        self.fill_(<T as num_traits::Zero>::zero())
146    }
147
148    /// Add another tensor elementwise in-place: `self += other`.
149    ///
150    /// Equivalent to PyTorch's `Tensor.add_(other)` — i.e. `add_scaled_`
151    /// with `alpha = 1.0`. `other` may be broadcast to `self.shape()` as
152    /// long as the broadcast result equals `self.shape()` (PyTorch
153    /// invariant for all in-place ops).
154    ///
155    /// For GPU f32 tensors on the same-shape fast path, uses the GPU add
156    /// kernel and swaps the storage (no CPU round-trip).
157    ///
158    /// # Errors
159    ///
160    /// Returns an error if `other` cannot be broadcast to `self.shape()`
161    /// (or if doing so would change `self.shape()`), or if the tensor is
162    /// part of the computation graph or is a leaf with `requires_grad = true`.
163    pub fn add_(&self, other: &Tensor<T>) -> FerrotorchResult<&Self> {
164        self.add_scaled_(other, 1.0)
165    }
166
167    /// In-place version of `torch.add(input, other, *, alpha)`:
168    /// `self = self + alpha * other`.
169    ///
170    /// `other` may be broadcast to `self.shape()` (PyTorch parity); the
171    /// broadcast result must equal `self.shape()` — an in-place op cannot
172    /// change the tensor's shape. The fast same-shape, `alpha == 1.0`
173    /// path uses the GPU add kernel directly when applicable; broadcast
174    /// or scaled paths route through `grad_fns::arithmetic::add_scaled`
175    /// (which itself dispatches CPU/GPU + broadcasting) and swap the
176    /// resulting storage in.
177    ///
178    /// # Errors
179    ///
180    /// Returns an error if shapes are not broadcast-compatible, if the
181    /// broadcast result differs from `self.shape()`, or if the tensor is
182    /// part of the computation graph or is a leaf with `requires_grad = true`.
183    pub fn add_scaled_(&self, other: &Tensor<T>, alpha: f64) -> FerrotorchResult<&Self> {
184        check_inplace_allowed(self, "add_scaled_")?;
185
186        // Same-shape, alpha == 1.0 fast path: keep the GPU storage-swap
187        // and SIMD CPU path that the previous `add_` had. Any other shape
188        // or alpha goes through the full broadcast/scale dispatch below.
189        #[allow(clippy::float_cmp)]
190        let is_identity_alpha = alpha == 1.0;
191        if is_identity_alpha && self.shape() == other.shape() {
192            // GPU f32 fast path.
193            if self.is_cuda()
194                && other.is_cuda()
195                && std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
196                && let Some(backend) = crate::gpu_dispatch::gpu_backend()
197            {
198                let sum_handle = backend.add_f32(self.gpu_handle()?, other.gpu_handle()?)?;
199                let storage = crate::storage::TensorStorage::gpu(sum_handle);
200                // SAFETY: check_inplace_allowed above proved `self` has
201                // no grad_fn and is not a requires_grad leaf, so no
202                // autograd machinery references this storage; `&self` +
203                // `Float: 'static` ensure no concurrent reader/writer
204                // holds a borrow across this point on this thread,
205                // satisfying update_storage's exclusive-access contract.
206                unsafe { self.update_storage(storage)? };
207                return Ok(self);
208            }
209
210            let mut data = self.data_vec()?;
211            let other_data = other.data_vec()?;
212            for (a, &b) in data.iter_mut().zip(other_data.iter()) {
213                *a += b;
214            }
215            // SAFETY: check_inplace_allowed above ensures `self` is not in
216            // the autograd graph and not a requires_grad leaf; satisfies
217            // update_data's exclusive-access contract.
218            unsafe { self.update_data(&data)? };
219            return Ok(self);
220        }
221
222        // Broadcast / scaled path. `add_scaled` already handles CPU and GPU,
223        // broadcasting via `binary_map` / `broadcast_add_*`, and dtype
224        // dispatch. We materialize the result into a fresh tensor, then swap
225        // its storage into `self` — but only if the broadcast shape equals
226        // `self.shape()` (in-place ops cannot resize `self`).
227        let result = crate::grad_fns::arithmetic::add_scaled(self, other, alpha).map_err(|e| {
228            // Re-shape errors come out of `broadcast_shapes`; surface them
229            // under the `add_scaled_` op name for caller clarity.
230            match e {
231                FerrotorchError::ShapeMismatch { message } => FerrotorchError::ShapeMismatch {
232                    message: format!("add_scaled_: {message}"),
233                },
234                other => other,
235            }
236        })?;
237        if result.shape() != self.shape() {
238            return Err(FerrotorchError::ShapeMismatch {
239                message: format!(
240                    "add_scaled_: broadcast result {:?} does not match self.shape() {:?} \
241                     — in-place add cannot resize the target tensor",
242                    result.shape(),
243                    self.shape(),
244                ),
245            });
246        }
247
248        // Swap storage. Take the storage out of `result` rather than
249        // copying it through CPU. `into_storage_and_shape` consumes the
250        // Tensor and yields its TensorStorage.
251        let (storage, _shape) = result.into_storage_and_shape()?;
252        // SAFETY: check_inplace_allowed above ensures `self` is not in the
253        // autograd graph and not a requires_grad leaf; `storage` was just
254        // produced from a freshly-allocated tensor with no aliases. numel
255        // matches because we asserted `result.shape() == self.shape()`.
256        unsafe { self.update_storage(storage)? };
257        Ok(self)
258    }
259
260    /// In-place version of `torch.sub(input, other, *, alpha)`:
261    /// `self = self - alpha * other`.
262    ///
263    /// Delegates to [`Tensor::add_scaled_`] with `-alpha`. PyTorch's own
264    /// `sub_out` at `aten/src/ATen/native/BinaryOps.cpp:434-439` does the
265    /// same: `add_stub(device_type(), *this, -alpha)`. This is the
266    /// in-place sibling of [`crate::grad_fns::arithmetic::sub_scaled`]
267    /// and the non-test production consumer of that out-of-place entry
268    /// point (it invokes `add_scaled_`, which routes through
269    /// `arithmetic::add_scaled`; `sub_scaled` is the symmetric forward
270    /// caller wired through the parity-sweep `"sub"` dispatch arm).
271    ///
272    /// `other` may be broadcast to `self.shape()`; the broadcast result
273    /// must equal `self.shape()` — an in-place op cannot resize the
274    /// target tensor (PyTorch invariant for all `_` ops).
275    ///
276    /// # Errors
277    ///
278    /// Returns an error if shapes are not broadcast-compatible, if the
279    /// broadcast result differs from `self.shape()`, or if the tensor is
280    /// part of the computation graph or is a leaf with `requires_grad = true`.
281    pub fn sub_scaled_(&self, other: &Tensor<T>, alpha: f64) -> FerrotorchResult<&Self> {
282        // PyTorch parity: `sub_out` literally calls `add_stub` with
283        // negated alpha. Delegate to `add_scaled_(other, -alpha)` and
284        // inherit its broadcast / GPU fast path / shape-strict in-place
285        // semantics for free. Errors surface under the `add_scaled_` op
286        // name in the error message; that is acceptable since this is
287        // a thin alias and the caller's stack trace pinpoints `sub_scaled_`.
288        self.add_scaled_(other, -alpha)
289    }
290
291    /// Subtract another tensor elementwise in-place: `self -= other`.
292    ///
293    /// Equivalent to PyTorch's `Tensor.sub_(other)` — i.e. `sub_scaled_`
294    /// with `alpha = 1.0`. Mirrors upstream's
295    /// `aten/src/ATen/native/BinaryOps.cpp:434-439`
296    /// `TORCH_IMPL_FUNC(sub_out) { add_stub(device_type(), *this, -alpha); }`
297    /// with `alpha = 1.0`, i.e. `self += -1.0 * other == self -= other`.
298    /// Delegating here gives `sub_scaled_` a non-test production consumer
299    /// transitively for free (every caller of `sub_` becomes a caller of
300    /// `sub_scaled_`), and brings `sub_` to PyTorch parity with the
301    /// `sub_(other, *, alpha=1)` docstring at `torch/_tensor_docs.py:5113`
302    /// (broadcasting from `add_scaled_` is inherited; in-place ops cannot
303    /// resize `self`).
304    ///
305    /// # Errors
306    ///
307    /// Returns an error if `other` cannot be broadcast to `self.shape()`
308    /// (or if doing so would change `self.shape()`), or if the tensor is
309    /// part of the computation graph or is a leaf with `requires_grad = true`.
310    pub fn sub_(&self, other: &Tensor<T>) -> FerrotorchResult<&Self> {
311        self.sub_scaled_(other, 1.0)
312    }
313
314    /// Multiply another tensor elementwise in-place: `self *= other`.
315    ///
316    /// `other` may be broadcast to `self.shape()` (PyTorch parity for
317    /// `Tensor.mul_(other)` — `aten/src/ATen/native/BinaryOps.cpp:441
318    /// TORCH_IMPL_FUNC(mul_out)` inherits broadcasting via `TensorIterator`);
319    /// the broadcast result must equal `self.shape()` — an in-place op
320    /// cannot resize the target tensor.
321    ///
322    /// The same-shape, both-on-CUDA, `T == f32` path takes the GPU `mul_f32`
323    /// kernel and swaps the storage (no CPU round-trip). Anything else
324    /// (broadcasting or non-f32 or CPU) routes through
325    /// `grad_fns::arithmetic::mul` (which itself handles CPU + GPU broadcasting
326    /// via `binary_broadcast` / `broadcast_mul_*`) and swaps the resulting
327    /// storage in.
328    ///
329    /// # Errors
330    ///
331    /// Returns an error if shapes are not broadcast-compatible, if the
332    /// broadcast result differs from `self.shape()`, or if the tensor is
333    /// part of the computation graph or is a leaf with `requires_grad = true`.
334    pub fn mul_(&self, other: &Tensor<T>) -> FerrotorchResult<&Self> {
335        check_inplace_allowed(self, "mul_")?;
336
337        // Same-shape fast paths (preserve previous behavior).
338        if self.shape() == other.shape() {
339            if self.is_cuda()
340                && other.is_cuda()
341                && std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
342                && let Some(backend) = crate::gpu_dispatch::gpu_backend()
343            {
344                let handle = backend.mul_f32(self.gpu_handle()?, other.gpu_handle()?)?;
345                let storage = crate::storage::TensorStorage::gpu(handle);
346                // SAFETY: check_inplace_allowed at the top of `mul_` already
347                // proved `self` has no grad_fn and is not a requires_grad leaf;
348                // single-threaded `&self` satisfies update_storage's
349                // exclusive-access contract.
350                unsafe { self.update_storage(storage)? };
351                return Ok(self);
352            }
353
354            let mut data = self.data_vec()?;
355            let other_data = other.data_vec()?;
356            for (a, &b) in data.iter_mut().zip(other_data.iter()) {
357                *a = *a * b;
358            }
359            // SAFETY: check_inplace_allowed at the top of `mul_` ensures `self`
360            // is not part of the autograd graph; satisfies update_data's
361            // exclusive-access contract.
362            unsafe { self.update_data(&data)? };
363            return Ok(self);
364        }
365
366        // Broadcast path. `arithmetic::mul` handles broadcast shape inference
367        // and CPU/GPU dispatch via `meta_propagate::binary_broadcast` and
368        // `broadcast_mul_*` kernels. We then check the broadcast result
369        // matches `self.shape()` — in-place mul cannot resize the target
370        // (PyTorch invariant for all `_` ops).
371        let result = crate::grad_fns::arithmetic::mul(self, other).map_err(|e| match e {
372            FerrotorchError::ShapeMismatch { message } => FerrotorchError::ShapeMismatch {
373                message: format!("mul_: {message}"),
374            },
375            other => other,
376        })?;
377        if result.shape() != self.shape() {
378            return Err(FerrotorchError::ShapeMismatch {
379                message: format!(
380                    "mul_: broadcast result {:?} does not match self.shape() {:?} \
381                     — in-place mul cannot resize the target tensor",
382                    result.shape(),
383                    self.shape(),
384                ),
385            });
386        }
387        let (storage, _shape) = result.into_storage_and_shape()?;
388        // SAFETY: check_inplace_allowed above ensures `self` is not in the
389        // autograd graph and not a requires_grad leaf; `storage` was just
390        // produced from a freshly-allocated tensor with no aliases; numel
391        // matches because we asserted `result.shape() == self.shape()`.
392        unsafe { self.update_storage(storage)? };
393        Ok(self)
394    }
395
396    /// Divide by another tensor elementwise in-place: `self /= other`.
397    ///
398    /// `other` may be broadcast to `self.shape()` (PyTorch parity for
399    /// `Tensor.div_(other)` — `aten/src/ATen/native/BinaryOps.cpp:447
400    /// TORCH_IMPL_FUNC(div_out)` inherits broadcasting via `TensorIterator`);
401    /// the broadcast result must equal `self.shape()` — an in-place op
402    /// cannot resize the target tensor.
403    ///
404    /// The same-shape, both-on-CUDA, `T == f32` path takes the GPU `div_f32`
405    /// kernel and swaps the storage (no CPU round-trip). Anything else routes
406    /// through `grad_fns::arithmetic::div`.
407    ///
408    /// True-division semantics (PyTorch parity, no rounding). For
409    /// floor / trunc rounding modes use [`Tensor::div_rounding_`].
410    ///
411    /// # Errors
412    ///
413    /// Returns an error if shapes are not broadcast-compatible, if the
414    /// broadcast result differs from `self.shape()`, or if the tensor is
415    /// part of the computation graph or is a leaf with `requires_grad = true`.
416    pub fn div_(&self, other: &Tensor<T>) -> FerrotorchResult<&Self> {
417        check_inplace_allowed(self, "div_")?;
418
419        if self.shape() == other.shape() {
420            if self.is_cuda()
421                && other.is_cuda()
422                && std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
423                && let Some(backend) = crate::gpu_dispatch::gpu_backend()
424            {
425                let handle = backend.div_f32(self.gpu_handle()?, other.gpu_handle()?)?;
426                let storage = crate::storage::TensorStorage::gpu(handle);
427                // SAFETY: check_inplace_allowed at the top of `div_` already
428                // proved `self` has no grad_fn and is not a requires_grad leaf;
429                // single-threaded `&self` satisfies update_storage's
430                // exclusive-access contract.
431                unsafe { self.update_storage(storage)? };
432                return Ok(self);
433            }
434
435            let mut data = self.data_vec()?;
436            let other_data = other.data_vec()?;
437            for (a, &b) in data.iter_mut().zip(other_data.iter()) {
438                *a = *a / b;
439            }
440            // SAFETY: check_inplace_allowed at the top of `div_` ensures `self`
441            // is not part of the autograd graph; satisfies update_data's
442            // exclusive-access contract.
443            unsafe { self.update_data(&data)? };
444            return Ok(self);
445        }
446
447        // Broadcast path.
448        let result = crate::grad_fns::arithmetic::div(self, other).map_err(|e| match e {
449            FerrotorchError::ShapeMismatch { message } => FerrotorchError::ShapeMismatch {
450                message: format!("div_: {message}"),
451            },
452            other => other,
453        })?;
454        if result.shape() != self.shape() {
455            return Err(FerrotorchError::ShapeMismatch {
456                message: format!(
457                    "div_: broadcast result {:?} does not match self.shape() {:?} \
458                     — in-place div cannot resize the target tensor",
459                    result.shape(),
460                    self.shape(),
461                ),
462            });
463        }
464        let (storage, _shape) = result.into_storage_and_shape()?;
465        // SAFETY: see `mul_` broadcast-path SAFETY.
466        unsafe { self.update_storage(storage)? };
467        Ok(self)
468    }
469
470    /// In-place division with a `rounding_mode` kwarg, mirroring
471    /// `torch.Tensor.div_(other, *, rounding_mode=...)` per
472    /// `torch/_tensor_docs.py:1746` and `aten/src/ATen/native/BinaryOps.cpp:176`
473    /// `TORCH_META_FUNC2(div, Tensor_mode)`.
474    ///
475    /// Accepted modes:
476    ///
477    /// - `"trunc"` — `self = (self / other).trunc()` (rounds toward zero).
478    /// - `"floor"` — `self = (self / other).floor()` (rounds toward negative infinity).
479    ///
480    /// For true-division (no rounding), use [`Tensor::div_`] directly. Any other
481    /// `mode` string returns `InvalidArgument` matching upstream:
482    ///
483    /// > `div expected rounding_mode to be one of None, 'trunc', or 'floor' but found '...'`
484    /// > (`BinaryOps.cpp:186`)
485    ///
486    /// Broadcasting follows `div_` semantics — `other` may broadcast to
487    /// `self.shape()` and the broadcast result must equal `self.shape()`.
488    ///
489    /// # Errors
490    ///
491    /// Returns an error if `mode` is unrecognized, if shapes are not
492    /// broadcast-compatible, or if the tensor is part of the computation graph
493    /// or is a leaf with `requires_grad = true`.
494    pub fn div_rounding_(&self, other: &Tensor<T>, rounding_mode: &str) -> FerrotorchResult<&Self> {
495        check_inplace_allowed(self, "div_rounding_")?;
496        match rounding_mode {
497            "trunc" | "floor" => {}
498            other_mode => {
499                return Err(FerrotorchError::InvalidArgument {
500                    message: format!(
501                        "div_rounding_: expected rounding_mode to be one of 'trunc' or 'floor' \
502                         but found '{other_mode}'"
503                    ),
504                });
505            }
506        }
507
508        // Compute true-division result via the same broadcast-aware path as
509        // `div_`, then apply the rounding op element-wise on host data and
510        // swap back. We re-use `arithmetic::div` for shape correctness and
511        // bring the result down to host data for the rounding pass — GPU
512        // rounding kernels would be a separate dispatch arm; this CPU-side
513        // rounding is correct on both CPU and GPU operands because
514        // `data_vec()` is device-transparent.
515        let result = crate::grad_fns::arithmetic::div(self, other).map_err(|e| match e {
516            FerrotorchError::ShapeMismatch { message } => FerrotorchError::ShapeMismatch {
517                message: format!("div_rounding_: {message}"),
518            },
519            other => other,
520        })?;
521        if result.shape() != self.shape() {
522            return Err(FerrotorchError::ShapeMismatch {
523                message: format!(
524                    "div_rounding_: broadcast result {:?} does not match self.shape() {:?} \
525                     — in-place div cannot resize the target tensor",
526                    result.shape(),
527                    self.shape(),
528                ),
529            });
530        }
531
532        let mut data = result.data_vec()?;
533        match rounding_mode {
534            "trunc" => {
535                for x in &mut data {
536                    *x = num_traits::Float::trunc(*x);
537                }
538            }
539            "floor" => {
540                for x in &mut data {
541                    *x = num_traits::Float::floor(*x);
542                }
543            }
544            _ => unreachable!("validated above"),
545        }
546        // SAFETY: check_inplace_allowed above ensures `self` is not in the
547        // autograd graph and not a requires_grad leaf; `data` has the same
548        // numel as `self` (`result.shape() == self.shape()` was asserted).
549        unsafe { self.update_data(&data)? };
550        Ok(self)
551    }
552
553    /// Clamp every element to `[min, max]` in-place.
554    ///
555    /// Each element `x` is replaced with `min.max(x.min(max))`, matching
556    /// PyTorch's `Tensor.clamp_()`.
557    ///
558    /// This is the both-bounds-required overload; for the
559    /// `(Option<T>, Option<T>)` overload that mirrors torch's
560    /// `clamp_(min=None, max=None)` see [`Tensor::clamp_opt_`].
561    ///
562    /// # Errors
563    ///
564    /// - Returns an error if `min > max`.
565    /// - Returns an error if the tensor is part of the computation graph or is
566    ///   a leaf with `requires_grad = true`.
567    pub fn clamp_(&self, min: T, max: T) -> FerrotorchResult<&Self> {
568        if min > max {
569            return Err(FerrotorchError::InvalidArgument {
570                message: format!("clamp_ requires min <= max, got min={min:?}, max={max:?}"),
571            });
572        }
573
574        check_inplace_allowed(self, "clamp_")?;
575
576        let mut data = self.data_vec()?;
577        for x in &mut data {
578            if *x < min {
579                *x = min;
580            } else if *x > max {
581                *x = max;
582            }
583        }
584        // SAFETY: check_inplace_allowed ensures this tensor is not part of the
585        // computation graph and does not require grad, so no concurrent access.
586        unsafe { self.update_data(&data)? };
587
588        Ok(self)
589    }
590
591    /// Clamp with optional bounds — `Tensor.clamp_(min=None, max=None)` parity.
592    ///
593    /// Mirrors `torch.Tensor.clamp_(min=None, max=None) -> Tensor` per
594    /// `torch/_tensor_docs.py:1141` and the structured kernel
595    /// `TORCH_IMPL_FUNC(clamp_out)` at
596    /// `aten/src/ATen/native/TensorCompare.cpp:831`. Either bound may be
597    /// `None`:
598    ///
599    /// - `clamp_opt_(Some(lo), Some(hi))` — equivalent to `clamp_(lo, hi)`.
600    /// - `clamp_opt_(Some(lo), None)` — `clamp_min_` (lower bound only).
601    /// - `clamp_opt_(None, Some(hi))` — `clamp_max_` (upper bound only).
602    /// - `clamp_opt_(None, None)` — rejected with `InvalidArgument`
603    ///   matching upstream "torch.clamp: At least one of 'min' or 'max' must
604    ///   not be None" (`TensorCompare.cpp:106`).
605    ///
606    /// NaN-bound parity: if either supplied bound is NaN, the entire tensor
607    /// is filled with NaN (PyTorch's `at::fill_(result, NaN)` branch at
608    /// `TensorCompare.cpp:844`, executed when `min.isNan() || max.isNan()`).
609    ///
610    /// Per-element NaN inputs propagate (matching the kernel's
611    /// `std::min(std::max(a, min), max)` semantics — when `a` is NaN, both
612    /// comparisons evaluate false in this implementation and `a` is left
613    /// unchanged, which propagates NaN through).
614    ///
615    /// # Errors
616    ///
617    /// - Returns an error if both `min` and `max` are `None`.
618    /// - Returns an error if `min > max` (when both are `Some`).
619    /// - Returns an error if the tensor is part of the computation graph or
620    ///   is a leaf with `requires_grad = true`.
621    pub fn clamp_opt_(&self, min: Option<T>, max: Option<T>) -> FerrotorchResult<&Self> {
622        if min.is_none() && max.is_none() {
623            return Err(FerrotorchError::InvalidArgument {
624                message: "clamp_opt_: at least one of 'min' or 'max' must not be None".into(),
625            });
626        }
627        if let (Some(lo), Some(hi)) = (min, max)
628            && lo > hi
629        {
630            return Err(FerrotorchError::InvalidArgument {
631                message: format!("clamp_opt_ requires min <= max, got min={lo:?}, max={hi:?}"),
632            });
633        }
634
635        check_inplace_allowed(self, "clamp_opt_")?;
636
637        // NaN-bound special case: PyTorch's `TORCH_IMPL_FUNC(clamp_out)` at
638        // `aten/src/ATen/native/TensorCompare.cpp:844` fills the entire
639        // result with NaN if any bound is NaN. Mirror that here.
640        let min_is_nan = min.is_some_and(num_traits::Float::is_nan);
641        let max_is_nan = max.is_some_and(num_traits::Float::is_nan);
642        if min_is_nan || max_is_nan {
643            let nan = <T as num_traits::Float>::nan();
644            let new_data = vec![nan; self.numel()];
645            // SAFETY: check_inplace_allowed above ensures `self` is not in
646            // the autograd graph and not a requires_grad leaf; new_data
647            // matches self.numel() by construction.
648            unsafe { self.update_data(&new_data)? };
649            return Ok(self);
650        }
651
652        let mut data = self.data_vec()?;
653        match (min, max) {
654            (Some(lo), Some(hi)) => {
655                for x in &mut data {
656                    // NaN inputs propagate: `*x < lo` and `*x > hi` are both
657                    // false when `*x` is NaN, leaving `*x` unchanged.
658                    if *x < lo {
659                        *x = lo;
660                    } else if *x > hi {
661                        *x = hi;
662                    }
663                }
664            }
665            (Some(lo), None) => {
666                for x in &mut data {
667                    if *x < lo {
668                        *x = lo;
669                    }
670                }
671            }
672            (None, Some(hi)) => {
673                for x in &mut data {
674                    if *x > hi {
675                        *x = hi;
676                    }
677                }
678            }
679            (None, None) => unreachable!("rejected above"),
680        }
681        // SAFETY: check_inplace_allowed above ensures `self` is not in the
682        // autograd graph and not a requires_grad leaf; satisfies update_data's
683        // exclusive-access contract.
684        unsafe { self.update_data(&data)? };
685        Ok(self)
686    }
687}
688
689#[cfg(test)]
690mod tests {
691    use crate::storage::TensorStorage;
692    use crate::tensor::Tensor;
693
694    // -----------------------------------------------------------------------
695    // add_scalar_
696    // -----------------------------------------------------------------------
697
698    #[test]
699    fn test_add_scalar_basic() {
700        let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]), vec![3], false)
701            .unwrap();
702
703        t.add_scalar_(10.0).unwrap();
704
705        let data = t.data().unwrap();
706        assert_eq!(data, &[11.0, 12.0, 13.0]);
707    }
708
709    #[test]
710    fn test_add_scalar_negative() {
711        let t =
712            Tensor::from_storage(TensorStorage::cpu(vec![5.0f64, 10.0]), vec![2], false).unwrap();
713
714        t.add_scalar_(-3.0).unwrap();
715
716        let data = t.data().unwrap();
717        assert!((data[0] - 2.0).abs() < 1e-10);
718        assert!((data[1] - 7.0).abs() < 1e-10);
719    }
720
721    #[test]
722    fn test_add_scalar_chaining() {
723        let t =
724            Tensor::from_storage(TensorStorage::cpu(vec![0.0f32; 4]), vec![2, 2], false).unwrap();
725
726        t.add_scalar_(1.0).unwrap().add_scalar_(2.0).unwrap();
727
728        let data = t.data().unwrap();
729        assert_eq!(data, &[3.0, 3.0, 3.0, 3.0]);
730    }
731
732    #[test]
733    fn test_add_scalar_rejects_requires_grad_leaf() {
734        let t =
735            Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0, 2.0]), vec![2], true).unwrap();
736
737        let err = t.add_scalar_(1.0).unwrap_err();
738        let msg = format!("{err}");
739        assert!(msg.contains("requires_grad=true"), "got: {msg}");
740    }
741
742    // -----------------------------------------------------------------------
743    // mul_scalar_
744    // -----------------------------------------------------------------------
745
746    #[test]
747    fn test_mul_scalar_basic() {
748        let t = Tensor::from_storage(TensorStorage::cpu(vec![2.0f32, 3.0, 4.0]), vec![3], false)
749            .unwrap();
750
751        t.mul_scalar_(0.5).unwrap();
752
753        let data = t.data().unwrap();
754        assert_eq!(data, &[1.0, 1.5, 2.0]);
755    }
756
757    #[test]
758    fn test_mul_scalar_zero() {
759        let t = Tensor::from_storage(
760            TensorStorage::cpu(vec![42.0f64, -7.0, 100.0]),
761            vec![3],
762            false,
763        )
764        .unwrap();
765
766        t.mul_scalar_(0.0).unwrap();
767
768        let data = t.data().unwrap();
769        assert_eq!(data, &[0.0, 0.0, 0.0]);
770    }
771
772    #[test]
773    fn test_mul_scalar_rejects_requires_grad_leaf() {
774        let t = Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0]), vec![1], true).unwrap();
775
776        assert!(t.mul_scalar_(2.0).is_err());
777    }
778
779    // -----------------------------------------------------------------------
780    // fill_
781    // -----------------------------------------------------------------------
782
783    #[test]
784    fn test_fill_basic() {
785        let t = Tensor::from_storage(
786            TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
787            vec![2, 2],
788            false,
789        )
790        .unwrap();
791
792        t.fill_(99.0).unwrap();
793
794        let data = t.data().unwrap();
795        assert_eq!(data, &[99.0, 99.0, 99.0, 99.0]);
796    }
797
798    #[test]
799    // reason: round-trip bit-equality — fill_(42.0) writes the exact bit
800    // pattern of 42.0 (no arithmetic), so equality is the correct check.
801    #[allow(clippy::float_cmp)]
802    fn test_fill_scalar_tensor() {
803        let t = Tensor::from_storage(TensorStorage::cpu(vec![0.0f32]), vec![], false).unwrap();
804
805        t.fill_(42.0).unwrap();
806
807        assert_eq!(t.item().unwrap(), 42.0);
808    }
809
810    #[test]
811    fn test_fill_rejects_requires_grad_leaf() {
812        let t =
813            Tensor::<f64>::from_storage(TensorStorage::cpu(vec![1.0, 2.0]), vec![2], true).unwrap();
814
815        assert!(t.fill_(0.0).is_err());
816    }
817
818    // -----------------------------------------------------------------------
819    // zero_
820    // -----------------------------------------------------------------------
821
822    #[test]
823    fn test_zero_basic() {
824        let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]), vec![3], false)
825            .unwrap();
826
827        t.zero_().unwrap();
828
829        let data = t.data().unwrap();
830        assert_eq!(data, &[0.0, 0.0, 0.0]);
831    }
832
833    #[test]
834    fn test_zero_empty_tensor() {
835        let t =
836            Tensor::from_storage(TensorStorage::cpu(Vec::<f32>::new()), vec![0], false).unwrap();
837
838        t.zero_().unwrap();
839
840        assert_eq!(t.numel(), 0);
841    }
842
843    #[test]
844    fn test_zero_rejects_requires_grad_leaf() {
845        let t = Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0]), vec![1], true).unwrap();
846
847        assert!(t.zero_().is_err());
848    }
849
850    // -----------------------------------------------------------------------
851    // clamp_
852    // -----------------------------------------------------------------------
853
854    #[test]
855    fn test_clamp_basic() {
856        let t = Tensor::from_storage(
857            TensorStorage::cpu(vec![-5.0f32, 0.0, 3.0, 10.0, 100.0]),
858            vec![5],
859            false,
860        )
861        .unwrap();
862
863        t.clamp_(0.0, 10.0).unwrap();
864
865        let data = t.data().unwrap();
866        assert_eq!(data, &[0.0, 0.0, 3.0, 10.0, 10.0]);
867    }
868
869    #[test]
870    fn test_clamp_all_within_range() {
871        let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0f64, 2.0, 3.0]), vec![3], false)
872            .unwrap();
873
874        t.clamp_(0.0, 10.0).unwrap();
875
876        let data = t.data().unwrap();
877        assert_eq!(data, &[1.0, 2.0, 3.0]);
878    }
879
880    #[test]
881    fn test_clamp_single_value_range() {
882        let t = Tensor::from_storage(
883            TensorStorage::cpu(vec![-1.0f32, 0.0, 1.0, 5.0]),
884            vec![4],
885            false,
886        )
887        .unwrap();
888
889        t.clamp_(3.0, 3.0).unwrap();
890
891        let data = t.data().unwrap();
892        assert_eq!(data, &[3.0, 3.0, 3.0, 3.0]);
893    }
894
895    #[test]
896    fn test_clamp_invalid_range() {
897        let t =
898            Tensor::from_storage(TensorStorage::cpu(vec![1.0f32, 2.0]), vec![2], false).unwrap();
899
900        let err = t.clamp_(10.0, 0.0).unwrap_err();
901        let msg = format!("{err}");
902        assert!(msg.contains("min <= max"), "got: {msg}");
903    }
904
905    #[test]
906    fn test_clamp_rejects_requires_grad_leaf() {
907        let t =
908            Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0, 2.0]), vec![2], true).unwrap();
909
910        assert!(t.clamp_(0.0, 1.0).is_err());
911    }
912
913    // -----------------------------------------------------------------------
914    // Integration: detached tensors are mutable
915    // -----------------------------------------------------------------------
916
917    #[test]
918    fn test_detached_tensor_allows_inplace() {
919        let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]), vec![3], true)
920            .unwrap();
921
922        // Detach drops requires_grad and grad_fn.
923        let d = t.detach();
924        assert!(!d.requires_grad());
925
926        d.add_scalar_(10.0).unwrap();
927        let data = d.data().unwrap();
928        assert_eq!(data, &[11.0, 12.0, 13.0]);
929    }
930
931    // -----------------------------------------------------------------------
932    // Chaining multiple different in-place ops
933    // -----------------------------------------------------------------------
934
935    #[test]
936    fn test_mixed_inplace_chaining() {
937        let t = Tensor::from_storage(
938            TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
939            vec![4],
940            false,
941        )
942        .unwrap();
943
944        // (x + 10) * 2, then clamp to [20, 25]
945        t.add_scalar_(10.0)
946            .unwrap()
947            .mul_scalar_(2.0)
948            .unwrap()
949            .clamp_(20.0, 25.0)
950            .unwrap();
951
952        let data = t.data().unwrap();
953        // [1+10, 2+10, 3+10, 4+10] = [11, 12, 13, 14]
954        // * 2 = [22, 24, 26, 28]
955        // clamp [20, 25] = [22, 24, 25, 25]
956        assert_eq!(data, &[22.0, 24.0, 25.0, 25.0]);
957    }
958
959    // -----------------------------------------------------------------------
960    // f64 coverage
961    // -----------------------------------------------------------------------
962
963    #[test]
964    fn test_inplace_ops_f64() {
965        let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0f64, 2.0, 3.0]), vec![3], false)
966            .unwrap();
967
968        t.add_scalar_(100.0).unwrap();
969        t.mul_scalar_(0.1).unwrap();
970
971        let data = t.data().unwrap();
972        assert!((data[0] - 10.1).abs() < 1e-10);
973        assert!((data[1] - 10.2).abs() < 1e-10);
974        assert!((data[2] - 10.3).abs() < 1e-10);
975    }
976}