ferrotorch_core/methods.rs
1//! Method-style API for Tensor operations.
2//!
3//! Enables `a.matmul(&b)`, `a.relu()`, `a.sum()`, `a.reshape(&[2, 3])` etc.
4//! All methods delegate to the corresponding grad_fns or ops functions.
5//!
6//! ## REQ status (per `.design/ferrotorch-core/methods.md`)
7//!
8//! | REQ | Status | Evidence |
9//! |---|---|---|
10//! | REQ-1 (arithmetic methods) | SHIPPED | `add_t / sub_t / rsub_t / mul_t / div_t / neg_t / pow_t / sqrt_t / abs_t` each delegate to `crate::grad_fns::arithmetic::<op>`; non-test consumer `ferrotorch-nn/src/hooks.rs` (`add_t`); `rsub_t` itself is the production-consumer surface that closes R-DEFER-1 for `arithmetic::rsub`; parity-sweep `add` `88/88 passed`. |
11//! | REQ-2 (transcendental methods) | SHIPPED | `exp_t / log_t / sin_t / cos_t / clamp_t` delegate to `crate::grad_fns::transcendental::*`; non-test consumer `ferrotorch-diffusion/src/vae_encoder.rs` (`clamp_t`). |
12//! | REQ-3 (activation methods) | SHIPPED | `relu / sigmoid / tanh_t / gelu / gelu_with / silu / softmax / log_softmax` delegate to `crate::grad_fns::activation::*`; non-test consumers in `ferrotorch-vision`, `ferrotorch-diffusion`. |
13//! | REQ-4 (global reductions) | SHIPPED | `sum_all / mean_all / prod_all / amin / amax` delegate to `crate::grad_fns::reduction::*`; consumers in `ferrotorch-distributions`, `stride_tricks.rs`, and `grad_fns::activation`. |
14//! | REQ-5 (dim reductions) | NOT-STARTED | `sum_dim / mean_dim` methods exist; all production callers use the free `grad_fns::reduction::{sum_dim, mean_dim}` form directly. Blocker #1221. |
15//! | REQ-6 (linalg methods) | SHIPPED | `matmul / mm / mm_bt / bmm / mv_t / dot_t / t / einsum` delegate to `crate::grad_fns::linalg::*_differentiable` and `crate::einsum::einsum_differentiable`; consumers across `ferrotorch-distributions`, `ferrotorch-optim`, `ferrotorch-diffusion`. |
16//! | REQ-7 (`lu_factor`) | NOT-STARTED | delegation exists; no non-test caller. Blocker #1220. |
17//! | REQ-8 (reshape / shape methods) | SHIPPED | `reshape_t / flatten_t / squeeze_t / unsqueeze_t / permute / transpose` delegate to `crate::grad_fns::shape::*`; consumers pervasively across `ferrotorch-nn`, `ferrotorch-diffusion`, `ferrotorch-vision`, `flex_attention.rs`, `einops.rs`. |
18//! | REQ-9 (view / contiguous / narrow) | SHIPPED | `view / contiguous / narrow` delegate to free functions `view_t / contiguous_t / narrow_t` with `ContiguousBackward` / `NarrowBackward`; consumers pervasive (attention, blocks, distributions, vision, nn). |
19//! | REQ-10 (chunk / split) | SHIPPED | `chunk / split` delegate to `chunk_t / split_t` (GPU `strided_split_f32` + CPU fallback, `SplitBackward` autograd); consumers in `ferrotorch-diffusion` (`vae_encoder.rs`, `attention.rs`). |
20//! | REQ-11 (`size` / `dim` aliases) | NOT-STARTED | aliases compile; all in-tree callers are inside `#[cfg(test)]` modules. Blocker #1222. |
21//! | REQ-12 (`print` utility) | NOT-STARTED | emits a `tracing::info!` event; the only invocation is the in-file test. Blocker #1223. |
22//! | REQ-13 (cumulative methods) | SHIPPED | `cumsum_t / cumprod_t / logcumsumexp_t` delegate to `crate::grad_fns::cumulative::*`; these methods themselves close the R-DEFER-1 consumer requirement for the previously vocabulary-only `lib.rs` re-exports of the three ops; parity `[cumsum] 32/32 / [cumprod] 80/80 / [logcumsumexp] 48/48 passed`; closes #1232. `cummax_t / cummin_t` intentionally excluded (the tuple-form callers in `einops.rs` are the existing consumer, and the underlying ops remain NOT-STARTED behind #1231). |
23
24use crate::dtype::Float;
25use crate::error::FerrotorchResult;
26use crate::storage::TensorStorage;
27use crate::tensor::Tensor;
28
29impl<T: Float> Tensor<T> {
30 // --- Arithmetic ---
31
32 pub fn add_t(&self, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
33 crate::grad_fns::arithmetic::add(self, other)
34 }
35
36 pub fn sub_t(&self, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
37 crate::grad_fns::arithmetic::sub(self, other)
38 }
39
40 /// `torch.Tensor.rsub(other, *, alpha=1)` — reverse subtract:
41 /// `self - alpha * other` is the `sub_t` semantic; rsub is the
42 /// operand-swapped variant returning `other - alpha * self`.
43 ///
44 /// Per upstream `aten/src/ATen/native/BinaryOps.cpp:1169 Tensor rsub(
45 /// const Tensor& self, const Tensor& other, const Scalar& alpha) {
46 /// return at::sub(other, self, alpha); }` — a literal operand-swap
47 /// delegation. The non-test production consumer wiring for
48 /// `arithmetic::rsub` per R-DEFER-1: this method is the public,
49 /// chainable surface that closes the consumer requirement.
50 pub fn rsub_t(&self, other: &Tensor<T>, alpha: f64) -> FerrotorchResult<Tensor<T>> {
51 crate::grad_fns::arithmetic::rsub(self, other, alpha)
52 }
53
54 pub fn mul_t(&self, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
55 crate::grad_fns::arithmetic::mul(self, other)
56 }
57
58 pub fn div_t(&self, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
59 crate::grad_fns::arithmetic::div(self, other)
60 }
61
62 pub fn neg_t(&self) -> FerrotorchResult<Tensor<T>> {
63 crate::grad_fns::arithmetic::neg(self)
64 }
65
66 pub fn pow_t(&self, exponent: f64) -> FerrotorchResult<Tensor<T>> {
67 crate::grad_fns::arithmetic::pow(self, exponent)
68 }
69
70 pub fn sqrt_t(&self) -> FerrotorchResult<Tensor<T>> {
71 crate::grad_fns::arithmetic::sqrt(self)
72 }
73
74 /// `torch.Tensor.rsqrt()` — reciprocal square root: `1 / sqrt(self)`.
75 ///
76 /// Mirrors `torch.rsqrt(input, *, out=None)` per `torch/_torch_docs.py:9656`
77 /// and the upstream impl macro at
78 /// `aten/src/ATen/native/UnaryOps.cpp:346
79 /// CREATE_UNARY_TORCH_IMPL_FUNC(rsqrt_out, rsqrt_stub)`. The non-test
80 /// production consumer wiring for `arithmetic::rsqrt` per R-DEFER-1:
81 /// this method is the public, chainable surface that closes the
82 /// consumer requirement.
83 pub fn rsqrt_t(&self) -> FerrotorchResult<Tensor<T>> {
84 crate::grad_fns::arithmetic::rsqrt(self)
85 }
86
87 /// `torch.Tensor.reciprocal()` — elementwise reciprocal: `1 / self`.
88 ///
89 /// Mirrors `torch.reciprocal(input, *, out=None)` per
90 /// `torch/_torch_docs.py:2584` and the upstream impl macro at
91 /// `aten/src/ATen/native/UnaryOps.cpp:345
92 /// CREATE_UNARY_TORCH_IMPL_FUNC(reciprocal_out, reciprocal_stub)`. The
93 /// non-test production consumer wiring for `arithmetic::reciprocal` per
94 /// R-DEFER-1: this method is the public, chainable surface that closes
95 /// the consumer requirement.
96 pub fn reciprocal_t(&self) -> FerrotorchResult<Tensor<T>> {
97 crate::grad_fns::arithmetic::reciprocal(self)
98 }
99
100 pub fn abs_t(&self) -> FerrotorchResult<Tensor<T>> {
101 crate::grad_fns::arithmetic::abs(self)
102 }
103
104 /// `torch.Tensor.remainder(other)` — elementwise remainder with the
105 /// **sign of the divisor** (Python `%` / NumPy semantics).
106 ///
107 /// Mirrors `torch.remainder(input, other, *, out=None)` per
108 /// `torch/_torch_docs.py:9453-9472` and the upstream C++ entry at
109 /// `aten/src/ATen/native/BinaryOps.cpp:1184 Tensor remainder(const
110 /// Tensor& self, const Scalar& other)`. The float-tensor CPU
111 /// implementation is at `aten/src/ATen/native/cpu/BinaryOpsKernel.cpp:
112 /// 391-409 remainder_kernel`. Registration at
113 /// `torch/overrides.py:1100 torch.remainder: lambda input, other,
114 /// out=None: -1`.
115 ///
116 /// Distinct from `fmod_t` (dividend-sign / C99 semantics, REQ-14 NOT-
117 /// STARTED): for `remainder(-5, 3)` ferrotorch returns `1` (sign
118 /// matches divisor `+3`); `fmod(-5, 3)` returns `-2` (sign matches
119 /// dividend `-5`).
120 ///
121 /// The non-test production consumer wiring for `arithmetic::remainder`
122 /// per R-DEFER-1: this method is the public, chainable surface that
123 /// closes the consumer requirement.
124 pub fn remainder_t(&self, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
125 crate::grad_fns::arithmetic::remainder(self, other)
126 }
127
128 /// `torch.fmod(input, other, *, out=None)` — elementwise remainder
129 /// with the sign of the **dividend** (C99 `std::fmod` semantics).
130 ///
131 /// Mirrors `torch.Tensor.fmod` via the same upstream registration
132 /// `torch/overrides.py:666 torch.fmod: lambda input, other, out=None: -1`.
133 ///
134 /// Distinct from `remainder_t` (divisor-sign, REQ-13 SHIPPED): for
135 /// `fmod(-5, 3)` ferrotorch returns `-2` (sign matches dividend
136 /// `-5`); `remainder(-5, 3)` returns `1` (sign matches divisor
137 /// `+3`). See `arithmetic::fmod` docs for the per-quadrant table.
138 ///
139 /// The non-test production consumer wiring for `arithmetic::fmod`
140 /// per R-DEFER-1: this method is the public, chainable surface that
141 /// closes the consumer requirement.
142 pub fn fmod_t(&self, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
143 crate::grad_fns::arithmetic::fmod(self, other)
144 }
145
146 /// `torch.Tensor.floor_divide(other)` — elementwise floor division
147 /// (true floor, toward `-infinity`).
148 ///
149 /// Mirrors `torch.floor_divide(input, other, *, out=None)` per
150 /// `torch/_torch_docs.py:4265-4296`:
151 ///
152 /// > Computes :attr:`input` divided by :attr:`other`, elementwise, and
153 /// > floors the result.
154 /// >
155 /// > .. math::
156 /// > out_i = floor(input_i / other_i)
157 ///
158 /// Upstream entry at `aten/src/ATen/native/BinaryOps.cpp:979 Tensor
159 /// floor_divide(const Tensor& self, const Tensor& other)` dispatching
160 /// to `div_floor_stub` -> `div_floor_kernel` at
161 /// `aten/src/ATen/native/cpu/BinaryOpsKernel.cpp:297-349` ->
162 /// `c10::div_floor_floating` at `c10/util/generic_math.h:34-58`.
163 /// Registration at `torch/overrides.py:664 torch.floor_divide: lambda
164 /// input, other: -1`.
165 ///
166 /// `torch.floor_divide` was historically broken (performed trunc, NOT
167 /// floor) and `torch/_torch_docs.py:4267-4271` explicitly notes:
168 ///
169 /// > .. note::
170 /// > Before PyTorch 1.13 :func:`torch.floor_divide` incorrectly
171 /// > performed truncation division. To restore the previous
172 /// > behavior use :func:`torch.div` with ``rounding_mode='trunc'``.
173 ///
174 /// As of PyTorch 1.13+ (and as of the upstream pin this ferrotorch is
175 /// translated against), `torch.floor_divide` performs TRUE FLOOR.
176 /// Verified live on 2026-05-25:
177 /// `torch.floor_divide(-7.0, 3.0).item() == -3.0`.
178 ///
179 /// Distinct from `remainder_t` and `fmod_t`. The 3-way identity
180 /// `a == floor_divide(a,b) * b + remainder(a,b)` holds; the
181 /// `fmod` sibling is the trunc-division remainder. For `a=-7, b=3`:
182 /// - `floor_divide(-7, 3) = -3` (true floor)
183 /// - `remainder(-7, 3) = 2` (sign of divisor)
184 /// - `fmod(-7, 3) = -1` (sign of dividend / trunc remainder)
185 ///
186 /// Backward: `torch.floor_divide` has no derivative — verified live
187 /// `grad_fn=<NotImplemented object>` raises `derivative for
188 /// aten::floor_divide is not implemented`. `FloorDivideBackward`
189 /// mirrors that by erroring on `.backward()`.
190 ///
191 /// The non-test production consumer wiring for
192 /// `arithmetic::floor_divide` per R-DEFER-1: this method is the
193 /// public, chainable surface that closes the consumer requirement.
194 pub fn floor_divide_t(&self, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
195 crate::grad_fns::arithmetic::floor_divide(self, other)
196 }
197
198 /// `torch.Tensor.addcmul(tensor1, tensor2, *, value=1)` — fused
199 /// `self + value * tensor1 * tensor2` (receiver is `input`).
200 ///
201 /// Mirrors `torch.addcmul(input, tensor1, tensor2, *, value=1, out=None)`
202 /// per `torch/_torch_docs.py:510-544`:
203 ///
204 /// > Performs the element-wise multiplication of :attr:`tensor1` by
205 /// > :attr:`tensor2`, multiplies the result by the scalar :attr:`value`
206 /// > and adds it to :attr:`input`.
207 /// >
208 /// > .. math::
209 /// > \text{out}_i = \text{input}_i + \text{value} \times \text{tensor1}_i \times \text{tensor2}_i
210 ///
211 /// Upstream C++ entry at `aten/src/ATen/native/PointwiseOps.cpp:57-64
212 /// TORCH_IMPL_FUNC(addcmul_out)`. Registration at
213 /// `torch/overrides.py:462 torch.addcmul: lambda input, tensor1, tensor2,
214 /// value=1, out=None: -1`.
215 ///
216 /// Broadcasting: the 3 input tensors (`self`, `tensor1`, `tensor2`) are
217 /// jointly broadcast to a common output shape. Backward: per
218 /// `tools/autograd/derivatives.yaml`, `d_input = grad`, `d_tensor1 =
219 /// grad * value * tensor2`, `d_tensor2 = grad * value * tensor1` (no
220 /// gradient with respect to the scalar `value`).
221 ///
222 /// The non-test production consumer wiring for `arithmetic::addcmul`
223 /// per R-DEFER-1: this method is the public, chainable surface that
224 /// closes the consumer requirement.
225 pub fn addcmul_t(
226 &self,
227 tensor1: &Tensor<T>,
228 tensor2: &Tensor<T>,
229 value: f64,
230 ) -> FerrotorchResult<Tensor<T>> {
231 crate::grad_fns::arithmetic::addcmul(self, tensor1, tensor2, value)
232 }
233
234 /// `torch.Tensor.addcdiv(tensor1, tensor2, *, value=1)` — fused
235 /// `self + value * tensor1 / tensor2` (receiver is `input`).
236 ///
237 /// Mirrors `torch.addcdiv(input, tensor1, tensor2, *, value=1, out=None)`
238 /// per `torch/_torch_docs.py:461-473`:
239 ///
240 /// > Performs the element-wise division of :attr:`tensor1` by
241 /// > :attr:`tensor2`, multiplies the result by the scalar :attr:`value`
242 /// > and adds it to :attr:`input`.
243 /// >
244 /// > .. math::
245 /// > \text{out}_i = \text{input}_i + \text{value} \times
246 /// > \frac{\text{tensor1}_i}{\text{tensor2}_i}
247 ///
248 /// Upstream C++ entry at `aten/src/ATen/native/PointwiseOps.cpp:66-73
249 /// TORCH_IMPL_FUNC(addcdiv_out)`. The integer-dtype deprecation block at
250 /// `PointwiseOps.cpp:38-50 TORCH_META_FUNC(addcdiv)` is unreachable for
251 /// the `Tensor<T: Float>` family.
252 ///
253 /// Broadcasting: the 3 input tensors (`self`, `tensor1`, `tensor2`) are
254 /// jointly broadcast to a common output shape. Backward: per
255 /// `tools/autograd/derivatives.yaml`, `d_input = grad`, `d_tensor1 =
256 /// grad * value / tensor2`, `d_tensor2 = -grad * value * tensor1 /
257 /// (tensor2 * tensor2)` (no gradient with respect to the scalar
258 /// `value`). At `tensor2=0` the d_tensor2 path produces NaN / ±Inf via
259 /// IEEE-754 — matches upstream (R-DEV-1).
260 ///
261 /// The non-test production consumer wiring for `arithmetic::addcdiv`
262 /// per R-DEFER-1: this method is the public, chainable surface that
263 /// closes the consumer requirement.
264 pub fn addcdiv_t(
265 &self,
266 tensor1: &Tensor<T>,
267 tensor2: &Tensor<T>,
268 value: f64,
269 ) -> FerrotorchResult<Tensor<T>> {
270 crate::grad_fns::arithmetic::addcdiv(self, tensor1, tensor2, value)
271 }
272
273 // --- Cumulative (scan) ---
274
275 /// `torch.Tensor.cumsum(dim)` — cumulative sum along `dim`.
276 ///
277 /// Mirrors `torch.cumsum(input, dim, *, dtype=None, out=None)` per
278 /// `torch/_torch_docs.py:3429 cumsum(input, dim, *, dtype=None,
279 /// out=None) -> Tensor` and the `torch.Tensor` method docstring at
280 /// `torch/_tensor_docs.py:1500-1506 add_docstr_all("cumsum", r"""
281 /// cumsum(dim, dtype=None) -> Tensor [...] See :func:`torch.cumsum``.
282 /// Upstream C++ entry at `aten/src/ATen/native/ReduceOps.cpp:511
283 /// TORCH_IMPL_FUNC(cumsum_out)` dispatching `cumsum_stub`. Autograd
284 /// VJP per `tools/autograd/derivatives.yaml:529-531 (name: cumsum(
285 /// Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor; self:
286 /// cumsum_backward(grad.to(self.scalar_type()), dim))` which is the
287 /// `reverse_cumsum` (flip → cumsum → flip) upper-triangular
288 /// multiplication at `ReduceOps.cpp:527-529 static Tensor
289 /// reversed_cumsum(const Tensor& w, int64_t dim)`.
290 ///
291 /// ferrotorch does NOT accept the `dtype` kwarg (the dtype-promotion
292 /// branch at `ReduceOps.cpp:267` is unreachable for the `Tensor<T:
293 /// Float>` family — see `.design/ferrotorch-core/grad_fns/
294 /// cumulative.md` REQ-1).
295 ///
296 /// The non-test production consumer wiring for
297 /// `grad_fns::cumulative::cumsum` per R-DEFER-1: this method is the
298 /// public, chainable surface that closes the consumer requirement
299 /// (blocker #1232).
300 pub fn cumsum_t(&self, dim: i64) -> FerrotorchResult<Tensor<T>> {
301 crate::grad_fns::cumulative::cumsum(self, dim)
302 }
303
304 /// `torch.Tensor.cumprod(dim)` — cumulative product along `dim`.
305 ///
306 /// Mirrors `torch.cumprod(input, dim, *, dtype=None, out=None)` per
307 /// `torch/_torch_docs.py:3390 cumprod(input, dim, *, dtype=None,
308 /// out=None) -> Tensor` and the `torch.Tensor` method docstring at
309 /// `torch/_tensor_docs.py:1482-1488 add_docstr_all("cumprod", r"""
310 /// cumprod(dim, dtype=None) -> Tensor [...] See :func:`torch.cumprod`.
311 /// Upstream C++ entry at `aten/src/ATen/native/ReduceOps.cpp:519
312 /// TORCH_IMPL_FUNC(cumprod_out)`. Autograd VJP per
313 /// `tools/autograd/derivatives.yaml:525-527 (name: cumprod(Tensor
314 /// self, int dim, *, ScalarType? dtype=None) -> Tensor; self:
315 /// cumprod_backward(grad.to(self.scalar_type()), self, dim, result))`
316 /// routing through `cumprod_backward` at `ReduceOps.cpp:531-790`
317 /// with the zeros-aware reverse-cumsum-divide algorithm.
318 ///
319 /// ferrotorch does NOT accept the `dtype` kwarg; the zeros-present
320 /// path uses an O(n^3) brute-force backward rather than upstream's
321 /// composite-compliance masked-fill (numerically identical, slower,
322 /// not second-order-differentiable — see
323 /// `.design/ferrotorch-core/grad_fns/cumulative.md` REQ-2).
324 ///
325 /// The non-test production consumer wiring for
326 /// `grad_fns::cumulative::cumprod` per R-DEFER-1: this method is the
327 /// public, chainable surface that closes the consumer requirement
328 /// (blocker #1232).
329 pub fn cumprod_t(&self, dim: i64) -> FerrotorchResult<Tensor<T>> {
330 crate::grad_fns::cumulative::cumprod(self, dim)
331 }
332
333 /// `torch.Tensor.logcumsumexp(dim)` — numerically stable
334 /// `log(cumsum(exp(self)))` along `dim`.
335 ///
336 /// Mirrors `torch.logcumsumexp(input, dim, *, out=None)` per
337 /// `torch/_torch_docs.py:3298 logcumsumexp(input, dim, *, out=None)
338 /// -> Tensor` and the `torch.Tensor` method docstring at
339 /// `torch/_tensor_docs.py:1455-1462 add_docstr_all("logcumsumexp",
340 /// r""" logcumsumexp(dim) -> Tensor [...] See
341 /// :func:`torch.logcumsumexp``. Upstream C++ entry at
342 /// `aten/src/ATen/native/ReduceOps.cpp:475 Tensor logcumsumexp(const
343 /// Tensor& self, int64_t dim)` dispatching `_logcumsumexp_cpu` at
344 /// `:465-468` → `logcumsumexp_stub` at `:471`. Autograd VJP per
345 /// `tools/autograd/derivatives.yaml:521-523 (name: logcumsumexp(
346 /// Tensor self, int dim) -> Tensor; self: logcumsumexp_backward(grad,
347 /// self, result, dim))` factors as `grad_input[i] = exp(input[i]) *
348 /// reverse_cumsum(grad_output * exp(-output))` (softmax-weighted
349 /// reverse cumsum).
350 ///
351 /// The numerical-stability invariant (large inputs ~1000.0 stay
352 /// finite) is preserved by the two-pass max-rescaling forward
353 /// algorithm at `ops/cumulative.rs:378-410`. See
354 /// `.design/ferrotorch-core/grad_fns/cumulative.md` REQ-5.
355 ///
356 /// The non-test production consumer wiring for
357 /// `grad_fns::cumulative::logcumsumexp` per R-DEFER-1: this method
358 /// is the public, chainable surface that closes the consumer
359 /// requirement (blocker #1232).
360 pub fn logcumsumexp_t(&self, dim: i64) -> FerrotorchResult<Tensor<T>> {
361 crate::grad_fns::cumulative::logcumsumexp(self, dim)
362 }
363
364 // --- Transcendental ---
365
366 pub fn exp_t(&self) -> FerrotorchResult<Tensor<T>> {
367 crate::grad_fns::transcendental::exp(self)
368 }
369
370 pub fn log_t(&self) -> FerrotorchResult<Tensor<T>> {
371 crate::grad_fns::transcendental::log(self)
372 }
373
374 pub fn sin_t(&self) -> FerrotorchResult<Tensor<T>> {
375 crate::grad_fns::transcendental::sin(self)
376 }
377
378 pub fn cos_t(&self) -> FerrotorchResult<Tensor<T>> {
379 crate::grad_fns::transcendental::cos(self)
380 }
381
382 pub fn clamp_t(&self, min: T, max: T) -> FerrotorchResult<Tensor<T>> {
383 crate::grad_fns::transcendental::clamp(self, min, max)
384 }
385
386 /// `clip` is a literal alias of `clamp` per upstream
387 /// `aten/src/ATen/native/TensorCompare.cpp:918-930 Tensor clip(...)`
388 /// (pass-through to `at::clamp(self, min, max)`).
389 pub fn clip_t(&self, min: T, max: T) -> FerrotorchResult<Tensor<T>> {
390 crate::grad_fns::transcendental::clamp(self, min, max)
391 }
392
393 // --- Transcendental: extended unary family
394 // (closes #1303 #1305 #1307 #1309 #1311 #1313 #1315 #1316 #1317 #1319
395 // #1320 #1322 #1323 #1324 #1325 #1326 #1327 #1328 #1329 #1330 #1331
396 // #1333 — impl + non-test consumer in same commit per S5 / R-DEFER-1)
397
398 pub fn tan_t(&self) -> FerrotorchResult<Tensor<T>> {
399 crate::grad_fns::transcendental::tan(self)
400 }
401
402 pub fn asin_t(&self) -> FerrotorchResult<Tensor<T>> {
403 crate::grad_fns::transcendental::asin(self)
404 }
405
406 pub fn acos_t(&self) -> FerrotorchResult<Tensor<T>> {
407 crate::grad_fns::transcendental::acos(self)
408 }
409
410 pub fn atan_t(&self) -> FerrotorchResult<Tensor<T>> {
411 crate::grad_fns::transcendental::atan(self)
412 }
413
414 pub fn sinh_t(&self) -> FerrotorchResult<Tensor<T>> {
415 crate::grad_fns::transcendental::sinh(self)
416 }
417
418 pub fn cosh_t(&self) -> FerrotorchResult<Tensor<T>> {
419 crate::grad_fns::transcendental::cosh(self)
420 }
421
422 pub fn asinh_t(&self) -> FerrotorchResult<Tensor<T>> {
423 crate::grad_fns::transcendental::asinh(self)
424 }
425
426 pub fn acosh_t(&self) -> FerrotorchResult<Tensor<T>> {
427 crate::grad_fns::transcendental::acosh(self)
428 }
429
430 pub fn atanh_t(&self) -> FerrotorchResult<Tensor<T>> {
431 crate::grad_fns::transcendental::atanh(self)
432 }
433
434 pub fn exp2_t(&self) -> FerrotorchResult<Tensor<T>> {
435 crate::grad_fns::transcendental::exp2(self)
436 }
437
438 pub fn expm1_t(&self) -> FerrotorchResult<Tensor<T>> {
439 crate::grad_fns::transcendental::expm1(self)
440 }
441
442 pub fn log2_t(&self) -> FerrotorchResult<Tensor<T>> {
443 crate::grad_fns::transcendental::log2(self)
444 }
445
446 pub fn log10_t(&self) -> FerrotorchResult<Tensor<T>> {
447 crate::grad_fns::transcendental::log10(self)
448 }
449
450 pub fn log1p_t(&self) -> FerrotorchResult<Tensor<T>> {
451 crate::grad_fns::transcendental::log1p(self)
452 }
453
454 pub fn ceil_t(&self) -> FerrotorchResult<Tensor<T>> {
455 crate::grad_fns::transcendental::ceil(self)
456 }
457
458 pub fn floor_t(&self) -> FerrotorchResult<Tensor<T>> {
459 crate::grad_fns::transcendental::floor(self)
460 }
461
462 pub fn round_t(&self) -> FerrotorchResult<Tensor<T>> {
463 crate::grad_fns::transcendental::round(self)
464 }
465
466 pub fn trunc_t(&self) -> FerrotorchResult<Tensor<T>> {
467 crate::grad_fns::transcendental::trunc(self)
468 }
469
470 pub fn frac_t(&self) -> FerrotorchResult<Tensor<T>> {
471 crate::grad_fns::transcendental::frac(self)
472 }
473
474 pub fn sign_t(&self) -> FerrotorchResult<Tensor<T>> {
475 crate::grad_fns::transcendental::sign(self)
476 }
477
478 pub fn sinc_t(&self) -> FerrotorchResult<Tensor<T>> {
479 crate::grad_fns::transcendental::sinc(self)
480 }
481
482 // --- Activation ---
483
484 pub fn relu(&self) -> FerrotorchResult<Tensor<T>> {
485 crate::grad_fns::activation::relu(self)
486 }
487
488 pub fn sigmoid(&self) -> FerrotorchResult<Tensor<T>> {
489 crate::grad_fns::activation::sigmoid(self)
490 }
491
492 pub fn tanh_t(&self) -> FerrotorchResult<Tensor<T>> {
493 crate::grad_fns::activation::tanh(self)
494 }
495
496 pub fn gelu(&self) -> FerrotorchResult<Tensor<T>> {
497 crate::grad_fns::activation::gelu(self)
498 }
499
500 pub fn gelu_with(
501 &self,
502 approximate: crate::grad_fns::activation::GeluApproximate,
503 ) -> FerrotorchResult<Tensor<T>> {
504 crate::grad_fns::activation::gelu_with(self, approximate)
505 }
506
507 pub fn silu(&self) -> FerrotorchResult<Tensor<T>> {
508 crate::grad_fns::activation::silu(self)
509 }
510
511 pub fn softmax(&self) -> FerrotorchResult<Tensor<T>> {
512 crate::grad_fns::activation::softmax(self)
513 }
514
515 pub fn log_softmax(&self) -> FerrotorchResult<Tensor<T>> {
516 crate::grad_fns::activation::log_softmax(self)
517 }
518
519 /// `torch.Tensor.threshold(threshold, value)` — replace each element below
520 /// (or equal to) `threshold` with `value`, leave the rest unchanged.
521 ///
522 /// Mirrors `torch.nn.functional.threshold(input, threshold, value)` per
523 /// `torch/nn/functional.py:1682-1700` and
524 /// `TORCH_IMPL_FUNC(threshold_out)` at
525 /// `aten/src/ATen/native/Activation.cpp:688-690`. The non-test production
526 /// consumer wiring for `grad_fns::activation::threshold` per R-DEFER-1:
527 /// this method is the public, chainable surface that closes the
528 /// consumer requirement (closes #1341 REQ-19).
529 pub fn threshold_t(&self, threshold: f64, value: f64) -> FerrotorchResult<Tensor<T>> {
530 crate::grad_fns::activation::threshold(self, threshold, value)
531 }
532
533 /// `torch.Tensor.rrelu(lower, upper, training)` — randomized leaky ReLU.
534 ///
535 /// Mirrors `torch.nn.functional.rrelu(input, lower, upper, training,
536 /// inplace)` per `torch/nn/functional.py:1962-1989` and
537 /// `Tensor& rrelu_with_noise_out_cpu(...)` at
538 /// `aten/src/ATen/native/Activation.cpp:611-654`. The non-test production
539 /// consumer wiring for `grad_fns::activation::rrelu` per R-DEFER-1:
540 /// this method is the public, chainable surface that closes the
541 /// consumer requirement (closes #1341 REQ-20).
542 ///
543 /// Note: `training=true` falls back to the deterministic mean-slope
544 /// inference path (per the GradFn docs at `activation.rs`). The
545 /// RNG-stateful training-mode VJP is a separately-tracked follow-up.
546 pub fn rrelu_t(&self, lower: f64, upper: f64, training: bool) -> FerrotorchResult<Tensor<T>> {
547 crate::grad_fns::activation::rrelu(self, lower, upper, training)
548 }
549
550 /// `torch.Tensor.celu(alpha)` —
551 /// `celu(x) = max(0, x) + min(0, alpha * (exp(x / alpha) - 1))`.
552 ///
553 /// Mirrors `torch.nn.functional.celu(input, alpha=1.0)` per
554 /// `torch/nn/functional.py:1874-1894` and
555 /// `Tensor celu(const Tensor& self, const Scalar& alpha)` at
556 /// `aten/src/ATen/native/Activation.cpp:540-545`. The non-test production
557 /// consumer wiring for `grad_fns::activation::celu` per R-DEFER-1:
558 /// this method is the public, chainable surface that closes the
559 /// consumer requirement (closes #1341 REQ-21).
560 pub fn celu_t(&self, alpha: f64) -> FerrotorchResult<Tensor<T>> {
561 crate::grad_fns::activation::celu(self, alpha)
562 }
563
564 /// `torch.Tensor.softmin()` — `softmin(x) = softmax(-x)` along the last
565 /// axis (fused single-`GradFn` variant).
566 ///
567 /// Mirrors `torch.nn.functional.softmin(input, dim=None, dtype=None)` per
568 /// `torch/nn/functional.py:2095-2125`. The non-test production consumer
569 /// wiring for `grad_fns::activation::softmin` per R-DEFER-1: this method
570 /// is the public, chainable surface that closes the consumer requirement
571 /// (closes #1341 REQ-22). The composition-route variant
572 /// (`ferrotorch_nn::functional::softmin` = neg -> softmax, two GradFn
573 /// nodes) remains available; this method routes through the fused VJP.
574 pub fn softmin_t(&self) -> FerrotorchResult<Tensor<T>> {
575 crate::grad_fns::activation::softmin(self)
576 }
577
578 // --- Reduction ---
579
580 pub fn sum_all(&self) -> FerrotorchResult<Tensor<T>> {
581 crate::grad_fns::reduction::sum(self)
582 }
583
584 pub fn mean_all(&self) -> FerrotorchResult<Tensor<T>> {
585 crate::grad_fns::reduction::mean(self)
586 }
587
588 pub fn prod_all(&self) -> FerrotorchResult<Tensor<T>> {
589 crate::grad_fns::reduction::prod(self)
590 }
591
592 /// Global minimum across all elements. Mirrors `torch.amin(self)` with
593 /// no `dim` argument. Returns a 0-d tensor. On CUDA f32/f64, dispatches
594 /// to the native PTX reduce_min kernel; on CPU walks the buffer. (#627)
595 pub fn amin(&self) -> FerrotorchResult<Tensor<T>> {
596 crate::grad_fns::reduction::amin(self)
597 }
598
599 /// Global maximum across all elements. Mirrors `torch.amax(self)`. (#627)
600 pub fn amax(&self) -> FerrotorchResult<Tensor<T>> {
601 crate::grad_fns::reduction::amax(self)
602 }
603
604 /// LU factorization in cuSOLVER's packed form: returns
605 /// `(LU_packed, pivots)`. Mirrors `torch.linalg.lu_factor`. On CUDA
606 /// f32/f64, runs natively via cuSOLVER `getrf` with no host bounce
607 /// for the matrix; pivots come back as a host `Vec<i32>` (O(n)). (#604)
608 pub fn lu_factor(&self) -> FerrotorchResult<(Tensor<T>, Vec<i32>)> {
609 crate::linalg::lu_factor(self)
610 }
611
612 // --- Linalg ---
613
614 pub fn matmul(&self, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
615 crate::grad_fns::linalg::matmul_differentiable(self, other)
616 }
617
618 pub fn mm(&self, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
619 crate::grad_fns::linalg::mm_differentiable(self, other)
620 }
621
622 /// Fused A @ B^T — avoids materializing the transpose of B.
623 /// A: [M, K], B: [N, K] -> [M, N].
624 pub fn mm_bt(&self, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
625 crate::grad_fns::linalg::mm_bt_differentiable(self, other)
626 }
627
628 pub fn bmm(&self, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
629 crate::grad_fns::linalg::bmm_differentiable(self, other)
630 }
631
632 pub fn mv_t(&self, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
633 crate::grad_fns::linalg::mv_differentiable(self, other)
634 }
635
636 pub fn dot_t(&self, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
637 crate::grad_fns::linalg::dot_differentiable(self, other)
638 }
639
640 pub fn t(&self) -> FerrotorchResult<Tensor<T>> {
641 crate::grad_fns::shape::transpose_2d(self)
642 }
643
644 /// Einstein summation with this tensor as the first operand.
645 ///
646 /// `others` contains the remaining input tensors (if any). The equation
647 /// must include subscripts for `self` followed by the `others`.
648 ///
649 /// ```ignore
650 /// // Matrix multiply: self @ other
651 /// let c = a.einsum("ij,jk->ik", &[&b])?;
652 ///
653 /// // Trace of self
654 /// let t = a.einsum("ii->", &[])?;
655 /// ```
656 pub fn einsum(&self, equation: &str, others: &[&Tensor<T>]) -> FerrotorchResult<Tensor<T>> {
657 let mut inputs: Vec<&Tensor<T>> = vec![self];
658 inputs.extend_from_slice(others);
659 crate::einsum::einsum_differentiable(equation, &inputs)
660 }
661
662 // --- Reduction (dim) ---
663
664 pub fn sum_dim(&self, dim: i64, keepdim: bool) -> FerrotorchResult<Tensor<T>> {
665 crate::grad_fns::reduction::sum_dim(self, dim, keepdim)
666 }
667
668 pub fn mean_dim(&self, dim: i64, keepdim: bool) -> FerrotorchResult<Tensor<T>> {
669 crate::grad_fns::reduction::mean_dim(self, dim, keepdim)
670 }
671
672 /// Differentiable full-reduction logsumexp. Mirrors
673 /// `torch.logsumexp(self)` — numerically stable `log(sum(exp(self)))`
674 /// to a 0-D scalar. Backward `grad * exp(self - result)`. Closes #1310.
675 pub fn logsumexp_t(&self) -> FerrotorchResult<Tensor<T>> {
676 crate::grad_fns::reduction::logsumexp(self)
677 }
678
679 /// Differentiable dim-keyed logsumexp. Mirrors
680 /// `torch.logsumexp(self, dim, keepdim)`.
681 pub fn logsumexp_dim_t(&self, dim: i64, keepdim: bool) -> FerrotorchResult<Tensor<T>> {
682 crate::grad_fns::reduction::logsumexp_dim(self, dim, keepdim)
683 }
684
685 /// Non-differentiable global argmax. Mirrors `torch.argmax(self)`.
686 /// Returns a 0-D IntTensor<i64> with the flat index of the largest
687 /// element. Closes #1304 (argmax).
688 pub fn argmax_t(&self) -> FerrotorchResult<crate::int_tensor::IntTensor<i64>> {
689 crate::grad_fns::reduction::argmax(self)
690 }
691
692 /// Non-differentiable dim-keyed argmax.
693 pub fn argmax_dim_t(
694 &self,
695 dim: i64,
696 keepdim: bool,
697 ) -> FerrotorchResult<crate::int_tensor::IntTensor<i64>> {
698 crate::grad_fns::reduction::argmax_dim(self, dim, keepdim)
699 }
700
701 /// Non-differentiable global argmin. Mirrors `torch.argmin(self)`.
702 pub fn argmin_t(&self) -> FerrotorchResult<crate::int_tensor::IntTensor<i64>> {
703 crate::grad_fns::reduction::argmin(self)
704 }
705
706 /// Non-differentiable dim-keyed argmin.
707 pub fn argmin_dim_t(
708 &self,
709 dim: i64,
710 keepdim: bool,
711 ) -> FerrotorchResult<crate::int_tensor::IntTensor<i64>> {
712 crate::grad_fns::reduction::argmin_dim(self, dim, keepdim)
713 }
714
715 /// Differentiable full-reduction variance with optional Bessel
716 /// correction. `unbiased=true` divides by `n-1`; false divides by
717 /// `n`. Closes #1301 (var).
718 pub fn var_t(&self, unbiased: bool) -> FerrotorchResult<Tensor<T>> {
719 crate::grad_fns::reduction::var(self, unbiased)
720 }
721
722 /// Differentiable full-reduction standard deviation. Closes #1301
723 /// (std).
724 pub fn std_t(&self, unbiased: bool) -> FerrotorchResult<Tensor<T>> {
725 crate::grad_fns::reduction::std(self, unbiased)
726 }
727
728 /// Differentiable full-reduction variance with arbitrary Bessel
729 /// correction. Mirrors `torch.var(input, correction=...)` —
730 /// `denom = max(0, n - correction)`. Closes #1346 (audit 7cef63f88
731 /// REQ-8 full-reduction correction-API gap).
732 pub fn var_with_correction_t(&self, correction: f64) -> FerrotorchResult<Tensor<T>> {
733 crate::grad_fns::reduction::var_with_correction(self, correction)
734 }
735
736 /// Differentiable full-reduction standard deviation with arbitrary
737 /// `correction`. Mirrors `torch.std(input, correction=...)`. Closes
738 /// #1346 (audit 7cef63f88 REQ-8 full-reduction correction-API gap).
739 pub fn std_with_correction_t(&self, correction: f64) -> FerrotorchResult<Tensor<T>> {
740 crate::grad_fns::reduction::std_with_correction(self, correction)
741 }
742
743 /// Non-differentiable full-reduction `any`. Returns a 0-D BoolTensor
744 /// holding `true` iff any element is non-zero. Closes #1312 (any).
745 pub fn any_t(&self) -> FerrotorchResult<crate::bool_tensor::BoolTensor> {
746 crate::grad_fns::reduction::any(self)
747 }
748
749 /// Non-differentiable full-reduction `all`. Closes #1312 (all).
750 pub fn all_t(&self) -> FerrotorchResult<crate::bool_tensor::BoolTensor> {
751 crate::grad_fns::reduction::all(self)
752 }
753
754 /// Non-differentiable full-reduction `count_nonzero`. Returns a 0-D
755 /// IntTensor<i64> with the count of non-zero elements. Closes #1312
756 /// (count_nonzero).
757 pub fn count_nonzero_t(&self) -> FerrotorchResult<crate::int_tensor::IntTensor<i64>> {
758 crate::grad_fns::reduction::count_nonzero(self)
759 }
760
761 // --- Shape ---
762
763 pub fn reshape_t(&self, shape: &[isize]) -> FerrotorchResult<Tensor<T>> {
764 crate::grad_fns::shape::reshape(self, shape)
765 }
766
767 pub fn flatten_t(&self) -> FerrotorchResult<Tensor<T>> {
768 crate::grad_fns::shape::flatten(self)
769 }
770
771 pub fn squeeze_t(&self, axis: isize) -> FerrotorchResult<Tensor<T>> {
772 crate::grad_fns::shape::squeeze(self, axis)
773 }
774
775 pub fn unsqueeze_t(&self, axis: isize) -> FerrotorchResult<Tensor<T>> {
776 crate::grad_fns::shape::unsqueeze(self, axis)
777 }
778
779 /// Permute tensor dimensions. Like PyTorch's `tensor.permute(dims)`.
780 ///
781 /// Zero-copy: returns a view with permuted shape and strides.
782 /// `dims` must be a valid permutation of `0..ndim`.
783 pub fn permute(&self, dims: &[usize]) -> FerrotorchResult<Tensor<T>> {
784 permute_t(self, dims)
785 }
786
787 /// Swap two dimensions. Like PyTorch's `tensor.transpose(dim0, dim1)`.
788 ///
789 /// Zero-copy: returns a view with swapped strides.
790 pub fn transpose(&self, dim0: usize, dim1: usize) -> FerrotorchResult<Tensor<T>> {
791 let ndim = self.ndim();
792 if dim0 >= ndim || dim1 >= ndim {
793 return Err(crate::error::FerrotorchError::InvalidArgument {
794 message: format!("transpose: dims ({dim0}, {dim1}) out of bounds for ndim {ndim}"),
795 });
796 }
797 if dim0 == dim1 {
798 return Ok(self.clone());
799 }
800 let mut perm: Vec<usize> = (0..ndim).collect();
801 perm.swap(dim0, dim1);
802 permute_t(self, &perm)
803 }
804
805 /// Swap two axes. Like PyTorch's `tensor.swapaxes(axis0, axis1)` — a
806 /// literal alias of `transpose` per upstream
807 /// `aten/src/ATen/native/TensorShape.cpp:4776`.
808 pub fn swapaxes(&self, axis0: usize, axis1: usize) -> FerrotorchResult<Tensor<T>> {
809 crate::grad_fns::shape::swapaxes(self, axis0, axis1)
810 }
811
812 /// Swap two dims. Like PyTorch's `tensor.swapdims(dim0, dim1)` — a literal
813 /// alias of `transpose` per upstream
814 /// `aten/src/ATen/native/TensorShape.cpp:4784`.
815 pub fn swapdims(&self, dim0: usize, dim1: usize) -> FerrotorchResult<Tensor<T>> {
816 crate::grad_fns::shape::swapdims(self, dim0, dim1)
817 }
818
819 /// Reshape a single dimension `dim` into multiple `sizes`. Like PyTorch's
820 /// `tensor.unflatten(dim, sizes)` per upstream
821 /// `aten/src/ATen/native/TensorShape.cpp:4350`. At most one `-1`
822 /// inference slot is allowed in `sizes`.
823 pub fn unflatten_t(&self, dim: isize, sizes: &[isize]) -> FerrotorchResult<Tensor<T>> {
824 crate::grad_fns::shape::unflatten(self, dim, sizes)
825 }
826
827 /// Broadcast this tensor to the shape of `other`. Like PyTorch's
828 /// `tensor.expand_as(other)` per upstream
829 /// `aten/src/ATen/native/TensorShape.cpp:1374`.
830 pub fn expand_as_t(&self, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
831 crate::grad_fns::shape::expand_as(self, other)
832 }
833
834 /// Reverse element order along each axis in `dims`. Like PyTorch's
835 /// `torch.flip(input, dims)` per upstream
836 /// `aten/src/ATen/native/TensorTransformations.cpp:36`.
837 pub fn flip_t(&self, dims: &[isize]) -> FerrotorchResult<Tensor<T>> {
838 crate::grad_fns::shape::flip(self, dims)
839 }
840
841 /// Flip left-to-right (along dim 1). Like PyTorch's `torch.fliplr` per
842 /// upstream `aten/src/ATen/native/TensorTransformations.cpp:180`.
843 pub fn fliplr_t(&self) -> FerrotorchResult<Tensor<T>> {
844 crate::grad_fns::shape::fliplr(self)
845 }
846
847 /// Flip up-to-down (along dim 0). Like PyTorch's `torch.flipud` per
848 /// upstream `aten/src/ATen/native/TensorTransformations.cpp:186`.
849 pub fn flipud_t(&self) -> FerrotorchResult<Tensor<T>> {
850 crate::grad_fns::shape::flipud(self)
851 }
852
853 /// Rotate 90° `k` times in the plane spanned by `dims`. Like PyTorch's
854 /// `torch.rot90(input, k, dims)` per upstream
855 /// `aten/src/ATen/native/TensorTransformations.cpp:134`.
856 pub fn rot90_t(&self, k: i64, dims: &[isize]) -> FerrotorchResult<Tensor<T>> {
857 crate::grad_fns::shape::rot90(self, k, dims)
858 }
859
860 /// Reposition dims from `source` to `destination`. Like PyTorch's
861 /// `torch.movedim(input, source, destination)` per upstream
862 /// `aten/src/ATen/native/TensorShape.cpp:4657`.
863 pub fn movedim_t(
864 &self,
865 source: &[isize],
866 destination: &[isize],
867 ) -> FerrotorchResult<Tensor<T>> {
868 crate::grad_fns::shape::movedim(self, source, destination)
869 }
870
871 /// Reposition axes from `source` to `destination`. Like PyTorch's
872 /// `torch.moveaxis` (an alias of `movedim`) per upstream
873 /// `aten/src/ATen/native/TensorShape.cpp:4768`.
874 pub fn moveaxis_t(
875 &self,
876 source: &[isize],
877 destination: &[isize],
878 ) -> FerrotorchResult<Tensor<T>> {
879 crate::grad_fns::shape::moveaxis(self, source, destination)
880 }
881
882 /// Broadcast this tensor to `shape`. Like PyTorch's
883 /// `torch.broadcast_to(input, shape)` (an alias of `expand`) per upstream
884 /// `aten/src/ATen/native/TensorShape.cpp:652`.
885 pub fn broadcast_to_t(&self, shape: &[usize]) -> FerrotorchResult<Tensor<T>> {
886 crate::grad_fns::shape::broadcast_to(self, shape)
887 }
888
889 /// Tile this tensor `repeats[i]` times along each axis. Like PyTorch's
890 /// `tensor.repeat(*repeats)` per upstream
891 /// `aten/src/ATen/native/TensorShape.cpp:1909`.
892 pub fn repeat_t(&self, repeats: &[isize]) -> FerrotorchResult<Tensor<T>> {
893 crate::grad_fns::shape::repeat(self, repeats)
894 }
895
896 /// NumPy-style tile. Like PyTorch's `torch.tile(input, reps)` per upstream
897 /// `aten/src/ATen/native/TensorShape.cpp:1971`.
898 pub fn tile_t(&self, reps: &[isize]) -> FerrotorchResult<Tensor<T>> {
899 crate::grad_fns::shape::tile(self, reps)
900 }
901
902 /// Repeat each element `repeats` times consecutively along `dim`. Like
903 /// PyTorch's `torch.repeat_interleave(input, repeats, dim)`.
904 pub fn repeat_interleave_t(&self, repeats: usize, dim: isize) -> FerrotorchResult<Tensor<T>> {
905 crate::grad_fns::shape::repeat_interleave(self, repeats, dim)
906 }
907
908 /// Split into `size(dim)` slices with `dim` removed. Like PyTorch's
909 /// `torch.unbind(input, dim)` per upstream
910 /// `aten/src/ATen/native/TensorShape.cpp:4367`.
911 pub fn unbind_t(&self, dim: isize) -> FerrotorchResult<Vec<Tensor<T>>> {
912 crate::grad_fns::shape::unbind(self, dim)
913 }
914
915 /// Split at the integer section boundaries `indices` along `dim`. Like
916 /// PyTorch's `torch.tensor_split(input, indices, dim)` per upstream
917 /// `aten/src/ATen/native/TensorShape.cpp:1167`.
918 pub fn tensor_split_t(
919 &self,
920 indices: &[usize],
921 dim: isize,
922 ) -> FerrotorchResult<Vec<Tensor<T>>> {
923 crate::grad_fns::shape::tensor_split(self, indices, dim)
924 }
925
926 /// Stack tensors row-wise (≥2-D then `cat` dim 0). Like PyTorch's
927 /// `torch.vstack` per upstream
928 /// `aten/src/ATen/native/TensorShape.cpp:3532`.
929 pub fn vstack_t(tensors: &[Tensor<T>]) -> FerrotorchResult<Tensor<T>> {
930 crate::grad_fns::shape::vstack(tensors)
931 }
932
933 /// Stack tensors column-wise. Like PyTorch's `torch.hstack` per upstream
934 /// `aten/src/ATen/native/TensorShape.cpp:3514`.
935 pub fn hstack_t(tensors: &[Tensor<T>]) -> FerrotorchResult<Tensor<T>> {
936 crate::grad_fns::shape::hstack(tensors)
937 }
938
939 /// Stack tensors depth-wise (≥3-D then `cat` dim 2). Like PyTorch's
940 /// `torch.dstack` per upstream
941 /// `aten/src/ATen/native/TensorShape.cpp:3544`.
942 pub fn dstack_t(tensors: &[Tensor<T>]) -> FerrotorchResult<Tensor<T>> {
943 crate::grad_fns::shape::dstack(tensors)
944 }
945
946 /// Stack ≤1-D tensors as columns of a 2-D matrix. Like PyTorch's
947 /// `torch.column_stack` per upstream
948 /// `aten/src/ATen/native/TensorShape.cpp:3628`.
949 pub fn column_stack_t(tensors: &[Tensor<T>]) -> FerrotorchResult<Tensor<T>> {
950 crate::grad_fns::shape::column_stack(tensors)
951 }
952
953 /// Return a narrowed view along `dim` starting at `start` with `length`
954 /// elements. Like PyTorch's `tensor.narrow(dim, start, length)`.
955 ///
956 /// Zero-copy: shares storage with the original tensor.
957 pub fn narrow(&self, dim: usize, start: usize, length: usize) -> FerrotorchResult<Tensor<T>> {
958 narrow_t(self, dim, start, length)
959 }
960
961 /// View tensor with new shape. Like PyTorch's `tensor.view(shape)`.
962 ///
963 /// Exactly one dimension may be `-1`, in which case it is inferred.
964 /// Requires the tensor to be contiguous.
965 pub fn view(&self, shape: &[i64]) -> FerrotorchResult<Tensor<T>> {
966 view_t(self, shape)
967 }
968
969 /// Make tensor contiguous — if already contiguous, returns a cheap clone.
970 /// Otherwise materializes a new contiguous buffer.
971 pub fn contiguous(&self) -> FerrotorchResult<Tensor<T>> {
972 contiguous_t(self)
973 }
974
975 /// Split tensor into `chunks` roughly equal pieces along `dim`.
976 pub fn chunk(&self, chunks: usize, dim: usize) -> FerrotorchResult<Vec<Tensor<T>>> {
977 chunk_t(self, chunks, dim)
978 }
979
980 /// Split tensor into pieces of given sizes along `dim`.
981 pub fn split(&self, split_sizes: &[usize], dim: usize) -> FerrotorchResult<Vec<Tensor<T>>> {
982 split_t(self, split_sizes, dim)
983 }
984
985 // --- Quantization ---
986
987 /// `torch.Tensor.fake_quantize_per_tensor_affine(scale, zero_point,
988 /// quant_min, quant_max)` — per-tensor affine fake quantization with
989 /// autograd-tracked clipped STE backward.
990 ///
991 /// Mirrors `torch.fake_quantize_per_tensor_affine` per
992 /// `torch/overrides.py:622 torch.fake_quantize_per_tensor_affine: lambda
993 /// input, scale, zero_point, quant_min, quant_max: -1` and the upstream
994 /// implementation at `aten/src/ATen/native/quantized/
995 /// FakeQuantPerTensorAffine.cpp:31-40 Tensor fake_quantize_per_tensor_affine(
996 /// const Tensor& self, double scale, int64_t zero_point, int64_t quant_min,
997 /// int64_t quant_max)`. Backward per `tools/autograd/derivatives.yaml:673-674
998 /// fake_quantize_per_tensor_affine_cachemask_backward(grad, mask)` returning
999 /// `dY * mask` where the mask is `1` iff
1000 /// `quant_min <= round_ties_even(input/scale) + zero_point <= quant_max`.
1001 ///
1002 /// The non-test production consumer wiring for
1003 /// `grad_fns::quantize_grad::fake_quantize_per_tensor_affine` per
1004 /// R-DEFER-1: this method is the public, chainable surface that closes
1005 /// the consumer requirement for the per-tensor variant (blocker #1238).
1006 pub fn fake_quantize_per_tensor_affine_t(
1007 &self,
1008 scale: f64,
1009 zero_point: i64,
1010 quant_min: i64,
1011 quant_max: i64,
1012 ) -> FerrotorchResult<Tensor<T>> {
1013 crate::grad_fns::quantize_grad::fake_quantize_per_tensor_affine(
1014 self, scale, zero_point, quant_min, quant_max,
1015 )
1016 }
1017
1018 /// `torch.Tensor.fake_quantize_per_channel_affine(scale, zero_point, axis,
1019 /// quant_min, quant_max)` — per-channel affine fake quantization with
1020 /// autograd-tracked clipped STE backward.
1021 ///
1022 /// Mirrors `torch.fake_quantize_per_channel_affine` per
1023 /// `torch/overrides.py:621 torch.fake_quantize_per_channel_affine: lambda
1024 /// input, scale, zero_point, axis, quant_min, quant_max: -1` and the
1025 /// upstream implementation at `aten/src/ATen/native/quantized/
1026 /// FakeQuantPerChannelAffine.cpp:32-42 Tensor fake_quantize_per_channel_affine(
1027 /// const Tensor& self, const Tensor& scale, const Tensor& zero_point,
1028 /// int64_t axis, int64_t quant_min, int64_t quant_max)`. Backward per
1029 /// `tools/autograd/derivatives.yaml fake_quantize_per_channel_affine_cachemask_backward(
1030 /// grad, mask)` returning `dY * mask` where the per-channel mask is `1`
1031 /// iff `quant_min <= round_ties_even(input/scale[c]) + zero_point[c]
1032 /// <= quant_max` for the channel `c` along `axis`.
1033 ///
1034 /// The non-test production consumer wiring for
1035 /// `grad_fns::quantize_grad::fake_quantize_per_channel_affine` per
1036 /// R-DEFER-1: this method is the public, chainable surface that closes
1037 /// the consumer requirement for the per-channel variant (blocker #1239).
1038 pub fn fake_quantize_per_channel_affine_t(
1039 &self,
1040 scale: &Tensor<T>,
1041 zero_point: &crate::int_tensor::IntTensor<i64>,
1042 axis: i64,
1043 quant_min: i64,
1044 quant_max: i64,
1045 ) -> FerrotorchResult<Tensor<T>> {
1046 crate::grad_fns::quantize_grad::fake_quantize_per_channel_affine(
1047 self, scale, zero_point, axis, quant_min, quant_max,
1048 )
1049 }
1050
1051 // --- Indexing (REQ-8 from `.design/ferrotorch-core/grad_fns/indexing.md`) ---
1052
1053 /// `torch.Tensor.index_fill(dim, index, value)` — overwrite slices along
1054 /// `dim` at `index` positions with the scalar `value`.
1055 ///
1056 /// Mirrors `torch.index_fill(input, dim, index, value)` per the upstream
1057 /// docstring at `torch/_torch_docs.py:6563-6567 index_fill(dim, index,
1058 /// value) -> Tensor [...] Out-of-place version of :meth:`torch.Tensor.
1059 /// index_fill_`` and `torch/_tensor_docs.py:2489-2509` which gives the
1060 /// canonical example
1061 ///
1062 /// ```text
1063 /// >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)
1064 /// >>> index = torch.tensor([0, 2])
1065 /// >>> x.index_fill_(1, index, -1)
1066 /// tensor([[-1., 2., -1.],
1067 /// [-1., 5., -1.],
1068 /// [-1., 8., -1.]])
1069 /// ```
1070 ///
1071 /// Upstream C++ entry at `aten/src/ATen/native/TensorAdvancedIndexing.cpp:
1072 /// 1979 Tensor index_fill(const Tensor& self, int64_t dim, const Tensor&
1073 /// index, const Scalar& source) { return self.clone(at::MemoryFormat::
1074 /// Preserve).index_fill_(dim, index, source); }`. Registration at
1075 /// `torch/overrides.py:710 torch.index_fill: lambda input, dim, index,
1076 /// value: -1`.
1077 ///
1078 /// Backward per `tools/autograd/derivatives.yaml:884-887`:
1079 /// `- name: index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor`
1080 /// / `self: grad.index_fill(dim, index, 0)` /
1081 /// `index: non_differentiable` /
1082 /// `result: self_t.index_fill(dim, index, 0)`
1083 /// — gradient is zeroed at every position the fill overwrote (those
1084 /// positions were replaced by a constant and no longer depend on the
1085 /// input).
1086 ///
1087 /// `dim` follows PyTorch's negative-wrapping convention (`at::maybe_wrap_dim`
1088 /// at `TensorAdvancedIndexing.cpp:1919`). The `index` tensor must be 1-D
1089 /// or scalar (upstream `TORCH_CHECK(index.dim() <= 1)` at `:1920`).
1090 /// Negative index values are accepted and wrapped per upstream's
1091 /// `index_fill_kernel` at `aten/src/ATen/native/cpu/IndexKernel.cpp:
1092 /// 224-229` (`TORCH_CHECK_INDEX(idx >= -self_dim_size && idx <
1093 /// self_dim_size, ...); if (idx < 0) { idx += self_dim_size; }`). Indices
1094 /// strictly outside `[-dim_size, dim_size)` raise `IndexOutOfBounds`
1095 /// matching upstream's `TORCH_CHECK_INDEX`. A 0-d input is accepted: the
1096 /// implementation mirrors upstream's `self.unsqueeze(-1)` at
1097 /// `TensorAdvancedIndexing.cpp:1917` by treating the scalar as a length-1
1098 /// 1-d tensor for the fill (only `dim ∈ {-1, 0}` and `index ∈ {-1, 0}`
1099 /// are in range for that case).
1100 ///
1101 /// The non-test production consumer wiring for `grad_fns::indexing::
1102 /// index_fill` per R-DEFER-1: this method is the public, chainable
1103 /// surface that closes the consumer requirement (blocker #1249).
1104 pub fn index_fill_t(
1105 &self,
1106 dim: i64,
1107 index: &crate::int_tensor::IntTensor<i64>,
1108 value: f64,
1109 ) -> FerrotorchResult<Tensor<T>> {
1110 crate::grad_fns::indexing::index_fill(self, dim, index, value)
1111 }
1112
1113 /// `torch.Tensor.scatter_reduce(dim, index, src, reduce, *, include_self=True)`
1114 /// — reduce-mode scatter onto a clone of `self`. Mirrors upstream
1115 /// `Tensor scatter_reduce(...)` at `aten/src/ATen/native/
1116 /// TensorAdvancedIndexing.cpp:2354 TORCH_IMPL_FUNC(scatter_reduce_two)`.
1117 /// `reduce` ∈ {`"sum"` SHIPPED, `"prod"`, `"amax"`, `"amin"`}; backward
1118 /// is implemented only for `"sum"` per `tools/autograd/derivatives.yaml:
1119 /// 3074-3077` (other modes return a no-grad tensor — the
1120 /// op_db characterization sweep emits only `"sum"`).
1121 ///
1122 /// Non-test production consumer wiring for `grad_fns::indexing::
1123 /// scatter_reduce` per R-DEFER-1: this method is the chainable surface.
1124 /// Closes blocker #1245.
1125 pub fn scatter_reduce_t(
1126 &self,
1127 dim: i64,
1128 index: &[usize],
1129 index_shape: &[usize],
1130 src: &Tensor<T>,
1131 reduce: &str,
1132 include_self: bool,
1133 ) -> FerrotorchResult<Tensor<T>> {
1134 let mode =
1135 crate::grad_fns::indexing::ScatterReduce::parse_str(reduce).ok_or_else(|| {
1136 crate::error::FerrotorchError::InvalidArgument {
1137 message: format!(
1138 "scatter_reduce_t: unknown reduce mode '{reduce}' \
1139 (expected sum|prod|amax|amin)"
1140 ),
1141 }
1142 })?;
1143 crate::grad_fns::indexing::scatter_reduce(
1144 self,
1145 dim,
1146 index,
1147 index_shape,
1148 src,
1149 mode,
1150 include_self,
1151 )
1152 }
1153
1154 /// `torch.Tensor.index_add(dim, index, source, *, alpha=1)` —
1155 /// `out = self.clone(); out[..., index[i], ...] += alpha * source[..., i, ...]`
1156 /// along `dim`. Mirrors upstream `Tensor index_add(const Tensor& self,
1157 /// int64_t dim, const Tensor& index, const Tensor& source, const Scalar&
1158 /// alpha)` at `aten/src/ATen/native/TensorAdvancedIndexing.cpp:1153
1159 /// TORCH_IMPL_FUNC(index_add_cpu_out)`. Backward per
1160 /// `tools/autograd/derivatives.yaml:862-869 self: grad / source:
1161 /// maybe_multiply(grad.index_select(dim, index).expand_as(source), alpha)`.
1162 ///
1163 /// Non-test production consumer wiring for `grad_fns::indexing::
1164 /// index_add` per R-DEFER-1: this method is the chainable surface.
1165 /// Closes blocker #1247.
1166 pub fn index_add_t(
1167 &self,
1168 dim: i64,
1169 index: &crate::int_tensor::IntTensor<i64>,
1170 source: &Tensor<T>,
1171 alpha: f64,
1172 ) -> FerrotorchResult<Tensor<T>> {
1173 crate::grad_fns::indexing::index_add(self, dim, index, source, alpha)
1174 }
1175
1176 /// `torch.Tensor.index_copy(dim, index, source)` — `out = self.clone();
1177 /// out[..., index[i], ...] = source[..., i, ...]` along `dim`. Mirrors
1178 /// upstream `Tensor index_copy(...)` at `aten/src/ATen/native/
1179 /// TensorAdvancedIndexing.cpp:1082 TORCH_IMPL_FUNC(index_copy_out)`.
1180 /// Backward per `tools/autograd/derivatives.yaml:875-883
1181 /// self: grad.index_fill(dim, index, 0) / source:
1182 /// grad.index_select(dim, index).expand_as(source)`.
1183 ///
1184 /// Non-test production consumer wiring for `grad_fns::indexing::
1185 /// index_copy` per R-DEFER-1: this method is the chainable surface.
1186 /// Closes blocker #1248.
1187 pub fn index_copy_t(
1188 &self,
1189 dim: i64,
1190 index: &crate::int_tensor::IntTensor<i64>,
1191 source: &Tensor<T>,
1192 ) -> FerrotorchResult<Tensor<T>> {
1193 crate::grad_fns::indexing::index_copy(self, dim, index, source)
1194 }
1195
1196 /// `torch.Tensor.masked_scatter(mask, source)` — copy elements from
1197 /// `source` into a clone of `self` at positions where `mask` is true,
1198 /// in C-order. Mirrors upstream `Tensor masked_scatter(const Tensor&
1199 /// self, const Tensor& mask, const Tensor& source)` at
1200 /// `aten/src/ATen/native/TensorAdvancedIndexing.cpp:2402-2409`.
1201 /// Backward per `tools/autograd/derivatives.yaml:1105-1108
1202 /// self: grad.masked_fill(mask, 0) / source: masked_scatter_backward(...)`.
1203 ///
1204 /// Non-test production consumer wiring for `grad_fns::indexing::
1205 /// masked_scatter` per R-DEFER-1: this method is the chainable surface.
1206 /// Closes blocker #1252.
1207 pub fn masked_scatter_t(
1208 &self,
1209 mask: &crate::bool_tensor::BoolTensor,
1210 source: &Tensor<T>,
1211 ) -> FerrotorchResult<Tensor<T>> {
1212 crate::grad_fns::indexing::masked_scatter(self, mask, source)
1213 }
1214
1215 /// `torch.Tensor.take(index)` — `out[i] = self.view(-1)[index[i]]`, a
1216 /// flat-index gather producing a tensor of shape `index.shape()`.
1217 /// Mirrors upstream `Tensor take(const Tensor& self, const Tensor& index)`
1218 /// at `aten/src/ATen/native/TensorAdvancedIndexing.cpp:1067-1071`.
1219 /// Backward per `tools/autograd/derivatives.yaml:1766-1769
1220 /// self: take_backward(grad, self, index)` — scatter-add grad into a
1221 /// zeros buffer at the flat index positions.
1222 ///
1223 /// Non-test production consumer wiring for `grad_fns::indexing::take`
1224 /// per R-DEFER-1: this method is the chainable surface.
1225 /// Closes blocker #1253.
1226 pub fn take_t(&self, index: &crate::int_tensor::IntTensor<i64>) -> FerrotorchResult<Tensor<T>> {
1227 crate::grad_fns::indexing::take(self, index)
1228 }
1229
1230 /// `torch.Tensor.put(index, source, accumulate=False)` — flat-index
1231 /// scatter into a clone of `self`: `out.view(-1)[index[i]] = source[i]`
1232 /// (or `+= source[i]` when `accumulate=true`). Mirrors upstream
1233 /// `Tensor put(const Tensor& self, const Tensor& index, const Tensor&
1234 /// source, const bool accumulate)` at `aten/src/ATen/native/
1235 /// TensorAdvancedIndexing.cpp:928-934`. Backward per
1236 /// `tools/autograd/derivatives.yaml:1421-1424`.
1237 ///
1238 /// Non-test production consumer wiring for `grad_fns::indexing::put`
1239 /// per R-DEFER-1: this method is the chainable surface.
1240 /// Closes blocker #1254.
1241 pub fn put_t(
1242 &self,
1243 index: &crate::int_tensor::IntTensor<i64>,
1244 source: &Tensor<T>,
1245 accumulate: bool,
1246 ) -> FerrotorchResult<Tensor<T>> {
1247 crate::grad_fns::indexing::put(self, index, source, accumulate)
1248 }
1249
1250 /// `torch.where(condition, self, other)` — pointwise ternary selection
1251 /// taking a host `&[bool]` mask. Returns a tensor where each element is
1252 /// `self[i]` if `condition[i]` is true, else `other[i]`. Differentiable
1253 /// — a `WhereBackward` node is attached when grad tracking is enabled
1254 /// on either input.
1255 ///
1256 /// Mirrors `torch.where(condition, input, other)` per
1257 /// `torch/_torch_docs.py:13089` and the upstream impl macro at
1258 /// `aten/src/ATen/native/TensorCompare.cpp:646
1259 /// TORCH_IMPL_FUNC(where_out)` — the `self`-vs-other dispatch shape.
1260 ///
1261 /// Non-test production consumer wiring for
1262 /// `grad_fns::comparison::where_` per R-DEFER-1 (closes blocker #1295):
1263 /// this method is the public, chainable surface that closes the
1264 /// consumer requirement. The boolean-tensor variant is `where_bt_t`.
1265 pub fn where_t(&self, condition: &[bool], other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1266 crate::grad_fns::comparison::where_(condition, self, other)
1267 }
1268
1269 /// `torch.where(condition, self, other)` — `BoolTensor` overload.
1270 ///
1271 /// Pointwise ternary selection where `condition` is a first-class
1272 /// [`BoolTensor`](crate::bool_tensor::BoolTensor). The condition must
1273 /// match `self.numel()` and `self.shape() == other.shape()`. Delegates
1274 /// to `grad_fns::comparison::where_bt` which validates shape +
1275 /// materialises the host mask and dispatches to `where_` for
1276 /// autograd-aware forward.
1277 ///
1278 /// Mirrors `torch.where(cond, x, y)` for `cond: BoolTensor` per
1279 /// `torch/_torch_docs.py:13089`.
1280 ///
1281 /// Non-test production consumer wiring for
1282 /// `grad_fns::comparison::where_bt` per R-DEFER-1 (closes blocker
1283 /// #1297): this method is the public, chainable surface that closes
1284 /// the consumer requirement.
1285 pub fn where_bt_t(
1286 &self,
1287 condition: &crate::bool_tensor::BoolTensor,
1288 other: &Tensor<T>,
1289 ) -> FerrotorchResult<Tensor<T>> {
1290 crate::grad_fns::comparison::where_bt(condition, self, other)
1291 }
1292
1293 /// `torch.Tensor.scatter_(dim, index, value)` (scalar-src overload) —
1294 /// scatter a single scalar `value` into a clone of `self` at the
1295 /// positions named by `index` along `dim`. Mirrors the upstream
1296 /// scalar overload `Tensor& scatter_(int64_t dim, const Tensor& index,
1297 /// const Scalar& value)` at
1298 /// `aten/src/ATen/native/TensorAdvancedIndexing.cpp:2278` —
1299 /// the `scatter.value` dispatch arm that op_db emits as a distinct
1300 /// sample family alongside the tensor-src overload.
1301 ///
1302 /// Equivalent to `self.scatter_(dim, index, full_like(index, value))`
1303 /// but avoids the temporary `src` allocation. No autograd is attached
1304 /// because the scalar `value` is not a differentiable input.
1305 ///
1306 /// Non-test production consumer wiring for
1307 /// `crate::ops::indexing::scatter_value` per R-DEFER-1 (closes blocker
1308 /// #1258): this method is the public, chainable surface that closes
1309 /// the consumer requirement.
1310 pub fn scatter_value_t(
1311 &self,
1312 dim: i64,
1313 index: &[usize],
1314 index_shape: &[usize],
1315 value: T,
1316 ) -> FerrotorchResult<Tensor<T>> {
1317 crate::ops::indexing::scatter_value(self, dim as isize, index, index_shape, value)
1318 }
1319
1320 // --- PyTorch compatibility aliases ---
1321
1322 /// Alias for `shape()`. Returns the tensor dimensions like PyTorch's `Tensor.size()`.
1323 #[inline]
1324 pub fn size(&self) -> &[usize] {
1325 self.shape()
1326 }
1327
1328 /// Alias for `ndim()`. Returns the number of dimensions like PyTorch's `Tensor.dim()`.
1329 #[inline]
1330 pub fn dim(&self) -> usize {
1331 self.ndim()
1332 }
1333
1334 // --- Utility ---
1335
1336 /// Log the tensor's `Display` form and return `self` for chaining.
1337 ///
1338 /// Emits a `tracing::info!` event on target `ferrotorch::tensor`. Behaviour
1339 /// change vs. earlier versions: this no longer writes directly to stdout —
1340 /// callers must install a `tracing` subscriber (e.g. `tracing_subscriber`)
1341 /// to see the output. Library code should not write to stdout; downstream
1342 /// consumers control logging policy.
1343 pub fn print(&self) -> &Self {
1344 tracing::info!(target: "ferrotorch::tensor", "{self}");
1345 self
1346 }
1347}
1348
1349// ---------------------------------------------------------------------------
1350// Free functions: permute, view, contiguous, chunk, split
1351// ---------------------------------------------------------------------------
1352
1353/// Permute tensor dimensions. Like PyTorch's `tensor.permute(dims)`.
1354///
1355/// `dims` must be a valid permutation of `0..ndim`.
1356pub fn permute_t<T: Float>(input: &Tensor<T>, dims: &[usize]) -> FerrotorchResult<Tensor<T>> {
1357 use crate::error::FerrotorchError;
1358
1359 let ndim = input.ndim();
1360 if dims.len() != ndim {
1361 return Err(FerrotorchError::InvalidArgument {
1362 message: format!(
1363 "permute: dims length {} does not match tensor ndim {}",
1364 dims.len(),
1365 ndim
1366 ),
1367 });
1368 }
1369
1370 // Validate that dims is a valid permutation.
1371 let mut seen = vec![false; ndim];
1372 for &d in dims {
1373 if d >= ndim {
1374 return Err(FerrotorchError::InvalidArgument {
1375 message: format!("permute: dim {d} is out of bounds for ndim {ndim}"),
1376 });
1377 }
1378 if seen[d] {
1379 return Err(FerrotorchError::InvalidArgument {
1380 message: format!("permute: duplicate dim {d} in permutation"),
1381 });
1382 }
1383 seen[d] = true;
1384 }
1385
1386 // Zero-copy: permute shape and strides without copying data.
1387 let in_shape = input.shape();
1388 let in_strides = input.strides();
1389 let out_shape: Vec<usize> = dims.iter().map(|&d| in_shape[d]).collect();
1390 let out_strides: Vec<isize> = dims.iter().map(|&d| in_strides[d]).collect();
1391 let offset = input.storage_offset();
1392
1393 if crate::autograd::no_grad::is_grad_enabled() && input.requires_grad() {
1394 let grad_fn = std::sync::Arc::new(PermuteBackward {
1395 input: input.clone(),
1396 dims: dims.to_vec(),
1397 });
1398 Ok(input.stride_view_operation(out_shape, out_strides, offset, grad_fn))
1399 } else {
1400 Ok(input.stride_view(out_shape, out_strides, offset))
1401 }
1402}
1403
1404/// Backward for permute: apply the inverse permutation to the gradient.
1405#[derive(Debug)]
1406struct PermuteBackward<T: Float> {
1407 input: Tensor<T>,
1408 dims: Vec<usize>,
1409}
1410
1411impl<T: Float> crate::tensor::GradFn<T> for PermuteBackward<T> {
1412 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1413 if !self.input.requires_grad() {
1414 return Ok(vec![None]);
1415 }
1416 // Compute inverse permutation.
1417 let mut inv_dims = vec![0usize; self.dims.len()];
1418 for (i, &d) in self.dims.iter().enumerate() {
1419 inv_dims[d] = i;
1420 }
1421 let grad_input = permute_t(grad_output, &inv_dims)?;
1422 Ok(vec![Some(grad_input)])
1423 }
1424
1425 fn inputs(&self) -> Vec<&Tensor<T>> {
1426 vec![&self.input]
1427 }
1428
1429 fn name(&self) -> &'static str {
1430 "PermuteBackward"
1431 }
1432}
1433
1434/// Zero-copy narrow (slice) along a dimension.
1435///
1436/// Returns a view with the same storage, adjusting offset and shape.
1437/// Like PyTorch's `tensor.narrow(dim, start, length)`.
1438pub fn narrow_t<T: Float>(
1439 input: &Tensor<T>,
1440 dim: usize,
1441 start: usize,
1442 length: usize,
1443) -> FerrotorchResult<Tensor<T>> {
1444 use crate::error::FerrotorchError;
1445
1446 let ndim = input.ndim();
1447 if dim >= ndim {
1448 return Err(FerrotorchError::InvalidArgument {
1449 message: format!("narrow: dim {dim} out of bounds for ndim {ndim}"),
1450 });
1451 }
1452 let dim_size = input.shape()[dim];
1453 if start + length > dim_size {
1454 return Err(FerrotorchError::InvalidArgument {
1455 message: format!(
1456 "narrow: start({}) + length({}) = {} exceeds dim size {}",
1457 start,
1458 length,
1459 start + length,
1460 dim_size,
1461 ),
1462 });
1463 }
1464
1465 let strides = input.strides();
1466 let mut new_shape = input.shape().to_vec();
1467 new_shape[dim] = length;
1468
1469 // Advance offset by start * stride[dim] elements.
1470 let new_offset = input.storage_offset() + start * strides[dim] as usize;
1471
1472 if crate::autograd::no_grad::is_grad_enabled() && input.requires_grad() {
1473 let grad_fn = std::sync::Arc::new(NarrowBackward {
1474 input: input.clone(),
1475 dim,
1476 start,
1477 });
1478 Ok(input.stride_view_operation(new_shape, strides.to_vec(), new_offset, grad_fn))
1479 } else {
1480 Ok(input.stride_view(new_shape, strides.to_vec(), new_offset))
1481 }
1482}
1483
1484/// Backward for narrow: pad the gradient with zeros in the sliced dimension.
1485#[derive(Debug)]
1486struct NarrowBackward<T: Float> {
1487 input: Tensor<T>,
1488 dim: usize,
1489 start: usize,
1490}
1491
1492impl<T: Float> crate::tensor::GradFn<T> for NarrowBackward<T> {
1493 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1494 if !self.input.requires_grad() {
1495 return Ok(vec![None]);
1496 }
1497 // Create a zero tensor matching the input shape and scatter the
1498 // gradient into the narrowed region.
1499 let mut grad_data = vec![<T as num_traits::Zero>::zero(); self.input.numel()];
1500 let grad_out_data = grad_output.data_vec()?;
1501 let in_shape = self.input.shape();
1502 let dim = self.dim;
1503 let start = self.start;
1504 let _length = grad_output.shape()[dim];
1505
1506 // Walk contiguous output elements and map to input flat indices.
1507 let out_strides = crate::shape::c_contiguous_strides(grad_output.shape());
1508 let in_strides = crate::shape::c_contiguous_strides(in_shape);
1509 let ndim = in_shape.len();
1510 let out_numel = grad_out_data.len();
1511
1512 for (flat, &grad_val) in grad_out_data[..out_numel].iter().enumerate() {
1513 // Decompose flat index to output coords.
1514 let mut rem = flat;
1515 let mut in_flat: usize = 0;
1516 for d in 0..ndim {
1517 let coord = rem / out_strides[d] as usize;
1518 rem %= out_strides[d] as usize;
1519 let in_coord = if d == dim { coord + start } else { coord };
1520 in_flat += in_coord * in_strides[d] as usize;
1521 }
1522 grad_data[in_flat] = grad_val;
1523 }
1524
1525 let device = self.input.device();
1526 let storage = crate::storage::TensorStorage::on_device(grad_data, device)?;
1527 let grad_input = Tensor::from_storage(storage, in_shape.to_vec(), false)?;
1528 Ok(vec![Some(grad_input)])
1529 }
1530
1531 fn inputs(&self) -> Vec<&Tensor<T>> {
1532 vec![&self.input]
1533 }
1534
1535 fn name(&self) -> &'static str {
1536 "NarrowBackward"
1537 }
1538}
1539
1540/// View tensor with new shape. Like PyTorch's `tensor.view(shape)`.
1541///
1542/// Exactly one dimension may be `-1`, in which case it is inferred.
1543/// Requires the tensor to be contiguous (currently all tensors are).
1544pub fn view_t<T: Float>(input: &Tensor<T>, shape: &[i64]) -> FerrotorchResult<Tensor<T>> {
1545 use crate::error::FerrotorchError;
1546
1547 if !input.is_contiguous() {
1548 return Err(FerrotorchError::InvalidArgument {
1549 message: "view: tensor must be contiguous; call .contiguous() first".into(),
1550 });
1551 }
1552
1553 // Convert i64 shape to isize for reshape (which handles -1 inference).
1554 let isize_shape: Vec<isize> = shape.iter().map(|&d| d as isize).collect();
1555 crate::grad_fns::shape::reshape(input, &isize_shape)
1556}
1557
1558/// Make tensor contiguous (copy data if needed).
1559///
1560/// If the tensor is already contiguous this returns a cheap clone.
1561/// Otherwise it gathers the data in C-order and creates a new
1562/// contiguous tensor, preserving the original device.
1563///
1564/// **GPU fast path (CL-496).** For non-contiguous CUDA tensors of rank
1565/// ≤ 8, this dispatches to the backend's `strided_copy_{f32,f64}`
1566/// kernel which gathers the view on-device and avoids the CPU
1567/// roundtrip that `data_vec()` would otherwise incur. Higher ranks
1568/// or missing GPU backends fall back to the host-memory path.
1569pub fn contiguous_t<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1570 use std::any::TypeId;
1571
1572 // A tensor whose strides look C-contiguous BUT carries a non-zero
1573 // `storage_offset` is NOT safe to fast-path-`clone()`: `is_contiguous()`
1574 // mirrors PyTorch (`c10::TensorImpl::is_contiguous`, computed from
1575 // sizes/strides ONLY — see `compute_contiguous`), so it returns `true`
1576 // for a row-narrowed view (`full.narrow(0, 1, 3)` keeps row-major strides
1577 // `[D, 1]`). The GPU consumers (`scatter_add_segments_cuda` et al.) read
1578 // `gpu_handle()` from the base buffer pointer (element 0) and DROP the
1579 // offset — silent wrong results (#1657). Materialising an offset != 0 view
1580 // through the strided-copy path below yields a fresh packed offset-0 buffer
1581 // whose `gpu_handle()` points at the logical element 0, fixing the whole
1582 // class at one site. `strided_copy_*` (and the CPU `data_vec()` fallback)
1583 // both honour `storage_offset`.
1584 if input.is_contiguous() && input.storage_offset() == 0 {
1585 return Ok(input.clone());
1586 }
1587 let device = input.device();
1588
1589 // GPU fast path: dispatch to the backend's strided_copy kernel
1590 // when the input is a non-contiguous CUDA tensor with rank ≤ 8.
1591 if device.is_cuda()
1592 && input.shape().len() <= 8
1593 && let Some(backend) = crate::gpu_dispatch::gpu_backend()
1594 {
1595 let in_handle = input.gpu_handle()?;
1596 let out_shape = input.shape().to_vec();
1597 let src_strides = input.strides().to_vec();
1598 let src_offset = input.storage_offset();
1599
1600 let out_handle = if TypeId::of::<T>() == TypeId::of::<f32>() {
1601 backend.strided_copy_f32(in_handle, &out_shape, &src_strides, src_offset)
1602 } else if TypeId::of::<T>() == TypeId::of::<f64>() {
1603 backend.strided_copy_f64(in_handle, &out_shape, &src_strides, src_offset)
1604 } else {
1605 // Unsupported dtype — fall through to CPU path.
1606 return contiguous_t_cpu(input);
1607 };
1608
1609 if let Ok(handle) = out_handle {
1610 let storage = TensorStorage::gpu(handle);
1611 return if crate::autograd::no_grad::is_grad_enabled() && input.requires_grad() {
1612 let grad_fn = std::sync::Arc::new(ContiguousBackward {
1613 input: input.clone(),
1614 });
1615 Tensor::from_operation(storage, out_shape, grad_fn)
1616 } else {
1617 Tensor::from_storage(storage, out_shape, false)
1618 };
1619 }
1620 // Kernel failure (negative strides, overflow, etc.) —
1621 // fall through to the host path which handles any layout.
1622 }
1623
1624 contiguous_t_cpu(input)
1625}
1626
1627/// CPU path for [`contiguous_t`]. Always valid for any layout; used
1628/// as a fallback when the GPU fast path declines or errors.
1629fn contiguous_t_cpu<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1630 let device = input.device();
1631 let data = input.data_vec()?;
1632 let storage = TensorStorage::on_device(data, device)?;
1633
1634 // Preserve the autograd graph: contiguous is a pure data copy, so the
1635 // backward is the identity (same shape, same semantics). Without this,
1636 // calling .contiguous() on a non-contiguous view severs the grad_fn chain.
1637 if crate::autograd::no_grad::is_grad_enabled() && input.requires_grad() {
1638 let grad_fn = std::sync::Arc::new(ContiguousBackward {
1639 input: input.clone(),
1640 });
1641 Tensor::from_operation(storage, input.shape().to_vec(), grad_fn)
1642 } else {
1643 Tensor::from_storage(storage, input.shape().to_vec(), false)
1644 }
1645}
1646
1647/// Backward for contiguous: gradient passes through unchanged (identity).
1648#[derive(Debug)]
1649struct ContiguousBackward<T: Float> {
1650 input: Tensor<T>,
1651}
1652
1653impl<T: Float> crate::tensor::GradFn<T> for ContiguousBackward<T> {
1654 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1655 if self.input.requires_grad() {
1656 Ok(vec![Some(grad_output.clone())])
1657 } else {
1658 Ok(vec![None])
1659 }
1660 }
1661
1662 fn inputs(&self) -> Vec<&Tensor<T>> {
1663 vec![&self.input]
1664 }
1665
1666 fn name(&self) -> &'static str {
1667 "ContiguousBackward"
1668 }
1669}
1670
1671/// Split tensor into `chunks` roughly equal pieces along `dim`.
1672///
1673/// If the tensor size along `dim` is not evenly divisible by `chunks`,
1674/// the last chunk will be smaller.
1675pub fn chunk_t<T: Float>(
1676 input: &Tensor<T>,
1677 chunks: usize,
1678 dim: usize,
1679) -> FerrotorchResult<Vec<Tensor<T>>> {
1680 use crate::error::FerrotorchError;
1681
1682 if chunks == 0 {
1683 return Err(FerrotorchError::InvalidArgument {
1684 message: "chunk: chunks must be > 0".into(),
1685 });
1686 }
1687
1688 let shape = input.shape();
1689 if dim >= shape.len() {
1690 return Err(FerrotorchError::InvalidArgument {
1691 message: format!(
1692 "chunk: dim {} is out of bounds for tensor with {} dimensions",
1693 dim,
1694 shape.len()
1695 ),
1696 });
1697 }
1698
1699 let dim_size = shape[dim];
1700 let chunk_size = dim_size.div_ceil(chunks);
1701 let mut split_sizes = Vec::new();
1702 let mut remaining = dim_size;
1703 while remaining > 0 {
1704 let s = chunk_size.min(remaining);
1705 split_sizes.push(s);
1706 remaining -= s;
1707 }
1708
1709 split_t(input, &split_sizes, dim)
1710}
1711
1712/// Split tensor into pieces of given sizes along `dim`.
1713///
1714/// The sum of `split_sizes` must equal the tensor's size along `dim`.
1715/// When gradient tracking is enabled and the input requires grad, each
1716/// output chunk is connected to the autograd graph via `SplitBackward`.
1717pub fn split_t<T: Float>(
1718 input: &Tensor<T>,
1719 split_sizes: &[usize],
1720 dim: usize,
1721) -> FerrotorchResult<Vec<Tensor<T>>> {
1722 use crate::autograd::no_grad::is_grad_enabled;
1723 use crate::error::FerrotorchError;
1724 use crate::grad_fns::shape::SplitBackward;
1725 use crate::storage::TensorStorage;
1726 use std::any::TypeId;
1727 use std::sync::Arc;
1728
1729 let shape = input.shape();
1730 let ndim = shape.len();
1731
1732 if dim >= ndim {
1733 return Err(FerrotorchError::InvalidArgument {
1734 message: format!("split: dim {dim} is out of bounds for tensor with {ndim} dimensions"),
1735 });
1736 }
1737
1738 let total: usize = split_sizes.iter().sum();
1739 if total != shape[dim] {
1740 return Err(FerrotorchError::InvalidArgument {
1741 message: format!(
1742 "split: split_sizes sum {} does not match dim {} size {}",
1743 total, dim, shape[dim]
1744 ),
1745 });
1746 }
1747
1748 let device = input.device();
1749 let needs_grad = is_grad_enabled() && input.requires_grad();
1750
1751 // GPU fast path: use strided_split to extract each chunk directly on GPU.
1752 if device.is_cuda()
1753 && TypeId::of::<T>() == TypeId::of::<f32>()
1754 && let Some(backend) = crate::gpu_dispatch::gpu_backend()
1755 {
1756 let inner: usize = if dim + 1 < ndim {
1757 shape[dim + 1..].iter().product()
1758 } else {
1759 1
1760 };
1761 let total_along_dim = shape[dim];
1762 let in_handle = input.gpu_handle()?;
1763
1764 let mut results = Vec::with_capacity(split_sizes.len());
1765 let mut offset_along_dim = 0usize;
1766
1767 for &split_size in split_sizes {
1768 let mut chunk_shape = shape.to_vec();
1769 chunk_shape[dim] = split_size;
1770 let chunk_numel: usize = chunk_shape.iter().product();
1771
1772 let chunk_handle = backend.strided_split_f32(
1773 in_handle,
1774 total_along_dim,
1775 offset_along_dim,
1776 split_size,
1777 inner,
1778 chunk_numel,
1779 )?;
1780
1781 let storage = TensorStorage::gpu(chunk_handle);
1782 let t = if needs_grad {
1783 let grad_fn = Arc::new(SplitBackward::new(
1784 input.clone(),
1785 dim,
1786 offset_along_dim,
1787 split_size,
1788 ));
1789 Tensor::from_operation(storage, chunk_shape, grad_fn)?
1790 } else {
1791 Tensor::from_storage(storage, chunk_shape, false)?
1792 };
1793 results.push(t);
1794 offset_along_dim += split_size;
1795 }
1796
1797 return Ok(results);
1798 }
1799
1800 // CPU path (also serves as fallback for non-f32 or missing backend).
1801 let in_data = input.data_vec()?;
1802
1803 let outer: usize = shape[..dim].iter().product();
1804 let inner: usize = if dim + 1 < ndim {
1805 shape[dim + 1..].iter().product()
1806 } else {
1807 1
1808 };
1809 let total_along_dim = shape[dim];
1810
1811 let mut results = Vec::with_capacity(split_sizes.len());
1812 let mut offset_along_dim = 0usize;
1813
1814 for &split_size in split_sizes {
1815 let mut chunk_shape = shape.to_vec();
1816 chunk_shape[dim] = split_size;
1817 let chunk_numel: usize = chunk_shape.iter().product();
1818 let mut chunk_data = vec![<T as num_traits::Zero>::zero(); chunk_numel];
1819
1820 for o in 0..outer {
1821 let src_start = o * total_along_dim * inner + offset_along_dim * inner;
1822 let dst_start = o * split_size * inner;
1823 let row_len = split_size * inner;
1824 chunk_data[dst_start..dst_start + row_len]
1825 .copy_from_slice(&in_data[src_start..src_start + row_len]);
1826 }
1827
1828 let storage = TensorStorage::on_device(chunk_data, device)?;
1829 let t = if needs_grad {
1830 let grad_fn = Arc::new(SplitBackward::new(
1831 input.clone(),
1832 dim,
1833 offset_along_dim,
1834 split_size,
1835 ));
1836 Tensor::from_operation(storage, chunk_shape, grad_fn)?
1837 } else {
1838 Tensor::from_storage(storage, chunk_shape, false)?
1839 };
1840 results.push(t);
1841 offset_along_dim += split_size;
1842 }
1843
1844 Ok(results)
1845}
1846
1847#[cfg(test)]
1848mod tests {
1849 use crate::*;
1850
1851 #[test]
1852 // reason: relu is pure passthrough or hard-zero; both branches preserve
1853 // the exact bit pattern (no arithmetic), so equality is the right check.
1854 #[allow(clippy::float_cmp)]
1855 fn test_method_relu() {
1856 let a = scalar(2.0f32).unwrap();
1857 assert_eq!(a.relu().unwrap().item().unwrap(), 2.0);
1858
1859 let b = scalar(-1.0f32).unwrap();
1860 assert_eq!(b.relu().unwrap().item().unwrap(), 0.0);
1861 }
1862
1863 #[test]
1864 fn test_method_matmul() {
1865 let a = from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1866 let b = from_slice(&[5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
1867 let c = a.matmul(&b).unwrap();
1868 assert_eq!(c.shape(), &[2, 2]);
1869 }
1870
1871 #[test]
1872 // reason: sum of small integer-valued floats (1+2+3=6) is bit-exact in
1873 // any deterministic order — the partial sums never lose mantissa bits,
1874 // so equality is the right check.
1875 #[allow(clippy::float_cmp)]
1876 fn test_method_sum() {
1877 let a = tensor(&[1.0f32, 2.0, 3.0]).unwrap();
1878 let s = a.sum_all().unwrap();
1879 assert_eq!(s.item().unwrap(), 6.0);
1880 }
1881
1882 #[test]
1883 fn test_method_transpose() {
1884 let a = from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1885 let b = a.t().unwrap();
1886 assert_eq!(b.shape(), &[3, 2]);
1887 }
1888
1889 #[test]
1890 // reason: 3^2 = 9 in f32 is bit-exact (small integer power of small
1891 // integer), and relu of a positive integer is passthrough. The whole
1892 // chain produces exactly 9.0, so equality is the right check.
1893 #[allow(clippy::float_cmp)]
1894 fn test_method_chain() {
1895 let a = scalar(3.0f32).unwrap().requires_grad_(true);
1896 // a.pow(2).relu().sum() = relu(9) = 9
1897 let c = a.pow_t(2.0).unwrap().relu().unwrap();
1898 assert_eq!(c.item().unwrap(), 9.0);
1899 }
1900
1901 #[test]
1902 fn test_method_sigmoid() {
1903 let a = scalar(0.0f32).unwrap();
1904 let s = a.sigmoid().unwrap();
1905 assert!((s.item().unwrap() - 0.5).abs() < 1e-6);
1906 }
1907
1908 #[test]
1909 fn test_method_flatten() {
1910 let a = from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1911 let f = a.flatten_t().unwrap();
1912 assert_eq!(f.shape(), &[6]);
1913 }
1914
1915 #[test]
1916 fn test_method_print_chain() {
1917 let a = scalar(42.0f32).unwrap();
1918 // .print() returns &Self for chaining
1919 let _ = a.print();
1920 }
1921
1922 // --- sum_dim / mean_dim method wrappers ---
1923
1924 #[test]
1925 fn test_method_sum_dim() {
1926 let a = from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1927 let s = a.sum_dim(1, false).unwrap();
1928 assert_eq!(s.shape(), &[2]);
1929 assert!((s.data().unwrap()[0] - 6.0).abs() < 1e-6);
1930 assert!((s.data().unwrap()[1] - 15.0).abs() < 1e-6);
1931 }
1932
1933 #[test]
1934 fn test_method_mean_dim() {
1935 let a = from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1936 let m = a.mean_dim(0, false).unwrap();
1937 assert_eq!(m.shape(), &[3]);
1938 assert!((m.data().unwrap()[0] - 2.5).abs() < 1e-6);
1939 }
1940
1941 // --- permute ---
1942
1943 #[test]
1944 fn test_method_permute_2d() {
1945 // Transpose via permute — now zero-copy (stride view).
1946 let a = from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1947 let b = a.permute(&[1, 0]).unwrap();
1948 assert_eq!(b.shape(), &[3, 2]);
1949 // Non-contiguous view — use data_vec() to read logical order.
1950 assert_eq!(b.data_vec().unwrap(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1951 // Verify it's a view (shares storage).
1952 assert!(!b.is_contiguous());
1953 }
1954
1955 #[test]
1956 // reason: permute is pure indexing — it rearranges values without any
1957 // arithmetic, so each output slot holds the exact bit pattern of the
1958 // corresponding input slot.
1959 #[allow(clippy::float_cmp)]
1960 fn test_method_permute_3d() {
1961 let data: Vec<f32> = (1..=24).map(|x| x as f32).collect();
1962 let a = from_slice(&data, &[2, 3, 4]).unwrap();
1963 let b = a.permute(&[2, 0, 1]).unwrap();
1964 assert_eq!(b.shape(), &[4, 2, 3]);
1965 let bdata = b.data_vec().unwrap();
1966 // element [0,0,0] of output = element [0,0,0] of input = 1.0
1967 assert_eq!(bdata[0], 1.0);
1968 // element [1,0,0] of output = input[0,0,1] = 2.0
1969 assert_eq!(bdata[2 * 3], 2.0);
1970 }
1971
1972 #[test]
1973 fn test_permute_invalid_dims() {
1974 let a = from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1975 assert!(a.permute(&[0]).is_err()); // wrong length
1976 assert!(a.permute(&[0, 0]).is_err()); // duplicate
1977 assert!(a.permute(&[0, 2]).is_err()); // out of bounds
1978 }
1979
1980 // --- view ---
1981
1982 #[test]
1983 fn test_method_view() {
1984 let a = from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1985 let b = a.view(&[3, 2]).unwrap();
1986 assert_eq!(b.shape(), &[3, 2]);
1987 assert_eq!(b.data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1988 }
1989
1990 #[test]
1991 fn test_method_view_infer() {
1992 let a = from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
1993 let b = a.view(&[2, -1]).unwrap();
1994 assert_eq!(b.shape(), &[2, 3]);
1995 }
1996
1997 // --- contiguous ---
1998
1999 #[test]
2000 fn test_method_contiguous() {
2001 let a = from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
2002 let b = a.contiguous().unwrap();
2003 assert_eq!(b.shape(), &[3]);
2004 assert_eq!(b.data().unwrap(), &[1.0, 2.0, 3.0]);
2005 }
2006
2007 // --- chunk ---
2008
2009 #[test]
2010 fn test_method_chunk_even() {
2011 let a = from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
2012 let chunks = a.chunk(3, 0).unwrap();
2013 assert_eq!(chunks.len(), 3);
2014 assert_eq!(chunks[0].data().unwrap(), &[1.0, 2.0]);
2015 assert_eq!(chunks[1].data().unwrap(), &[3.0, 4.0]);
2016 assert_eq!(chunks[2].data().unwrap(), &[5.0, 6.0]);
2017 }
2018
2019 #[test]
2020 fn test_method_chunk_uneven() {
2021 let a = from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
2022 let chunks = a.chunk(3, 0).unwrap();
2023 assert_eq!(chunks.len(), 3);
2024 assert_eq!(chunks[0].shape(), &[2]);
2025 assert_eq!(chunks[1].shape(), &[2]);
2026 assert_eq!(chunks[2].shape(), &[1]);
2027 }
2028
2029 #[test]
2030 fn test_method_chunk_2d() {
2031 let a = from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
2032 let chunks = a.chunk(2, 0).unwrap();
2033 assert_eq!(chunks.len(), 2);
2034 assert_eq!(chunks[0].shape(), &[2, 2]);
2035 assert_eq!(chunks[1].shape(), &[1, 2]);
2036 }
2037
2038 // --- split ---
2039
2040 #[test]
2041 fn test_method_split() {
2042 let a = from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
2043 let parts = a.split(&[2, 3], 0).unwrap();
2044 assert_eq!(parts.len(), 2);
2045 assert_eq!(parts[0].data().unwrap(), &[1.0, 2.0]);
2046 assert_eq!(parts[1].data().unwrap(), &[3.0, 4.0, 5.0]);
2047 }
2048
2049 #[test]
2050 fn test_method_split_2d_axis1() {
2051 let a = from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]).unwrap();
2052 let parts = a.split(&[1, 3], 1).unwrap();
2053 assert_eq!(parts.len(), 2);
2054 assert_eq!(parts[0].shape(), &[2, 1]);
2055 assert_eq!(parts[0].data().unwrap(), &[1.0, 5.0]);
2056 assert_eq!(parts[1].shape(), &[2, 3]);
2057 assert_eq!(parts[1].data().unwrap(), &[2.0, 3.0, 4.0, 6.0, 7.0, 8.0]);
2058 }
2059
2060 #[test]
2061 fn test_split_bad_sizes() {
2062 let a = from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
2063 assert!(a.split(&[1, 1], 0).is_err()); // sum != 3
2064 }
2065
2066 // --- split/chunk autograd ---
2067
2068 #[test]
2069 fn test_split_preserves_grad() {
2070 // Split a requires-grad tensor and verify chunks have grad_fn.
2071 let a = Tensor::from_storage(
2072 TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]),
2073 vec![6],
2074 true,
2075 )
2076 .unwrap();
2077 let chunks = a.split(&[2, 4], 0).unwrap();
2078 assert!(chunks[0].grad_fn().is_some(), "chunk 0 should have grad_fn");
2079 assert!(chunks[1].grad_fn().is_some(), "chunk 1 should have grad_fn");
2080 }
2081
2082 #[test]
2083 #[allow(clippy::needless_range_loop)]
2084 fn test_split_backward_simple() {
2085 // x = [1, 2, 3, 4, 5, 6], split into [1,2,3] and [4,5,6].
2086 // loss = sum(chunk0) + 2*sum(chunk1)
2087 // d_loss/d_x = [1, 1, 1, 2, 2, 2]
2088 let x = Tensor::from_storage(
2089 TensorStorage::cpu(vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]),
2090 vec![6],
2091 true,
2092 )
2093 .unwrap();
2094 let chunks = x.split(&[3, 3], 0).unwrap();
2095
2096 let sum0 = crate::grad_fns::reduction::sum(&chunks[0]).unwrap();
2097 let sum1 = crate::grad_fns::reduction::sum(&chunks[1]).unwrap();
2098
2099 // 2 * sum1
2100 let two = Tensor::from_storage(TensorStorage::cpu(vec![2.0f64]), vec![], false).unwrap();
2101 let scaled = crate::grad_fns::arithmetic::mul(&sum1, &two).unwrap();
2102 let loss = crate::grad_fns::arithmetic::add(&sum0, &scaled).unwrap();
2103
2104 loss.backward().unwrap();
2105
2106 let grad = x.grad().unwrap().expect("x should have grad");
2107 assert_eq!(grad.shape(), &[6]);
2108 let g = grad.data().unwrap();
2109 // First 3 elements: grad from sum0 = 1.0 each
2110 // Last 3 elements: grad from 2*sum1 = 2.0 each
2111 for i in 0..3 {
2112 assert!(
2113 (g[i] - 1.0).abs() < 1e-10,
2114 "grad[{i}] = {}, expected 1.0",
2115 g[i]
2116 );
2117 }
2118 for i in 3..6 {
2119 assert!(
2120 (g[i] - 2.0).abs() < 1e-10,
2121 "grad[{i}] = {}, expected 2.0",
2122 g[i]
2123 );
2124 }
2125 }
2126
2127 #[test]
2128 fn test_chunk_backward_2d() {
2129 // x shape [2, 4], chunk into 2 along dim=1 -> two [2, 2] tensors.
2130 // loss = sum(chunk0) * 3 + sum(chunk1)
2131 // grad_x[:, 0:2] = 3, grad_x[:, 2:4] = 1
2132 let x =
2133 Tensor::from_storage(TensorStorage::cpu(vec![1.0f64; 8]), vec![2, 4], true).unwrap();
2134 let chunks = x.chunk(2, 1).unwrap();
2135 assert_eq!(chunks.len(), 2);
2136 assert_eq!(chunks[0].shape(), &[2, 2]);
2137 assert_eq!(chunks[1].shape(), &[2, 2]);
2138
2139 let sum0 = crate::grad_fns::reduction::sum(&chunks[0]).unwrap();
2140 let sum1 = crate::grad_fns::reduction::sum(&chunks[1]).unwrap();
2141
2142 let three = Tensor::from_storage(TensorStorage::cpu(vec![3.0f64]), vec![], false).unwrap();
2143 let scaled = crate::grad_fns::arithmetic::mul(&sum0, &three).unwrap();
2144 let loss = crate::grad_fns::arithmetic::add(&scaled, &sum1).unwrap();
2145 loss.backward().unwrap();
2146
2147 let grad = x.grad().unwrap().expect("x should have grad");
2148 assert_eq!(grad.shape(), &[2, 4]);
2149 let g = grad.data().unwrap();
2150 // Row 0: [3, 3, 1, 1], Row 1: [3, 3, 1, 1]
2151 let expected = [3.0, 3.0, 1.0, 1.0, 3.0, 3.0, 1.0, 1.0];
2152 for (i, (&actual, &exp)) in g.iter().zip(expected.iter()).enumerate() {
2153 assert!(
2154 (actual - exp).abs() < 1e-10,
2155 "grad[{i}] = {actual}, expected {exp}"
2156 );
2157 }
2158 }
2159
2160 #[test]
2161 fn test_split_no_grad_when_disabled() {
2162 let x = Tensor::from_storage(
2163 TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]),
2164 vec![3],
2165 false, // no grad
2166 )
2167 .unwrap();
2168 let chunks = x.split(&[1, 2], 0).unwrap();
2169 assert!(chunks[0].grad_fn().is_none());
2170 assert!(chunks[1].grad_fn().is_none());
2171 }
2172
2173 // --- size / dim aliases ---
2174
2175 #[test]
2176 fn test_size_alias() {
2177 let a = from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
2178 assert_eq!(a.size(), &[2, 3]);
2179 assert_eq!(a.size(), a.shape());
2180 }
2181
2182 #[test]
2183 fn test_dim_alias() {
2184 let a = from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
2185 assert_eq!(a.dim(), 2);
2186 assert_eq!(a.dim(), a.ndim());
2187 }
2188
2189 // --- cumulative (scan) methods ---
2190
2191 #[test]
2192 // reason: cumulative sum of small integer-valued floats (1+2+3 = 6 at
2193 // most) is bit-exact in any deterministic order — the partial sums
2194 // never lose mantissa bits, so equality is the right check. The
2195 // expected values [1, 3, 6] are constructed from the named upstream
2196 // recurrence `out_i = sum_{k=0..=i} input_k` per
2197 // `aten/src/ATen/native/ReduceOps.cpp:511 TORCH_IMPL_FUNC(cumsum_out)`
2198 // and the math definition at `torch/_torch_docs.py:3431-3438
2199 // y_i = x_1 + x_2 + ... + x_i`. The dispatch-correctness assertion
2200 // (method == free function) protects R-DEFER-1 wiring at the
2201 // method boundary.
2202 #[allow(clippy::float_cmp)]
2203 fn test_method_cumsum_t_1d() {
2204 let a = from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
2205
2206 // Expected derived from upstream recurrence:
2207 // out[0] = 1
2208 // out[1] = 1 + 2 = 3
2209 // out[2] = 1 + 2 + 3 = 6
2210 let expected = [1.0f32, 1.0 + 2.0, 1.0 + 2.0 + 3.0];
2211
2212 let via_method = a.cumsum_t(0).unwrap();
2213 assert_eq!(via_method.shape(), &[3]);
2214 let m = via_method.data_vec().unwrap();
2215 for i in 0..3 {
2216 assert_eq!(m[i], expected[i], "method cumsum[{i}] != expected");
2217 }
2218
2219 // Dispatch-correctness: method MUST equal the free function on
2220 // identical input. This is the production-consumer parity check.
2221 let via_free = crate::grad_fns::cumulative::cumsum(&a, 0).unwrap();
2222 let f = via_free.data_vec().unwrap();
2223 for i in 0..3 {
2224 assert_eq!(m[i], f[i], "cumsum_t and free fn disagree at {i}");
2225 }
2226 }
2227
2228 #[test]
2229 // reason: cumprod of small ints (1, 2, 6) is bit-exact in f32 (small
2230 // integer mantissas), so equality is the right check. The expected
2231 // values [1, 2, 6] are constructed from the named upstream recurrence
2232 // `out_i = prod_{k=0..=i} input_k` per `aten/src/ATen/native/
2233 // ReduceOps.cpp:519 TORCH_IMPL_FUNC(cumprod_out)` and the math
2234 // definition at `torch/_torch_docs.py:3392-3399 y_i = x_1 * x_2 *
2235 // ... * x_i`.
2236 #[allow(clippy::float_cmp)]
2237 fn test_method_cumprod_t_1d() {
2238 let a = from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
2239
2240 // Expected derived from upstream recurrence:
2241 // out[0] = 1
2242 // out[1] = 1 * 2 = 2
2243 // out[2] = 1 * 2 * 3 = 6
2244 let expected = [1.0f32, 1.0 * 2.0, 1.0 * 2.0 * 3.0];
2245
2246 let via_method = a.cumprod_t(0).unwrap();
2247 assert_eq!(via_method.shape(), &[3]);
2248 let m = via_method.data_vec().unwrap();
2249 for i in 0..3 {
2250 assert_eq!(m[i], expected[i], "method cumprod[{i}] != expected");
2251 }
2252
2253 // Dispatch-correctness check.
2254 let via_free = crate::grad_fns::cumulative::cumprod(&a, 0).unwrap();
2255 let f = via_free.data_vec().unwrap();
2256 for i in 0..3 {
2257 assert_eq!(m[i], f[i], "cumprod_t and free fn disagree at {i}");
2258 }
2259 }
2260
2261 #[test]
2262 // reason: logcumsumexp on a single-element vector is the identity:
2263 // `log(exp(x)) == x` numerically (one term in the sum). The expected
2264 // value 42.0 is the input value itself, derived from the math
2265 // definition at `torch/_torch_docs.py:3304-3305
2266 // logcumsumexp(x)_ij = log(sum_{k=0..=j} exp(x_ik))` evaluated at
2267 // j=0 (single-element scan). For the 3-element case we also check
2268 // monotonicity (logcumsumexp is non-decreasing along the scan dim
2269 // because the running sum-of-exp is non-decreasing and log is
2270 // monotonic) — verified live 2026-05-25 with torch 2.11.0.
2271 fn test_method_logcumsumexp_t_1d() {
2272 // Single-element: the math identity `log(exp(x)) = x` makes the
2273 // expected value structurally derivable without calling the
2274 // function on itself.
2275 let a = from_slice(&[42.0f32], &[1]).unwrap();
2276 let via_method = a.logcumsumexp_t(0).unwrap();
2277 assert_eq!(via_method.shape(), &[1]);
2278 let m = via_method.data_vec().unwrap();
2279 // logcumsumexp on a single element equals the input. Allow a
2280 // small fp slop because exp/log round-trip is not bit-exact.
2281 assert!(
2282 (m[0] - 42.0_f32).abs() < 1e-3,
2283 "logcumsumexp single-elt: got {} expected 42.0",
2284 m[0]
2285 );
2286
2287 // Dispatch-correctness check on a 3-element input: method MUST
2288 // equal the free function bit-exactly (both go through the same
2289 // forward kernel).
2290 let b = from_slice(&[0.0f32, 1.0, 2.0], &[3]).unwrap();
2291 let via_method = b.logcumsumexp_t(0).unwrap();
2292 let via_free = crate::grad_fns::cumulative::logcumsumexp(&b, 0).unwrap();
2293 let m = via_method.data_vec().unwrap();
2294 let f = via_free.data_vec().unwrap();
2295 for i in 0..3 {
2296 assert!(
2297 (m[i] - f[i]).abs() < 1e-6,
2298 "logcumsumexp_t and free fn disagree at {i}: {} vs {}",
2299 m[i],
2300 f[i]
2301 );
2302 }
2303 // Monotonicity: y_0 <= y_1 <= y_2 (running sum of exp is
2304 // monotonic, log is monotonic).
2305 assert!(m[0] <= m[1], "logcumsumexp not monotonic: m[0]>m[1]");
2306 assert!(m[1] <= m[2], "logcumsumexp not monotonic: m[1]>m[2]");
2307 }
2308}