ferrotorch_core/inplace.rs
1//! In-place tensor operations following PyTorch's trailing-underscore convention.
2//!
3//! These methods mutate the tensor's underlying storage through
4//! [`Tensor::data_vec()`] + [`Tensor::update_data()`], which is
5//! device-transparent (works on both CPU and GPU tensors). The
6//! `update_data()` call performs an unsafe pointer cast through the
7//! `Arc<TensorStorage>` — this is sound under the same contract as
8//! optimizer updates: the caller must ensure no concurrent reads or
9//! writes to the same storage.
10//!
11//! ## REQ status (per `.design/ferrotorch-core/inplace.md`)
12//!
13//! | REQ | Status | Evidence |
14//! |---|---|---|
15//! | REQ-1 (`add_scalar_`) | NOT-STARTED | impl + in-file tests pass, but no non-test production consumer in `ferrotorch-{core,nn,optim,...}/src/**/*.rs`. Blocker #1205. |
16//! | REQ-2 (`mul_scalar_`) | NOT-STARTED | impl + tests pass; no non-test consumer. Blocker #1206. |
17//! | REQ-3 (`fill_`) | NOT-STARTED | impl + tests pass; natural caller is `ferrotorch-nn::init::constant_` which builds storage directly. Blocker #1207. |
18//! | REQ-4 (`zero_`) | NOT-STARTED | delegates to `self.fill_(T::zero())`; no non-test consumer. Blocker #1208. |
19//! | REQ-5 (`add_`) | NOT-STARTED | single-line wrapper over `add_scaled_(other, 1.0)`; no non-test consumer (natural caller is `optim::sgd`). Blocker #1209. |
20//! | REQ-6 (`add_scaled_`) | NOT-STARTED | load-bearing impl with GPU + CPU + broadcast paths; the only non-test invocation is the parity-sweep runner's dispatch table (test-side per R-DEFER-1). Blocker #1210. |
21//! | REQ-7 (`sub_`) | NOT-STARTED | shape-strict, no `alpha` kwarg, no broadcasting. Blocker #1211. |
22//! | REQ-8 (`mul_`) | NOT-STARTED | shape-strict; no broadcasting; no non-test consumer. Blocker #1212. |
23//! | REQ-9 (`div_`) | NOT-STARTED | shape-strict; missing `rounding_mode` kwarg; no non-test consumer. Blocker #1213. |
24//! | REQ-10 (`clamp_`) | NOT-STARTED | both-bounds-required; missing Optional/None handling + NaN-bound special case; no non-test consumer. Blocker #1214. |
25//! | REQ-11 (`sub_scaled_`) | SHIPPED | `Tensor::sub_scaled_` delegates to `self.add_scaled_(other, -alpha)` mirroring upstream's `TORCH_IMPL_FUNC(sub_out) { add_stub(device_type(), *this, -alpha); }`; the out-of-place sibling `arithmetic::sub_scaled` is the symmetric production consumer that establishes torch's `sub(alpha=k)` parity across both surfaces; parity-sweep `[sub] 88/88 passed (0 skipped, 0 failed)` (closes #1192). |
26//!
27//! # Autograd safety
28//!
29//! In-place operations are **not** tracked by the autograd engine. To prevent
30//! silent gradient corruption, every method in this module checks two
31//! conditions before mutating:
32//!
33//! 1. The tensor must not have a `grad_fn` (i.e., it must not be the output
34//! of a differentiable operation). Mutating a non-leaf node would
35//! invalidate cached values needed by the backward pass.
36//!
37//! 2. The tensor must not be a leaf with `requires_grad = true`. PyTorch
38//! raises `RuntimeError` in this case because the in-place modification
39//! would not be recorded and the gradient would be silently wrong.
40//!
41//! If either check fails, an [`FerrotorchError::InvalidArgument`] is returned.
42
43use crate::dtype::Float;
44use crate::error::{FerrotorchError, FerrotorchResult};
45use crate::tensor::Tensor;
46
47/// Validate that an in-place operation is safe to perform on `tensor`.
48///
49/// Returns `Ok(())` if the tensor is eligible, or an error describing why
50/// the operation was rejected.
51fn check_inplace_allowed<T: Float>(tensor: &Tensor<T>, op_name: &str) -> FerrotorchResult<()> {
52 if tensor.grad_fn().is_some() {
53 return Err(FerrotorchError::InvalidArgument {
54 message: format!(
55 "in-place operation '{op_name}' not allowed on a tensor that is \
56 part of the computation graph (has grad_fn = {:?})",
57 tensor.grad_fn().map(|gf| gf.name()),
58 ),
59 });
60 }
61
62 if tensor.requires_grad() && tensor.is_leaf() {
63 return Err(FerrotorchError::InvalidArgument {
64 message: format!(
65 "in-place operation '{op_name}' not allowed on a leaf tensor \
66 with requires_grad=true (the modification would not be tracked \
67 by autograd)",
68 ),
69 });
70 }
71
72 Ok(())
73}
74
75impl<T: Float> Tensor<T> {
76 /// Add a scalar to every element in-place: `self += value`.
77 ///
78 /// Returns `&Self` for method chaining. Follows PyTorch's `Tensor.add_()`
79 /// semantics — the trailing underscore denotes mutation.
80 ///
81 /// # Errors
82 ///
83 /// Returns an error if the tensor is part of the computation graph or is a
84 /// leaf with `requires_grad = true`.
85 pub fn add_scalar_(&self, value: T) -> FerrotorchResult<&Self> {
86 check_inplace_allowed(self, "add_scalar_")?;
87
88 let mut data = self.data_vec()?;
89 for x in &mut data {
90 *x += value;
91 }
92 // SAFETY: check_inplace_allowed ensures this tensor is not part of the
93 // computation graph and does not require grad, so no concurrent access.
94 unsafe { self.update_data(&data)? };
95
96 Ok(self)
97 }
98
99 /// Multiply every element by a scalar in-place: `self *= value`.
100 ///
101 /// # Errors
102 ///
103 /// Returns an error if the tensor is part of the computation graph or is a
104 /// leaf with `requires_grad = true`.
105 pub fn mul_scalar_(&self, value: T) -> FerrotorchResult<&Self> {
106 check_inplace_allowed(self, "mul_scalar_")?;
107
108 let mut data = self.data_vec()?;
109 for x in &mut data {
110 *x = *x * value;
111 }
112 // SAFETY: check_inplace_allowed ensures this tensor is not part of the
113 // computation graph and does not require grad, so no concurrent access.
114 unsafe { self.update_data(&data)? };
115
116 Ok(self)
117 }
118
119 /// Fill every element with `value` in-place.
120 ///
121 /// # Errors
122 ///
123 /// Returns an error if the tensor is part of the computation graph or is a
124 /// leaf with `requires_grad = true`.
125 pub fn fill_(&self, value: T) -> FerrotorchResult<&Self> {
126 check_inplace_allowed(self, "fill_")?;
127
128 let new_data = vec![value; self.numel()];
129 // SAFETY: check_inplace_allowed ensures this tensor is not part of the
130 // computation graph and does not require grad, so no concurrent access.
131 unsafe { self.update_data(&new_data)? };
132
133 Ok(self)
134 }
135
136 /// Zero all elements in-place: `self = 0`.
137 ///
138 /// Equivalent to `self.fill_(T::zero())`.
139 ///
140 /// # Errors
141 ///
142 /// Returns an error if the tensor is part of the computation graph or is a
143 /// leaf with `requires_grad = true`.
144 pub fn zero_(&self) -> FerrotorchResult<&Self> {
145 self.fill_(<T as num_traits::Zero>::zero())
146 }
147
148 /// Add another tensor elementwise in-place: `self += other`.
149 ///
150 /// Equivalent to PyTorch's `Tensor.add_(other)` — i.e. `add_scaled_`
151 /// with `alpha = 1.0`. `other` may be broadcast to `self.shape()` as
152 /// long as the broadcast result equals `self.shape()` (PyTorch
153 /// invariant for all in-place ops).
154 ///
155 /// For GPU f32 tensors on the same-shape fast path, uses the GPU add
156 /// kernel and swaps the storage (no CPU round-trip).
157 ///
158 /// # Errors
159 ///
160 /// Returns an error if `other` cannot be broadcast to `self.shape()`
161 /// (or if doing so would change `self.shape()`), or if the tensor is
162 /// part of the computation graph or is a leaf with `requires_grad = true`.
163 pub fn add_(&self, other: &Tensor<T>) -> FerrotorchResult<&Self> {
164 self.add_scaled_(other, 1.0)
165 }
166
167 /// In-place version of `torch.add(input, other, *, alpha)`:
168 /// `self = self + alpha * other`.
169 ///
170 /// `other` may be broadcast to `self.shape()` (PyTorch parity); the
171 /// broadcast result must equal `self.shape()` — an in-place op cannot
172 /// change the tensor's shape. The fast same-shape, `alpha == 1.0`
173 /// path uses the GPU add kernel directly when applicable; broadcast
174 /// or scaled paths route through `grad_fns::arithmetic::add_scaled`
175 /// (which itself dispatches CPU/GPU + broadcasting) and swap the
176 /// resulting storage in.
177 ///
178 /// # Errors
179 ///
180 /// Returns an error if shapes are not broadcast-compatible, if the
181 /// broadcast result differs from `self.shape()`, or if the tensor is
182 /// part of the computation graph or is a leaf with `requires_grad = true`.
183 pub fn add_scaled_(&self, other: &Tensor<T>, alpha: f64) -> FerrotorchResult<&Self> {
184 check_inplace_allowed(self, "add_scaled_")?;
185
186 // Same-shape, alpha == 1.0 fast path: keep the GPU storage-swap
187 // and SIMD CPU path that the previous `add_` had. Any other shape
188 // or alpha goes through the full broadcast/scale dispatch below.
189 #[allow(clippy::float_cmp)]
190 let is_identity_alpha = alpha == 1.0;
191 if is_identity_alpha && self.shape() == other.shape() {
192 // GPU f32 fast path.
193 if self.is_cuda()
194 && other.is_cuda()
195 && std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
196 && let Some(backend) = crate::gpu_dispatch::gpu_backend()
197 {
198 let sum_handle = backend.add_f32(self.gpu_handle()?, other.gpu_handle()?)?;
199 let storage = crate::storage::TensorStorage::gpu(sum_handle);
200 // SAFETY: check_inplace_allowed above proved `self` has
201 // no grad_fn and is not a requires_grad leaf, so no
202 // autograd machinery references this storage; `&self` +
203 // `Float: 'static` ensure no concurrent reader/writer
204 // holds a borrow across this point on this thread,
205 // satisfying update_storage's exclusive-access contract.
206 unsafe { self.update_storage(storage)? };
207 return Ok(self);
208 }
209
210 let mut data = self.data_vec()?;
211 let other_data = other.data_vec()?;
212 for (a, &b) in data.iter_mut().zip(other_data.iter()) {
213 *a += b;
214 }
215 // SAFETY: check_inplace_allowed above ensures `self` is not in
216 // the autograd graph and not a requires_grad leaf; satisfies
217 // update_data's exclusive-access contract.
218 unsafe { self.update_data(&data)? };
219 return Ok(self);
220 }
221
222 // Broadcast / scaled path. `add_scaled` already handles CPU and GPU,
223 // broadcasting via `binary_map` / `broadcast_add_*`, and dtype
224 // dispatch. We materialize the result into a fresh tensor, then swap
225 // its storage into `self` — but only if the broadcast shape equals
226 // `self.shape()` (in-place ops cannot resize `self`).
227 let result = crate::grad_fns::arithmetic::add_scaled(self, other, alpha).map_err(|e| {
228 // Re-shape errors come out of `broadcast_shapes`; surface them
229 // under the `add_scaled_` op name for caller clarity.
230 match e {
231 FerrotorchError::ShapeMismatch { message } => FerrotorchError::ShapeMismatch {
232 message: format!("add_scaled_: {message}"),
233 },
234 other => other,
235 }
236 })?;
237 if result.shape() != self.shape() {
238 return Err(FerrotorchError::ShapeMismatch {
239 message: format!(
240 "add_scaled_: broadcast result {:?} does not match self.shape() {:?} \
241 — in-place add cannot resize the target tensor",
242 result.shape(),
243 self.shape(),
244 ),
245 });
246 }
247
248 // Swap storage. Take the storage out of `result` rather than
249 // copying it through CPU. `into_storage_and_shape` consumes the
250 // Tensor and yields its TensorStorage.
251 let (storage, _shape) = result.into_storage_and_shape()?;
252 // SAFETY: check_inplace_allowed above ensures `self` is not in the
253 // autograd graph and not a requires_grad leaf; `storage` was just
254 // produced from a freshly-allocated tensor with no aliases. numel
255 // matches because we asserted `result.shape() == self.shape()`.
256 unsafe { self.update_storage(storage)? };
257 Ok(self)
258 }
259
260 /// In-place version of `torch.sub(input, other, *, alpha)`:
261 /// `self = self - alpha * other`.
262 ///
263 /// Delegates to [`Tensor::add_scaled_`] with `-alpha`. PyTorch's own
264 /// `sub_out` at `aten/src/ATen/native/BinaryOps.cpp:434-439` does the
265 /// same: `add_stub(device_type(), *this, -alpha)`. This is the
266 /// in-place sibling of [`crate::grad_fns::arithmetic::sub_scaled`]
267 /// and the non-test production consumer of that out-of-place entry
268 /// point (it invokes `add_scaled_`, which routes through
269 /// `arithmetic::add_scaled`; `sub_scaled` is the symmetric forward
270 /// caller wired through the parity-sweep `"sub"` dispatch arm).
271 ///
272 /// `other` may be broadcast to `self.shape()`; the broadcast result
273 /// must equal `self.shape()` — an in-place op cannot resize the
274 /// target tensor (PyTorch invariant for all `_` ops).
275 ///
276 /// # Errors
277 ///
278 /// Returns an error if shapes are not broadcast-compatible, if the
279 /// broadcast result differs from `self.shape()`, or if the tensor is
280 /// part of the computation graph or is a leaf with `requires_grad = true`.
281 pub fn sub_scaled_(&self, other: &Tensor<T>, alpha: f64) -> FerrotorchResult<&Self> {
282 // PyTorch parity: `sub_out` literally calls `add_stub` with
283 // negated alpha. Delegate to `add_scaled_(other, -alpha)` and
284 // inherit its broadcast / GPU fast path / shape-strict in-place
285 // semantics for free. Errors surface under the `add_scaled_` op
286 // name in the error message; that is acceptable since this is
287 // a thin alias and the caller's stack trace pinpoints `sub_scaled_`.
288 self.add_scaled_(other, -alpha)
289 }
290
291 /// Subtract another tensor elementwise in-place: `self -= other`.
292 ///
293 /// Equivalent to PyTorch's `Tensor.sub_(other)` — i.e. `sub_scaled_`
294 /// with `alpha = 1.0`. Mirrors upstream's
295 /// `aten/src/ATen/native/BinaryOps.cpp:434-439`
296 /// `TORCH_IMPL_FUNC(sub_out) { add_stub(device_type(), *this, -alpha); }`
297 /// with `alpha = 1.0`, i.e. `self += -1.0 * other == self -= other`.
298 /// Delegating here gives `sub_scaled_` a non-test production consumer
299 /// transitively for free (every caller of `sub_` becomes a caller of
300 /// `sub_scaled_`), and brings `sub_` to PyTorch parity with the
301 /// `sub_(other, *, alpha=1)` docstring at `torch/_tensor_docs.py:5113`
302 /// (broadcasting from `add_scaled_` is inherited; in-place ops cannot
303 /// resize `self`).
304 ///
305 /// # Errors
306 ///
307 /// Returns an error if `other` cannot be broadcast to `self.shape()`
308 /// (or if doing so would change `self.shape()`), or if the tensor is
309 /// part of the computation graph or is a leaf with `requires_grad = true`.
310 pub fn sub_(&self, other: &Tensor<T>) -> FerrotorchResult<&Self> {
311 self.sub_scaled_(other, 1.0)
312 }
313
314 /// Multiply another tensor elementwise in-place: `self *= other`.
315 ///
316 /// `other` may be broadcast to `self.shape()` (PyTorch parity for
317 /// `Tensor.mul_(other)` — `aten/src/ATen/native/BinaryOps.cpp:441
318 /// TORCH_IMPL_FUNC(mul_out)` inherits broadcasting via `TensorIterator`);
319 /// the broadcast result must equal `self.shape()` — an in-place op
320 /// cannot resize the target tensor.
321 ///
322 /// The same-shape, both-on-CUDA, `T == f32` path takes the GPU `mul_f32`
323 /// kernel and swaps the storage (no CPU round-trip). Anything else
324 /// (broadcasting or non-f32 or CPU) routes through
325 /// `grad_fns::arithmetic::mul` (which itself handles CPU + GPU broadcasting
326 /// via `binary_broadcast` / `broadcast_mul_*`) and swaps the resulting
327 /// storage in.
328 ///
329 /// # Errors
330 ///
331 /// Returns an error if shapes are not broadcast-compatible, if the
332 /// broadcast result differs from `self.shape()`, or if the tensor is
333 /// part of the computation graph or is a leaf with `requires_grad = true`.
334 pub fn mul_(&self, other: &Tensor<T>) -> FerrotorchResult<&Self> {
335 check_inplace_allowed(self, "mul_")?;
336
337 // Same-shape fast paths (preserve previous behavior).
338 if self.shape() == other.shape() {
339 if self.is_cuda()
340 && other.is_cuda()
341 && std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
342 && let Some(backend) = crate::gpu_dispatch::gpu_backend()
343 {
344 let handle = backend.mul_f32(self.gpu_handle()?, other.gpu_handle()?)?;
345 let storage = crate::storage::TensorStorage::gpu(handle);
346 // SAFETY: check_inplace_allowed at the top of `mul_` already
347 // proved `self` has no grad_fn and is not a requires_grad leaf;
348 // single-threaded `&self` satisfies update_storage's
349 // exclusive-access contract.
350 unsafe { self.update_storage(storage)? };
351 return Ok(self);
352 }
353
354 let mut data = self.data_vec()?;
355 let other_data = other.data_vec()?;
356 for (a, &b) in data.iter_mut().zip(other_data.iter()) {
357 *a = *a * b;
358 }
359 // SAFETY: check_inplace_allowed at the top of `mul_` ensures `self`
360 // is not part of the autograd graph; satisfies update_data's
361 // exclusive-access contract.
362 unsafe { self.update_data(&data)? };
363 return Ok(self);
364 }
365
366 // Broadcast path. `arithmetic::mul` handles broadcast shape inference
367 // and CPU/GPU dispatch via `meta_propagate::binary_broadcast` and
368 // `broadcast_mul_*` kernels. We then check the broadcast result
369 // matches `self.shape()` — in-place mul cannot resize the target
370 // (PyTorch invariant for all `_` ops).
371 let result = crate::grad_fns::arithmetic::mul(self, other).map_err(|e| match e {
372 FerrotorchError::ShapeMismatch { message } => FerrotorchError::ShapeMismatch {
373 message: format!("mul_: {message}"),
374 },
375 other => other,
376 })?;
377 if result.shape() != self.shape() {
378 return Err(FerrotorchError::ShapeMismatch {
379 message: format!(
380 "mul_: broadcast result {:?} does not match self.shape() {:?} \
381 — in-place mul cannot resize the target tensor",
382 result.shape(),
383 self.shape(),
384 ),
385 });
386 }
387 let (storage, _shape) = result.into_storage_and_shape()?;
388 // SAFETY: check_inplace_allowed above ensures `self` is not in the
389 // autograd graph and not a requires_grad leaf; `storage` was just
390 // produced from a freshly-allocated tensor with no aliases; numel
391 // matches because we asserted `result.shape() == self.shape()`.
392 unsafe { self.update_storage(storage)? };
393 Ok(self)
394 }
395
396 /// Divide by another tensor elementwise in-place: `self /= other`.
397 ///
398 /// `other` may be broadcast to `self.shape()` (PyTorch parity for
399 /// `Tensor.div_(other)` — `aten/src/ATen/native/BinaryOps.cpp:447
400 /// TORCH_IMPL_FUNC(div_out)` inherits broadcasting via `TensorIterator`);
401 /// the broadcast result must equal `self.shape()` — an in-place op
402 /// cannot resize the target tensor.
403 ///
404 /// The same-shape, both-on-CUDA, `T == f32` path takes the GPU `div_f32`
405 /// kernel and swaps the storage (no CPU round-trip). Anything else routes
406 /// through `grad_fns::arithmetic::div`.
407 ///
408 /// True-division semantics (PyTorch parity, no rounding). For
409 /// floor / trunc rounding modes use [`Tensor::div_rounding_`].
410 ///
411 /// # Errors
412 ///
413 /// Returns an error if shapes are not broadcast-compatible, if the
414 /// broadcast result differs from `self.shape()`, or if the tensor is
415 /// part of the computation graph or is a leaf with `requires_grad = true`.
416 pub fn div_(&self, other: &Tensor<T>) -> FerrotorchResult<&Self> {
417 check_inplace_allowed(self, "div_")?;
418
419 if self.shape() == other.shape() {
420 if self.is_cuda()
421 && other.is_cuda()
422 && std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
423 && let Some(backend) = crate::gpu_dispatch::gpu_backend()
424 {
425 let handle = backend.div_f32(self.gpu_handle()?, other.gpu_handle()?)?;
426 let storage = crate::storage::TensorStorage::gpu(handle);
427 // SAFETY: check_inplace_allowed at the top of `div_` already
428 // proved `self` has no grad_fn and is not a requires_grad leaf;
429 // single-threaded `&self` satisfies update_storage's
430 // exclusive-access contract.
431 unsafe { self.update_storage(storage)? };
432 return Ok(self);
433 }
434
435 let mut data = self.data_vec()?;
436 let other_data = other.data_vec()?;
437 for (a, &b) in data.iter_mut().zip(other_data.iter()) {
438 *a = *a / b;
439 }
440 // SAFETY: check_inplace_allowed at the top of `div_` ensures `self`
441 // is not part of the autograd graph; satisfies update_data's
442 // exclusive-access contract.
443 unsafe { self.update_data(&data)? };
444 return Ok(self);
445 }
446
447 // Broadcast path.
448 let result = crate::grad_fns::arithmetic::div(self, other).map_err(|e| match e {
449 FerrotorchError::ShapeMismatch { message } => FerrotorchError::ShapeMismatch {
450 message: format!("div_: {message}"),
451 },
452 other => other,
453 })?;
454 if result.shape() != self.shape() {
455 return Err(FerrotorchError::ShapeMismatch {
456 message: format!(
457 "div_: broadcast result {:?} does not match self.shape() {:?} \
458 — in-place div cannot resize the target tensor",
459 result.shape(),
460 self.shape(),
461 ),
462 });
463 }
464 let (storage, _shape) = result.into_storage_and_shape()?;
465 // SAFETY: see `mul_` broadcast-path SAFETY.
466 unsafe { self.update_storage(storage)? };
467 Ok(self)
468 }
469
470 /// In-place division with a `rounding_mode` kwarg, mirroring
471 /// `torch.Tensor.div_(other, *, rounding_mode=...)` per
472 /// `torch/_tensor_docs.py:1746` and `aten/src/ATen/native/BinaryOps.cpp:176`
473 /// `TORCH_META_FUNC2(div, Tensor_mode)`.
474 ///
475 /// Accepted modes:
476 ///
477 /// - `"trunc"` — `self = (self / other).trunc()` (rounds toward zero).
478 /// - `"floor"` — `self = (self / other).floor()` (rounds toward negative infinity).
479 ///
480 /// For true-division (no rounding), use [`Tensor::div_`] directly. Any other
481 /// `mode` string returns `InvalidArgument` matching upstream:
482 ///
483 /// > `div expected rounding_mode to be one of None, 'trunc', or 'floor' but found '...'`
484 /// > (`BinaryOps.cpp:186`)
485 ///
486 /// Broadcasting follows `div_` semantics — `other` may broadcast to
487 /// `self.shape()` and the broadcast result must equal `self.shape()`.
488 ///
489 /// # Errors
490 ///
491 /// Returns an error if `mode` is unrecognized, if shapes are not
492 /// broadcast-compatible, or if the tensor is part of the computation graph
493 /// or is a leaf with `requires_grad = true`.
494 pub fn div_rounding_(&self, other: &Tensor<T>, rounding_mode: &str) -> FerrotorchResult<&Self> {
495 check_inplace_allowed(self, "div_rounding_")?;
496 match rounding_mode {
497 "trunc" | "floor" => {}
498 other_mode => {
499 return Err(FerrotorchError::InvalidArgument {
500 message: format!(
501 "div_rounding_: expected rounding_mode to be one of 'trunc' or 'floor' \
502 but found '{other_mode}'"
503 ),
504 });
505 }
506 }
507
508 // Compute true-division result via the same broadcast-aware path as
509 // `div_`, then apply the rounding op element-wise on host data and
510 // swap back. We re-use `arithmetic::div` for shape correctness and
511 // bring the result down to host data for the rounding pass — GPU
512 // rounding kernels would be a separate dispatch arm; this CPU-side
513 // rounding is correct on both CPU and GPU operands because
514 // `data_vec()` is device-transparent.
515 let result = crate::grad_fns::arithmetic::div(self, other).map_err(|e| match e {
516 FerrotorchError::ShapeMismatch { message } => FerrotorchError::ShapeMismatch {
517 message: format!("div_rounding_: {message}"),
518 },
519 other => other,
520 })?;
521 if result.shape() != self.shape() {
522 return Err(FerrotorchError::ShapeMismatch {
523 message: format!(
524 "div_rounding_: broadcast result {:?} does not match self.shape() {:?} \
525 — in-place div cannot resize the target tensor",
526 result.shape(),
527 self.shape(),
528 ),
529 });
530 }
531
532 let mut data = result.data_vec()?;
533 match rounding_mode {
534 "trunc" => {
535 for x in &mut data {
536 *x = num_traits::Float::trunc(*x);
537 }
538 }
539 "floor" => {
540 for x in &mut data {
541 *x = num_traits::Float::floor(*x);
542 }
543 }
544 _ => unreachable!("validated above"),
545 }
546 // SAFETY: check_inplace_allowed above ensures `self` is not in the
547 // autograd graph and not a requires_grad leaf; `data` has the same
548 // numel as `self` (`result.shape() == self.shape()` was asserted).
549 unsafe { self.update_data(&data)? };
550 Ok(self)
551 }
552
553 /// Clamp every element to `[min, max]` in-place.
554 ///
555 /// Each element `x` is replaced with `min.max(x.min(max))`, matching
556 /// PyTorch's `Tensor.clamp_()`.
557 ///
558 /// This is the both-bounds-required overload; for the
559 /// `(Option<T>, Option<T>)` overload that mirrors torch's
560 /// `clamp_(min=None, max=None)` see [`Tensor::clamp_opt_`].
561 ///
562 /// # Errors
563 ///
564 /// - Returns an error if `min > max`.
565 /// - Returns an error if the tensor is part of the computation graph or is
566 /// a leaf with `requires_grad = true`.
567 pub fn clamp_(&self, min: T, max: T) -> FerrotorchResult<&Self> {
568 if min > max {
569 return Err(FerrotorchError::InvalidArgument {
570 message: format!("clamp_ requires min <= max, got min={min:?}, max={max:?}"),
571 });
572 }
573
574 check_inplace_allowed(self, "clamp_")?;
575
576 let mut data = self.data_vec()?;
577 for x in &mut data {
578 if *x < min {
579 *x = min;
580 } else if *x > max {
581 *x = max;
582 }
583 }
584 // SAFETY: check_inplace_allowed ensures this tensor is not part of the
585 // computation graph and does not require grad, so no concurrent access.
586 unsafe { self.update_data(&data)? };
587
588 Ok(self)
589 }
590
591 /// Clamp with optional bounds — `Tensor.clamp_(min=None, max=None)` parity.
592 ///
593 /// Mirrors `torch.Tensor.clamp_(min=None, max=None) -> Tensor` per
594 /// `torch/_tensor_docs.py:1141` and the structured kernel
595 /// `TORCH_IMPL_FUNC(clamp_out)` at
596 /// `aten/src/ATen/native/TensorCompare.cpp:831`. Either bound may be
597 /// `None`:
598 ///
599 /// - `clamp_opt_(Some(lo), Some(hi))` — equivalent to `clamp_(lo, hi)`.
600 /// - `clamp_opt_(Some(lo), None)` — `clamp_min_` (lower bound only).
601 /// - `clamp_opt_(None, Some(hi))` — `clamp_max_` (upper bound only).
602 /// - `clamp_opt_(None, None)` — rejected with `InvalidArgument`
603 /// matching upstream "torch.clamp: At least one of 'min' or 'max' must
604 /// not be None" (`TensorCompare.cpp:106`).
605 ///
606 /// NaN-bound parity: if either supplied bound is NaN, the entire tensor
607 /// is filled with NaN (PyTorch's `at::fill_(result, NaN)` branch at
608 /// `TensorCompare.cpp:844`, executed when `min.isNan() || max.isNan()`).
609 ///
610 /// Per-element NaN inputs propagate (matching the kernel's
611 /// `std::min(std::max(a, min), max)` semantics — when `a` is NaN, both
612 /// comparisons evaluate false in this implementation and `a` is left
613 /// unchanged, which propagates NaN through).
614 ///
615 /// # Errors
616 ///
617 /// - Returns an error if both `min` and `max` are `None`.
618 /// - Returns an error if `min > max` (when both are `Some`).
619 /// - Returns an error if the tensor is part of the computation graph or
620 /// is a leaf with `requires_grad = true`.
621 pub fn clamp_opt_(&self, min: Option<T>, max: Option<T>) -> FerrotorchResult<&Self> {
622 if min.is_none() && max.is_none() {
623 return Err(FerrotorchError::InvalidArgument {
624 message: "clamp_opt_: at least one of 'min' or 'max' must not be None".into(),
625 });
626 }
627 if let (Some(lo), Some(hi)) = (min, max)
628 && lo > hi
629 {
630 return Err(FerrotorchError::InvalidArgument {
631 message: format!("clamp_opt_ requires min <= max, got min={lo:?}, max={hi:?}"),
632 });
633 }
634
635 check_inplace_allowed(self, "clamp_opt_")?;
636
637 // NaN-bound special case: PyTorch's `TORCH_IMPL_FUNC(clamp_out)` at
638 // `aten/src/ATen/native/TensorCompare.cpp:844` fills the entire
639 // result with NaN if any bound is NaN. Mirror that here.
640 let min_is_nan = min.is_some_and(num_traits::Float::is_nan);
641 let max_is_nan = max.is_some_and(num_traits::Float::is_nan);
642 if min_is_nan || max_is_nan {
643 let nan = <T as num_traits::Float>::nan();
644 let new_data = vec![nan; self.numel()];
645 // SAFETY: check_inplace_allowed above ensures `self` is not in
646 // the autograd graph and not a requires_grad leaf; new_data
647 // matches self.numel() by construction.
648 unsafe { self.update_data(&new_data)? };
649 return Ok(self);
650 }
651
652 let mut data = self.data_vec()?;
653 match (min, max) {
654 (Some(lo), Some(hi)) => {
655 for x in &mut data {
656 // NaN inputs propagate: `*x < lo` and `*x > hi` are both
657 // false when `*x` is NaN, leaving `*x` unchanged.
658 if *x < lo {
659 *x = lo;
660 } else if *x > hi {
661 *x = hi;
662 }
663 }
664 }
665 (Some(lo), None) => {
666 for x in &mut data {
667 if *x < lo {
668 *x = lo;
669 }
670 }
671 }
672 (None, Some(hi)) => {
673 for x in &mut data {
674 if *x > hi {
675 *x = hi;
676 }
677 }
678 }
679 (None, None) => unreachable!("rejected above"),
680 }
681 // SAFETY: check_inplace_allowed above ensures `self` is not in the
682 // autograd graph and not a requires_grad leaf; satisfies update_data's
683 // exclusive-access contract.
684 unsafe { self.update_data(&data)? };
685 Ok(self)
686 }
687}
688
689#[cfg(test)]
690mod tests {
691 use crate::storage::TensorStorage;
692 use crate::tensor::Tensor;
693
694 // -----------------------------------------------------------------------
695 // add_scalar_
696 // -----------------------------------------------------------------------
697
698 #[test]
699 fn test_add_scalar_basic() {
700 let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]), vec![3], false)
701 .unwrap();
702
703 t.add_scalar_(10.0).unwrap();
704
705 let data = t.data().unwrap();
706 assert_eq!(data, &[11.0, 12.0, 13.0]);
707 }
708
709 #[test]
710 fn test_add_scalar_negative() {
711 let t =
712 Tensor::from_storage(TensorStorage::cpu(vec![5.0f64, 10.0]), vec![2], false).unwrap();
713
714 t.add_scalar_(-3.0).unwrap();
715
716 let data = t.data().unwrap();
717 assert!((data[0] - 2.0).abs() < 1e-10);
718 assert!((data[1] - 7.0).abs() < 1e-10);
719 }
720
721 #[test]
722 fn test_add_scalar_chaining() {
723 let t =
724 Tensor::from_storage(TensorStorage::cpu(vec![0.0f32; 4]), vec![2, 2], false).unwrap();
725
726 t.add_scalar_(1.0).unwrap().add_scalar_(2.0).unwrap();
727
728 let data = t.data().unwrap();
729 assert_eq!(data, &[3.0, 3.0, 3.0, 3.0]);
730 }
731
732 #[test]
733 fn test_add_scalar_rejects_requires_grad_leaf() {
734 let t =
735 Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0, 2.0]), vec![2], true).unwrap();
736
737 let err = t.add_scalar_(1.0).unwrap_err();
738 let msg = format!("{err}");
739 assert!(msg.contains("requires_grad=true"), "got: {msg}");
740 }
741
742 // -----------------------------------------------------------------------
743 // mul_scalar_
744 // -----------------------------------------------------------------------
745
746 #[test]
747 fn test_mul_scalar_basic() {
748 let t = Tensor::from_storage(TensorStorage::cpu(vec![2.0f32, 3.0, 4.0]), vec![3], false)
749 .unwrap();
750
751 t.mul_scalar_(0.5).unwrap();
752
753 let data = t.data().unwrap();
754 assert_eq!(data, &[1.0, 1.5, 2.0]);
755 }
756
757 #[test]
758 fn test_mul_scalar_zero() {
759 let t = Tensor::from_storage(
760 TensorStorage::cpu(vec![42.0f64, -7.0, 100.0]),
761 vec![3],
762 false,
763 )
764 .unwrap();
765
766 t.mul_scalar_(0.0).unwrap();
767
768 let data = t.data().unwrap();
769 assert_eq!(data, &[0.0, 0.0, 0.0]);
770 }
771
772 #[test]
773 fn test_mul_scalar_rejects_requires_grad_leaf() {
774 let t = Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0]), vec![1], true).unwrap();
775
776 assert!(t.mul_scalar_(2.0).is_err());
777 }
778
779 // -----------------------------------------------------------------------
780 // fill_
781 // -----------------------------------------------------------------------
782
783 #[test]
784 fn test_fill_basic() {
785 let t = Tensor::from_storage(
786 TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
787 vec![2, 2],
788 false,
789 )
790 .unwrap();
791
792 t.fill_(99.0).unwrap();
793
794 let data = t.data().unwrap();
795 assert_eq!(data, &[99.0, 99.0, 99.0, 99.0]);
796 }
797
798 #[test]
799 // reason: round-trip bit-equality — fill_(42.0) writes the exact bit
800 // pattern of 42.0 (no arithmetic), so equality is the correct check.
801 #[allow(clippy::float_cmp)]
802 fn test_fill_scalar_tensor() {
803 let t = Tensor::from_storage(TensorStorage::cpu(vec![0.0f32]), vec![], false).unwrap();
804
805 t.fill_(42.0).unwrap();
806
807 assert_eq!(t.item().unwrap(), 42.0);
808 }
809
810 #[test]
811 fn test_fill_rejects_requires_grad_leaf() {
812 let t =
813 Tensor::<f64>::from_storage(TensorStorage::cpu(vec![1.0, 2.0]), vec![2], true).unwrap();
814
815 assert!(t.fill_(0.0).is_err());
816 }
817
818 // -----------------------------------------------------------------------
819 // zero_
820 // -----------------------------------------------------------------------
821
822 #[test]
823 fn test_zero_basic() {
824 let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]), vec![3], false)
825 .unwrap();
826
827 t.zero_().unwrap();
828
829 let data = t.data().unwrap();
830 assert_eq!(data, &[0.0, 0.0, 0.0]);
831 }
832
833 #[test]
834 fn test_zero_empty_tensor() {
835 let t =
836 Tensor::from_storage(TensorStorage::cpu(Vec::<f32>::new()), vec![0], false).unwrap();
837
838 t.zero_().unwrap();
839
840 assert_eq!(t.numel(), 0);
841 }
842
843 #[test]
844 fn test_zero_rejects_requires_grad_leaf() {
845 let t = Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0]), vec![1], true).unwrap();
846
847 assert!(t.zero_().is_err());
848 }
849
850 // -----------------------------------------------------------------------
851 // clamp_
852 // -----------------------------------------------------------------------
853
854 #[test]
855 fn test_clamp_basic() {
856 let t = Tensor::from_storage(
857 TensorStorage::cpu(vec![-5.0f32, 0.0, 3.0, 10.0, 100.0]),
858 vec![5],
859 false,
860 )
861 .unwrap();
862
863 t.clamp_(0.0, 10.0).unwrap();
864
865 let data = t.data().unwrap();
866 assert_eq!(data, &[0.0, 0.0, 3.0, 10.0, 10.0]);
867 }
868
869 #[test]
870 fn test_clamp_all_within_range() {
871 let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0f64, 2.0, 3.0]), vec![3], false)
872 .unwrap();
873
874 t.clamp_(0.0, 10.0).unwrap();
875
876 let data = t.data().unwrap();
877 assert_eq!(data, &[1.0, 2.0, 3.0]);
878 }
879
880 #[test]
881 fn test_clamp_single_value_range() {
882 let t = Tensor::from_storage(
883 TensorStorage::cpu(vec![-1.0f32, 0.0, 1.0, 5.0]),
884 vec![4],
885 false,
886 )
887 .unwrap();
888
889 t.clamp_(3.0, 3.0).unwrap();
890
891 let data = t.data().unwrap();
892 assert_eq!(data, &[3.0, 3.0, 3.0, 3.0]);
893 }
894
895 #[test]
896 fn test_clamp_invalid_range() {
897 let t =
898 Tensor::from_storage(TensorStorage::cpu(vec![1.0f32, 2.0]), vec![2], false).unwrap();
899
900 let err = t.clamp_(10.0, 0.0).unwrap_err();
901 let msg = format!("{err}");
902 assert!(msg.contains("min <= max"), "got: {msg}");
903 }
904
905 #[test]
906 fn test_clamp_rejects_requires_grad_leaf() {
907 let t =
908 Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0, 2.0]), vec![2], true).unwrap();
909
910 assert!(t.clamp_(0.0, 1.0).is_err());
911 }
912
913 // -----------------------------------------------------------------------
914 // Integration: detached tensors are mutable
915 // -----------------------------------------------------------------------
916
917 #[test]
918 fn test_detached_tensor_allows_inplace() {
919 let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]), vec![3], true)
920 .unwrap();
921
922 // Detach drops requires_grad and grad_fn.
923 let d = t.detach();
924 assert!(!d.requires_grad());
925
926 d.add_scalar_(10.0).unwrap();
927 let data = d.data().unwrap();
928 assert_eq!(data, &[11.0, 12.0, 13.0]);
929 }
930
931 // -----------------------------------------------------------------------
932 // Chaining multiple different in-place ops
933 // -----------------------------------------------------------------------
934
935 #[test]
936 fn test_mixed_inplace_chaining() {
937 let t = Tensor::from_storage(
938 TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
939 vec![4],
940 false,
941 )
942 .unwrap();
943
944 // (x + 10) * 2, then clamp to [20, 25]
945 t.add_scalar_(10.0)
946 .unwrap()
947 .mul_scalar_(2.0)
948 .unwrap()
949 .clamp_(20.0, 25.0)
950 .unwrap();
951
952 let data = t.data().unwrap();
953 // [1+10, 2+10, 3+10, 4+10] = [11, 12, 13, 14]
954 // * 2 = [22, 24, 26, 28]
955 // clamp [20, 25] = [22, 24, 25, 25]
956 assert_eq!(data, &[22.0, 24.0, 25.0, 25.0]);
957 }
958
959 // -----------------------------------------------------------------------
960 // f64 coverage
961 // -----------------------------------------------------------------------
962
963 #[test]
964 fn test_inplace_ops_f64() {
965 let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0f64, 2.0, 3.0]), vec![3], false)
966 .unwrap();
967
968 t.add_scalar_(100.0).unwrap();
969 t.mul_scalar_(0.1).unwrap();
970
971 let data = t.data().unwrap();
972 assert!((data[0] - 10.1).abs() < 1e-10);
973 assert!((data[1] - 10.2).abs() < 1e-10);
974 assert!((data[2] - 10.3).abs() < 1e-10);
975 }
976}