Skip to main content

cjc_runtime/
tensor.rs

1//! N-dimensional tensor with element-wise, reduction, linalg, and NN operations.
2//!
3//! [`Tensor`] is the primary numerical type in CJC. It is backed by a
4//! [`Buffer<f64>`] with COW (copy-on-write) semantics, so cloning a tensor
5//! is O(1) and mutation triggers a deep copy only when the buffer is shared.
6//!
7//! # Determinism Guarantees
8//!
9//! - **Reductions** (`sum`, `mean`, `sum_axis`) use [`BinnedAccumulatorF64`]
10//!   for order-invariant, bit-identical results.
11//! - **Matmul** uses Kahan-compensated accumulation (sequential path) or
12//!   tiled + parallel strategies (large matrices) with deterministic per-row
13//!   thread assignment.
14//! - **SIMD** kernels (via [`tensor_simd`]) avoid hardware FMA to preserve
15//!   cross-platform bit-identity.
16//! - **No** `HashMap`/`HashSet` anywhere -- all ordering is deterministic.
17//!
18//! # Layout
19//!
20//! Tensors use row-major (C-order) layout with explicit strides. Non-contiguous
21//! views (from `slice`, `transpose`, `broadcast_to`) share the underlying
22//! buffer with adjusted strides and offset. Call [`Tensor::to_contiguous`] to
23//! materialize a contiguous copy when needed.
24//!
25//! # Operation Categories
26//!
27//! | Category | Methods |
28//! |----------|---------|
29//! | Construction | `zeros`, `ones`, `randn`, `from_vec`, `from_bytes` |
30//! | Shape | `shape`, `ndim`, `len`, `reshape`, `transpose`, `slice`, `unsqueeze`, `squeeze`, `flatten` |
31//! | Element-wise | `add`, `sub`, `mul_elem`, `div_elem`, `elem_pow`, `map`, `map_simd` |
32//! | Reductions | `sum`, `mean`, `sum_axis`, `mean_axis`, `var_axis`, `std_axis` |
33//! | Linalg | `matmul`, `bmm`, `linear`, `einsum` |
34//! | NN | `softmax`, `layer_norm`, `relu`, `sigmoid`, `gelu`, `conv1d`, `conv2d`, `maxpool2d` |
35//! | Indexing | `get`, `set`, `gather`, `scatter`, `index_select`, `argsort` |
36//! | Attention | `scaled_dot_product_attention`, `split_heads`, `merge_heads` |
37//!
38//! [`Buffer<f64>`]: crate::buffer::Buffer
39//! [`BinnedAccumulatorF64`]: crate::accumulator::BinnedAccumulatorF64
40//! [`tensor_simd`]: crate::tensor_simd
41
42use cjc_repro::Rng;
43
44use crate::accumulator::{binned_sum_f64, BinnedAccumulatorF64};
45use cjc_repro::KahanAccumulatorF64;
46
47use crate::accumulator;
48use crate::buffer::Buffer;
49use crate::dispatch;
50use crate::error::RuntimeError;
51use crate::kernel as kernel_fns;
52use crate::tensor_simd::{self, BinOp, UnaryOp};
53use crate::tensor_tiled::TiledMatmul;
54
55// ---------------------------------------------------------------------------
56// 2. Tensor Runtime
57// ---------------------------------------------------------------------------
58
59/// An N-dimensional tensor backed by a [`Buffer<f64>`](crate::buffer::Buffer).
60///
61/// Supports element-wise arithmetic (SIMD-accelerated), matrix multiplication
62/// (tiled + parallel), numerically-stable reductions via
63/// [`BinnedAccumulatorF64`](crate::accumulator::BinnedAccumulatorF64),
64/// and neural network operations (softmax, layer norm, attention).
65///
66/// # Memory
67///
68/// The underlying [`Buffer<f64>`](crate::buffer::Buffer) uses copy-on-write
69/// semantics. Cloning a `Tensor` is O(1). Operations like `reshape` and
70/// `transpose` return zero-copy views when possible. Mutation via `set`
71/// triggers a deep copy only when the buffer is shared.
72#[derive(Debug, Clone)]
73pub struct Tensor {
74    /// The underlying COW data buffer.
75    pub buffer: Buffer<f64>,
76    /// Dimensions of the tensor (e.g., `[3, 4]` for a 3x4 matrix).
77    pub(crate) shape: Vec<usize>,
78    /// Row-major strides for each dimension. Non-contiguous views may have
79    /// strides that differ from the standard row-major pattern (e.g., stride=0
80    /// for broadcast dimensions).
81    pub(crate) strides: Vec<usize>,
82    /// Byte offset into the buffer where this tensor's data begins.
83    /// Non-zero for slice views.
84    pub(crate) offset: usize,
85}
86
87impl Tensor {
88    // -- Construction -------------------------------------------------------
89
90    /// Compute row-major strides for a given shape.
91    pub(crate) fn compute_strides(shape: &[usize]) -> Vec<usize> {
92        let mut strides = vec![1usize; shape.len()];
93        for i in (0..shape.len().saturating_sub(1)).rev() {
94            strides[i] = strides[i + 1] * shape[i + 1];
95        }
96        strides
97    }
98
99    /// Total number of elements implied by `shape`.
100    fn shape_numel(shape: &[usize]) -> usize {
101        shape.iter().product()
102    }
103
104    /// Create a tensor filled with zeros.
105    pub fn zeros(shape: &[usize]) -> Self {
106        let numel = Self::shape_numel(shape);
107        Tensor {
108            buffer: Buffer::alloc(numel, 0.0),
109            shape: shape.to_vec(),
110            strides: Self::compute_strides(shape),
111            offset: 0,
112        }
113    }
114
115    /// Create a tensor filled with ones.
116    pub fn ones(shape: &[usize]) -> Self {
117        let numel = Self::shape_numel(shape);
118        Tensor {
119            buffer: Buffer::alloc(numel, 1.0),
120            shape: shape.to_vec(),
121            strides: Self::compute_strides(shape),
122            offset: 0,
123        }
124    }
125
126    /// Create a tensor filled with samples from the standard normal
127    /// distribution, drawn deterministically from `rng`.
128    pub fn randn(shape: &[usize], rng: &mut Rng) -> Self {
129        let numel = Self::shape_numel(shape);
130        let data: Vec<f64> = (0..numel).map(|_| rng.next_normal_f64()).collect();
131        Tensor {
132            buffer: Buffer::from_vec(data),
133            shape: shape.to_vec(),
134            strides: Self::compute_strides(shape),
135            offset: 0,
136        }
137    }
138
139    /// Create a tensor from raw data and a shape. Returns an error if the
140    /// number of elements does not match the shape.
141    pub fn from_vec(data: Vec<f64>, shape: &[usize]) -> Result<Self, RuntimeError> {
142        let numel = Self::shape_numel(shape);
143        if data.len() != numel {
144            return Err(RuntimeError::ShapeMismatch {
145                expected: numel,
146                got: data.len(),
147            });
148        }
149        Ok(Tensor {
150            buffer: Buffer::from_vec(data),
151            shape: shape.to_vec(),
152            strides: Self::compute_strides(shape),
153            offset: 0,
154        })
155    }
156
157    // -- Accessors ----------------------------------------------------------
158
159    /// The shape of this tensor.
160    pub fn shape(&self) -> &[usize] {
161        &self.shape
162    }
163
164    /// Number of dimensions.
165    pub fn ndim(&self) -> usize {
166        self.shape.len()
167    }
168
169    /// Total number of elements.
170    pub fn len(&self) -> usize {
171        Self::shape_numel(&self.shape)
172    }
173
174    /// Whether the tensor has zero elements.
175    pub fn is_empty(&self) -> bool {
176        self.len() == 0
177    }
178
179    /// Flatten a multi-dimensional index to a linear offset in the buffer.
180    fn linear_index(&self, indices: &[usize]) -> Result<usize, RuntimeError> {
181        if indices.len() != self.shape.len() {
182            return Err(RuntimeError::DimensionMismatch {
183                expected: self.shape.len(),
184                got: indices.len(),
185            });
186        }
187        let mut off = self.offset;
188        for (i, &idx) in indices.iter().enumerate() {
189            if idx >= self.shape[i] {
190                return Err(RuntimeError::IndexOutOfBounds {
191                    index: idx,
192                    length: self.shape[i],
193                });
194            }
195            off += idx * self.strides[i];
196        }
197        Ok(off)
198    }
199
200    /// Whether this tensor is contiguous in memory (row-major, no offset).
201    pub fn is_contiguous(&self) -> bool {
202        if self.offset != 0 {
203            return false;
204        }
205        let expected = Self::compute_strides(&self.shape);
206        self.strides == expected
207    }
208
209    /// Create a zero-copy slice (view) of this tensor.
210    /// `ranges` contains `(start, end)` for each dimension.
211    pub fn slice(&self, ranges: &[(usize, usize)]) -> Result<Tensor, RuntimeError> {
212        if ranges.len() != self.shape.len() {
213            return Err(RuntimeError::DimensionMismatch {
214                expected: self.shape.len(),
215                got: ranges.len(),
216            });
217        }
218        let mut new_offset = self.offset;
219        let mut new_shape = Vec::with_capacity(ranges.len());
220        for (i, &(start, end)) in ranges.iter().enumerate() {
221            if end > self.shape[i] || start > end {
222                return Err(RuntimeError::IndexOutOfBounds {
223                    index: end,
224                    length: self.shape[i],
225                });
226            }
227            new_offset += start * self.strides[i];
228            new_shape.push(end - start);
229        }
230        Ok(Tensor {
231            buffer: self.buffer.clone(), // shared — zero copy
232            shape: new_shape,
233            strides: self.strides.clone(),
234            offset: new_offset,
235        })
236    }
237
238    /// Materialize a contiguous copy if this tensor is non-contiguous.
239    pub fn to_contiguous(&self) -> Tensor {
240        if self.is_contiguous() {
241            return self.clone();
242        }
243        let data = self.to_vec();
244        Tensor {
245            buffer: Buffer::from_vec(data),
246            shape: self.shape.clone(),
247            strides: Self::compute_strides(&self.shape),
248            offset: 0,
249        }
250    }
251
252    /// Create a broadcast view of this tensor to `target_shape`.
253    /// Uses stride=0 for dimensions that need broadcasting (size 1 -> target size).
254    pub fn broadcast_to(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
255        let src_ndim = self.shape.len();
256        let tgt_ndim = target_shape.len();
257        if tgt_ndim < src_ndim {
258            return Err(RuntimeError::InvalidOperation(
259                "cannot broadcast to a smaller rank".to_string(),
260            ));
261        }
262        let pad = tgt_ndim - src_ndim;
263        let mut new_strides = vec![0usize; tgt_ndim];
264        for i in 0..tgt_ndim {
265            if i < pad {
266                // Padded dimension: stride = 0 (broadcast)
267                new_strides[i] = 0;
268            } else {
269                let src_i = i - pad;
270                if self.shape[src_i] == target_shape[i] {
271                    new_strides[i] = self.strides[src_i];
272                } else if self.shape[src_i] == 1 {
273                    new_strides[i] = 0; // broadcast
274                } else {
275                    return Err(RuntimeError::ShapeMismatch {
276                        expected: target_shape[i],
277                        got: self.shape[src_i],
278                    });
279                }
280            }
281        }
282        Ok(Tensor {
283            buffer: self.buffer.clone(),
284            shape: target_shape.to_vec(),
285            strides: new_strides,
286            offset: self.offset,
287        })
288    }
289
290    /// Read the element at the given multi-dimensional index.
291    pub fn get(&self, indices: &[usize]) -> Result<f64, RuntimeError> {
292        let offset = self.linear_index(indices)?;
293        self.buffer
294            .get(offset)
295            .ok_or(RuntimeError::IndexOutOfBounds {
296                index: offset,
297                length: self.buffer.len(),
298            })
299    }
300
301    /// Write the element at the given multi-dimensional index.
302    pub fn set(&mut self, indices: &[usize], val: f64) -> Result<(), RuntimeError> {
303        let offset = self.linear_index(indices)?;
304        self.buffer.set(offset, val)
305    }
306
307    /// Extract the raw data as a `Vec<f64>`, respecting strides and offset.
308    pub fn to_vec(&self) -> Vec<f64> {
309        if self.is_contiguous() {
310            let full = self.buffer.borrow_data();
311            let numel = self.len();
312            if full.len() == numel {
313                return full.to_vec();
314            }
315            // Buffer may be larger than the tensor's logical size
316            // (e.g. Scratchpad pre-allocates extra capacity)
317            return full[..numel].to_vec();
318        }
319        // Non-contiguous: iterate via strided access
320        let numel = self.len();
321        let mut result = Vec::with_capacity(numel);
322        let ndim = self.shape.len();
323        let mut indices = vec![0usize; ndim];
324        for _ in 0..numel {
325            let mut off = self.offset;
326            for d in 0..ndim {
327                off += indices[d] * self.strides[d];
328            }
329            result.push(self.buffer.get(off).unwrap_or(0.0));
330            // Increment multi-index (row-major order)
331            for d in (0..ndim).rev() {
332                indices[d] += 1;
333                if indices[d] < self.shape[d] {
334                    break;
335                }
336                indices[d] = 0;
337            }
338        }
339        result
340    }
341
342    // -- Reshape ------------------------------------------------------------
343
344    /// Reshape to `new_shape`. The new shape must have the same total number
345    /// of elements. The returned tensor **shares** the underlying buffer.
346    pub fn reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
347        let new_numel = Self::shape_numel(new_shape);
348        if new_numel != self.len() {
349            return Err(RuntimeError::ShapeMismatch {
350                expected: self.len(),
351                got: new_numel,
352            });
353        }
354        // Reshape requires contiguous data; materialize if needed
355        let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
356        Ok(Tensor {
357            buffer: tensor.buffer,
358            shape: new_shape.to_vec(),
359            strides: Self::compute_strides(new_shape),
360            offset: 0,
361        })
362    }
363
364    // -- Element-wise operations --------------------------------------------
365
366    /// Apply a binary operation element-wise with broadcasting support.
367    fn elementwise_binop(
368        &self,
369        other: &Tensor,
370        op: impl Fn(f64, f64) -> f64,
371    ) -> Result<Tensor, RuntimeError> {
372        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
373            // Fast path: same shape, both contiguous — borrow without cloning
374            let a = self.buffer.borrow_data();
375            let b = other.buffer.borrow_data();
376            let data: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
377            return Ok(Tensor {
378                buffer: Buffer::from_vec(data),
379                shape: self.shape.clone(),
380                strides: Self::compute_strides(&self.shape),
381                offset: 0,
382            });
383        }
384
385        // Broadcasting path: compute result shape
386        let result_shape = Self::broadcast_result_shape(&self.shape, &other.shape)?;
387        let a_broadcast = self.broadcast_to(&result_shape)?;
388        let b_broadcast = other.broadcast_to(&result_shape)?;
389
390        let numel = Self::shape_numel(&result_shape);
391        let ndim = result_shape.len();
392        let mut data = Vec::with_capacity(numel);
393        let mut indices = vec![0usize; ndim];
394
395        for _ in 0..numel {
396            let mut off_a = a_broadcast.offset;
397            let mut off_b = b_broadcast.offset;
398            for d in 0..ndim {
399                off_a += indices[d] * a_broadcast.strides[d];
400                off_b += indices[d] * b_broadcast.strides[d];
401            }
402            let va = a_broadcast.buffer.get(off_a).ok_or_else(|| {
403                RuntimeError::InvalidOperation(format!(
404                    "broadcast binop: left operand index {} out of bounds (buffer len {})",
405                    off_a,
406                    a_broadcast.buffer.len()
407                ))
408            })?;
409            let vb = b_broadcast.buffer.get(off_b).ok_or_else(|| {
410                RuntimeError::InvalidOperation(format!(
411                    "broadcast binop: right operand index {} out of bounds (buffer len {})",
412                    off_b,
413                    b_broadcast.buffer.len()
414                ))
415            })?;
416            data.push(op(va, vb));
417
418            for d in (0..ndim).rev() {
419                indices[d] += 1;
420                if indices[d] < result_shape[d] {
421                    break;
422                }
423                indices[d] = 0;
424            }
425        }
426
427        Ok(Tensor {
428            buffer: Buffer::from_vec(data),
429            shape: result_shape.clone(),
430            strides: Self::compute_strides(&result_shape),
431            offset: 0,
432        })
433    }
434
435    /// Compute the broadcast result shape for two shapes (NumPy rules).
436    fn broadcast_result_shape(a: &[usize], b: &[usize]) -> Result<Vec<usize>, RuntimeError> {
437        let max_ndim = a.len().max(b.len());
438        let mut result = Vec::with_capacity(max_ndim);
439        for i in 0..max_ndim {
440            let da = if i < max_ndim - a.len() { 1 } else { a[i - (max_ndim - a.len())] };
441            let db = if i < max_ndim - b.len() { 1 } else { b[i - (max_ndim - b.len())] };
442            if da == db {
443                result.push(da);
444            } else if da == 1 {
445                result.push(db);
446            } else if db == 1 {
447                result.push(da);
448            } else {
449                return Err(RuntimeError::ShapeMismatch {
450                    expected: da,
451                    got: db,
452                });
453            }
454        }
455        Ok(result)
456    }
457
458    /// SIMD-accelerated element-wise binary operation for known ops.
459    ///
460    /// For same-shape contiguous tensors, uses AVX2 (4-wide f64) when available.
461    /// Falls back to the generic closure path for broadcast cases.
462    fn elementwise_binop_simd(
463        &self,
464        other: &Tensor,
465        op: BinOp,
466        fallback: impl Fn(f64, f64) -> f64,
467    ) -> Result<Tensor, RuntimeError> {
468        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
469            // SIMD fast path: same shape, both contiguous
470            let a = self.buffer.borrow_data();
471            let b = other.buffer.borrow_data();
472            let data = tensor_simd::simd_binop(&a, &b, op);
473            return Ok(Tensor {
474                buffer: Buffer::from_vec(data),
475                shape: self.shape.clone(),
476                strides: Self::compute_strides(&self.shape),
477                offset: 0,
478            });
479        }
480        // Broadcast path: fall through to generic
481        self.elementwise_binop(other, fallback)
482    }
483
484    /// Element-wise addition (SIMD-accelerated for contiguous same-shape tensors).
485    pub fn add(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
486        self.elementwise_binop_simd(other, BinOp::Add, |a, b| a + b)
487    }
488
489    /// Element-wise subtraction (SIMD-accelerated for contiguous same-shape tensors).
490    pub fn sub(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
491        self.elementwise_binop_simd(other, BinOp::Sub, |a, b| a - b)
492    }
493
494    /// Element-wise (Hadamard) multiplication (SIMD-accelerated for contiguous same-shape tensors).
495    pub fn mul_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
496        self.elementwise_binop_simd(other, BinOp::Mul, |a, b| a * b)
497    }
498
499    /// Element-wise division (SIMD-accelerated for contiguous same-shape tensors).
500    pub fn div_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
501        self.elementwise_binop_simd(other, BinOp::Div, |a, b| a / b)
502    }
503
504    /// Fused multiply-add: `self * b + c` element-wise in a single pass.
505    ///
506    /// Eliminates the intermediate tensor that separate mul + add would create.
507    /// Uses software FMA (`a * b + c` with two roundings, not hardware FMA)
508    /// to preserve bit-identity with the non-fused path.
509    pub fn fused_mul_add(&self, b: &Tensor, c: &Tensor) -> Result<Tensor, RuntimeError> {
510        if self.shape != b.shape || self.shape != c.shape {
511            return Err(RuntimeError::InvalidOperation(
512                "broadcast_fma: all three tensors must have the same shape".to_string(),
513            ));
514        }
515        if self.is_contiguous() && b.is_contiguous() && c.is_contiguous() {
516            let a_data = self.buffer.borrow_data();
517            let b_data = b.buffer.borrow_data();
518            let c_data = c.buffer.borrow_data();
519            let n = a_data.len();
520            let mut out = vec![0.0f64; n];
521            // Software FMA: a*b + c (two roundings — NOT hardware FMA which uses one rounding).
522            // This produces identical results to separate broadcast2("mul") + broadcast2("add").
523            for i in 0..n {
524                out[i] = a_data[i] * b_data[i] + c_data[i];
525            }
526            return Ok(Tensor {
527                buffer: Buffer::from_vec(out),
528                shape: self.shape.clone(),
529                strides: Self::compute_strides(&self.shape),
530                offset: 0,
531            });
532        }
533        // Non-contiguous fallback: mul then add
534        let temp = self.mul_elem(b)?;
535        temp.add(c)
536    }
537
538    /// Fused `alpha * self + y` (scalar `alpha`) in one pass — the BLAS axpy,
539    /// ubiquitous in optimizers, EMAs, and linear combinations. Eliminates the
540    /// intermediate tensor that `scalar_mul` + `add` would allocate. Software
541    /// two-rounding (separate mul then add, **no hardware FMA**) so the result
542    /// is bit-identical to `self.scalar_mul(alpha).add(y)`.
543    pub fn fused_axpy(&self, alpha: f64, y: &Tensor) -> Result<Tensor, RuntimeError> {
544        if self.shape != y.shape {
545            return Err(RuntimeError::InvalidOperation(
546                "fused_axpy: self and y must have the same shape".to_string(),
547            ));
548        }
549        if self.is_contiguous() && y.is_contiguous() {
550            let x_data = self.buffer.borrow_data();
551            let y_data = y.buffer.borrow_data();
552            let n = x_data.len();
553            let mut out = vec![0.0f64; n];
554            for i in 0..n {
555                out[i] = alpha * x_data[i] + y_data[i];
556            }
557            return Ok(Tensor {
558                buffer: Buffer::from_vec(out),
559                shape: self.shape.clone(),
560                strides: Self::compute_strides(&self.shape),
561                offset: 0,
562            });
563        }
564        // Non-contiguous fallback: scalar_mul then add.
565        self.scalar_mul(alpha).add(y)
566    }
567
568    /// Fused `self * b - c` in one pass. Eliminates the `mul_elem` intermediate.
569    /// Bit-identical to `self.mul_elem(b).sub(c)` (two roundings, no hardware FMA).
570    pub fn fused_mul_sub(&self, b: &Tensor, c: &Tensor) -> Result<Tensor, RuntimeError> {
571        if self.shape != b.shape || self.shape != c.shape {
572            return Err(RuntimeError::InvalidOperation(
573                "fused_mul_sub: all three tensors must have the same shape".to_string(),
574            ));
575        }
576        if self.is_contiguous() && b.is_contiguous() && c.is_contiguous() {
577            let a_data = self.buffer.borrow_data();
578            let b_data = b.buffer.borrow_data();
579            let c_data = c.buffer.borrow_data();
580            let n = a_data.len();
581            let mut out = vec![0.0f64; n];
582            for i in 0..n {
583                out[i] = a_data[i] * b_data[i] - c_data[i];
584            }
585            return Ok(Tensor {
586                buffer: Buffer::from_vec(out),
587                shape: self.shape.clone(),
588                strides: Self::compute_strides(&self.shape),
589                offset: 0,
590            });
591        }
592        let temp = self.mul_elem(b)?;
593        temp.sub(c)
594    }
595
596    /// Fused `(self - b)^2` in one pass — the squared difference at the heart of
597    /// MSE, variance, and Euclidean distance. Eliminates the `sub` intermediate.
598    /// Bit-identical to `let d = self.sub(b); d.mul_elem(&d)`.
599    pub fn fused_sub_sq(&self, b: &Tensor) -> Result<Tensor, RuntimeError> {
600        if self.shape != b.shape {
601            return Err(RuntimeError::InvalidOperation(
602                "fused_sub_sq: self and b must have the same shape".to_string(),
603            ));
604        }
605        if self.is_contiguous() && b.is_contiguous() {
606            let a_data = self.buffer.borrow_data();
607            let b_data = b.buffer.borrow_data();
608            let n = a_data.len();
609            let mut out = vec![0.0f64; n];
610            for i in 0..n {
611                let d = a_data[i] - b_data[i];
612                out[i] = d * d;
613            }
614            return Ok(Tensor {
615                buffer: Buffer::from_vec(out),
616                shape: self.shape.clone(),
617                strides: Self::compute_strides(&self.shape),
618                offset: 0,
619            });
620        }
621        let d = self.sub(b)?;
622        d.mul_elem(&d)
623    }
624
625    // ── v0.1 Broadcasting: additional element-wise binary ops ──
626
627    /// Element-wise power: `a^b`.
628    pub fn elem_pow(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
629        self.elementwise_binop(other, |a, b| a.powf(b))
630    }
631
632    /// Element-wise minimum.
633    pub fn elem_min(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
634        self.elementwise_binop(other, |a, b| a.min(b))
635    }
636
637    /// Element-wise maximum.
638    pub fn elem_max(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
639        self.elementwise_binop(other, |a, b| a.max(b))
640    }
641
642    /// Element-wise atan2(self, other).
643    pub fn elem_atan2(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
644        self.elementwise_binop(other, |a, b| a.atan2(b))
645    }
646
647    /// Element-wise hypot(self, other).
648    pub fn elem_hypot(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
649        self.elementwise_binop(other, |a, b| a.hypot(b))
650    }
651
652    /// Apply a unary function to every element, returning a new contiguous tensor.
653    pub fn map(&self, f: impl Fn(f64) -> f64) -> Tensor {
654        let data: Vec<f64> = self.to_vec().iter().map(|&x| f(x)).collect();
655        Tensor {
656            buffer: Buffer::from_vec(data),
657            shape: self.shape.clone(),
658            strides: Self::compute_strides(&self.shape),
659            offset: 0,
660        }
661    }
662
663    /// SIMD-accelerated unary map for known operations (sqrt, abs, neg, relu).
664    ///
665    /// Uses AVX2 (4-wide f64) when available, scalar fallback otherwise.
666    /// Bit-identical to `map(f)` for the supported operations.
667    pub fn map_simd(&self, op: UnaryOp) -> Tensor {
668        let src = self.to_vec();
669        let data = tensor_simd::simd_unary(&src, op);
670        Tensor {
671            buffer: Buffer::from_vec(data),
672            shape: self.shape.clone(),
673            strides: Self::compute_strides(&self.shape),
674            offset: 0,
675        }
676    }
677
678    // -- Reductions (using BinnedAccumulator) --------------------------------
679
680    /// Sum of all elements (binned accumulation — order-invariant, deterministic).
681    pub fn sum(&self) -> f64 {
682        let data = self.buffer.borrow_data();
683        binned_sum_f64(&data)
684    }
685
686    /// Sum of all elements using BinnedAccumulator (order-invariant, deterministic).
687    ///
688    /// Bit-identical results regardless of element ordering or reduction schedule.
689    pub fn binned_sum(&self) -> f64 {
690        let data = self.buffer.borrow_data();
691        accumulator::binned_sum_f64(&data)
692    }
693
694    /// Sum with dispatched strategy based on execution context.
695    ///
696    /// Uses Kahan in serial mode, Binned in parallel/@nogc/strict/linalg mode.
697    pub fn dispatched_sum(&self, ctx: &dispatch::ReductionContext) -> f64 {
698        let data = self.buffer.borrow_data();
699        dispatch::dispatch_sum_f64(&data, ctx)
700    }
701
702    /// Mean of all elements (binned sum / count).
703    pub fn mean(&self) -> f64 {
704        let n = self.len();
705        if n == 0 {
706            return 0.0;
707        }
708        self.sum() / n as f64
709    }
710
711    /// Mean with dispatched strategy based on execution context.
712    pub fn dispatched_mean(&self, ctx: &dispatch::ReductionContext) -> f64 {
713        let n = self.len();
714        if n == 0 {
715            return 0.0;
716        }
717        self.dispatched_sum(ctx) / n as f64
718    }
719
720    /// Sum along a specific axis, returning a tensor with that dimension reduced.
721    ///
722    /// Supports N-D tensors. The reduced axis becomes size 1 in the output.
723    /// Uses BinnedAccumulator for order-invariant, deterministic summation.
724    ///
725    /// Examples:
726    /// - 2D [M, N] with axis=0: result [1, N] (sum columns)
727    /// - 2D [M, N] with axis=1: result [M, 1] (sum rows)
728    /// - 3D [A, B, C] with axis=1: result [A, 1, C]
729    pub fn sum_axis(&self, axis: usize) -> Result<Tensor, RuntimeError> {
730        let ndim = self.ndim();
731        if axis >= ndim {
732            return Err(RuntimeError::IndexOutOfBounds {
733                index: axis,
734                length: ndim,
735            });
736        }
737
738        // Build output shape: same as input but with axis dimension = 1
739        let mut out_shape = self.shape.clone();
740        out_shape[axis] = 1;
741        let out_numel = Self::shape_numel(&out_shape);
742        let out_strides = Self::compute_strides(&out_shape);
743
744        let data = self.to_vec();
745        let axis_len = self.shape[axis];
746        let mut result = vec![0.0f64; out_numel];
747
748        // For each output position, sum over the reduced axis with binned accumulation.
749        let mut indices = vec![0usize; ndim];
750        for out_idx in 0..out_numel {
751            // Compute the N-D index from flat output index
752            {
753                let mut remaining = out_idx;
754                for d in 0..ndim {
755                    indices[d] = remaining / out_strides[d];
756                    remaining %= out_strides[d];
757                }
758            }
759
760            let mut acc = BinnedAccumulatorF64::new();
761            for k in 0..axis_len {
762                // Compute input flat index with indices[axis] = k
763                let mut flat = self.offset;
764                for d in 0..ndim {
765                    let idx = if d == axis { k } else { indices[d] };
766                    flat += idx * self.strides[d];
767                }
768                acc.add(data[flat]);
769            }
770            result[out_idx] = acc.finalize();
771        }
772
773        Tensor::from_vec(result, &out_shape)
774    }
775
776    // -- Matrix multiplication (2-D only) -----------------------------------
777
778    /// Negate every element, returning a new tensor.
779    pub fn neg(&self) -> Tensor {
780        self.map(|x| -x)
781    }
782
783    /// Transpose a tensor. For 2-D: swaps rows and columns (zero-copy view).
784    /// For N-D: reverses all axes (zero-copy view).
785    pub fn transpose(&self) -> Tensor {
786        let ndim = self.ndim();
787        if ndim <= 1 {
788            return self.clone();
789        }
790        // Reverse shape and strides — zero-copy view
791        let mut new_shape = self.shape.clone();
792        let mut new_strides = self.strides.clone();
793        new_shape.reverse();
794        new_strides.reverse();
795        Tensor {
796            buffer: self.buffer.clone(), // shared — zero copy
797            shape: new_shape,
798            strides: new_strides,
799            offset: self.offset,
800        }
801    }
802
803    /// Transpose with explicit axis permutation (N-D). Zero-copy view.
804    ///
805    /// `axes` must be a permutation of `[0, 1, ..., ndim-1]`.
806    pub fn transpose_axes(&self, axes: &[usize]) -> Result<Tensor, RuntimeError> {
807        let ndim = self.ndim();
808        if axes.len() != ndim {
809            return Err(RuntimeError::InvalidOperation(
810                format!("transpose_axes: expected {} axes, got {}", ndim, axes.len()),
811            ));
812        }
813        // Validate permutation
814        let mut seen = vec![false; ndim];
815        for &ax in axes {
816            if ax >= ndim {
817                return Err(RuntimeError::IndexOutOfBounds { index: ax, length: ndim });
818            }
819            if seen[ax] {
820                return Err(RuntimeError::InvalidOperation(
821                    format!("transpose_axes: duplicate axis {ax}"),
822                ));
823            }
824            seen[ax] = true;
825        }
826        let new_shape: Vec<usize> = axes.iter().map(|&ax| self.shape[ax]).collect();
827        let new_strides: Vec<usize> = axes.iter().map(|&ax| self.strides[ax]).collect();
828        Ok(Tensor {
829            buffer: self.buffer.clone(),
830            shape: new_shape,
831            strides: new_strides,
832            offset: self.offset,
833        })
834    }
835
836    /// Multiply every element by a scalar, returning a new tensor.
837    pub fn scalar_mul(&self, s: f64) -> Tensor {
838        self.map(|x| x * s)
839    }
840
841    // ── Panicking convenience constructors (used by AD engine) --------
842
843    /// Create a tensor from raw data and shape.
844    /// **Panics** if `data.len()` does not match the shape.
845    pub fn from_vec_unchecked(data: Vec<f64>, shape: &[usize]) -> Tensor {
846        Self::from_vec(data, shape).expect("Tensor::from_vec_unchecked: shape mismatch")
847    }
848
849    /// Element-wise addition. **Panics** on shape mismatch.
850    pub fn add_unchecked(&self, other: &Tensor) -> Tensor {
851        self.add(other).expect("Tensor::add shape mismatch")
852    }
853
854    /// In-place element-wise addition. **Panics** on shape mismatch.
855    ///
856    /// Mutates `self` instead of allocating a new tensor. Uses COW: if the
857    /// underlying buffer is shared, `make_unique()` deep-copies first.
858    pub fn add_assign_unchecked(&mut self, other: &Tensor) {
859        assert_eq!(self.shape, other.shape, "Tensor::add_assign shape mismatch");
860        self.buffer.make_unique();
861        let other_data = other.buffer.borrow_data();
862        let mut self_data = self.buffer.borrow_data_mut();
863        for (a, b) in self_data.iter_mut().zip(other_data.iter()) {
864            *a += b;
865        }
866    }
867
868    /// Element-wise subtraction. **Panics** on shape mismatch.
869    pub fn sub_unchecked(&self, other: &Tensor) -> Tensor {
870        self.sub(other).expect("Tensor::sub shape mismatch")
871    }
872
873    /// Element-wise multiplication. **Panics** on shape mismatch.
874    pub fn mul_elem_unchecked(&self, other: &Tensor) -> Tensor {
875        self.mul_elem(other).expect("Tensor::mul_elem shape mismatch")
876    }
877
878    /// Element-wise division. **Panics** on shape mismatch.
879    pub fn div_elem_unchecked(&self, other: &Tensor) -> Tensor {
880        self.div_elem(other).expect("Tensor::div_elem shape mismatch")
881    }
882
883    /// Matrix multiplication. **Panics** on dimension mismatch.
884    pub fn matmul_unchecked(&self, other: &Tensor) -> Tensor {
885        self.matmul(other).expect("Tensor::matmul dimension mismatch")
886    }
887
888    /// Matrix multiplication for 2-D tensors.
889    ///
890    /// `self` is (M, K), `other` is (K, N) => result is (M, N).
891    pub fn matmul(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
892        if self.ndim() != 2 || other.ndim() != 2 {
893            return Err(RuntimeError::InvalidOperation(
894                "matmul requires 2-D tensors".to_string(),
895            ));
896        }
897        let m = self.shape[0];
898        let k = self.shape[1];
899        let k2 = other.shape[0];
900        let n = other.shape[1];
901        if k != k2 {
902            return Err(RuntimeError::DimensionMismatch {
903                expected: k,
904                got: k2,
905            });
906        }
907
908        let a = self.to_vec();
909        let b = other.to_vec();
910
911        // Parallel path (Mode A): parallelize over output rows when the parallel
912        // feature is enabled and the matrix is large enough (>= 256 in any dim).
913        #[cfg(feature = "parallel")]
914        {
915            if m >= 256 || n >= 256 || k >= 256 {
916                return Self::matmul_parallel_mode_a(&a, &b, m, n, k);
917            }
918        }
919
920        // Tiled path: use L2-friendly tiled matmul for medium-to-large matrices.
921        // Threshold: any dimension >= 64 (the default tile size).
922        // NOTE: tiled path uses naive accumulation (not binned) — different
923        // numerical path for large matrices, but better cache locality.
924        if m >= 64 || n >= 64 || k >= 64 {
925            return Self::matmul_tiled(&a, &b, m, n, k);
926        }
927
928        // Sequential path: single-threaded with binned accumulation.
929        Self::matmul_sequential(&a, &b, m, n, k)
930    }
931
932    /// Sequential matmul (always available, deterministic reference).
933    fn matmul_sequential(
934        a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
935    ) -> Result<Tensor, RuntimeError> {
936        let mut result = vec![0.0f64; m * n];
937        for i in 0..m {
938            for j in 0..n {
939                let mut acc = KahanAccumulatorF64::new();
940                for p in 0..k {
941                    acc.add(a[i * k + p] * b[p * n + j]);
942                }
943                result[i * n + j] = acc.finalize();
944            }
945        }
946        Tensor::from_vec(result, &[m, n])
947    }
948
949    /// Tiled matmul: delegates to `TiledMatmul` for L2-cache-friendly tiling.
950    ///
951    /// Used for medium matrices (any dimension >= 64) where cache locality
952    /// matters but parallel overhead isn't justified. The tiled path uses
953    /// naive accumulation (not binned accumulation), trading a small amount of
954    /// floating-point precision for better cache behavior.
955    fn matmul_tiled(
956        a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
957    ) -> Result<Tensor, RuntimeError> {
958        let engine = TiledMatmul::new();
959        let result = engine.matmul(a, m, k, b, n);
960        Tensor::from_vec(result, &[m, n])
961    }
962
963    /// Parallel matmul Mode A: parallelize over output rows using tiled
964    /// micro-kernels for cache locality.
965    ///
966    /// Deterministic because:
967    /// - Each output row is computed by exactly one thread.
968    /// - Within each row, accumulation uses tiled AXPY (deterministic order).
969    /// - No cross-thread reduction or merge of partial sums.
970    ///
971    /// Uses KahanAccumulatorF64 (lightweight) instead of BinnedAccumulatorF64
972    /// (32KB per accumulator) to avoid massive stack pressure in parallel mode.
973    #[cfg(feature = "parallel")]
974    fn matmul_parallel_mode_a(
975        a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
976    ) -> Result<Tensor, RuntimeError> {
977        use rayon::prelude::*;
978        use cjc_repro::KahanAccumulatorF64;
979
980        // For large matrices, use tiled matmul (sequential but cache-friendly).
981        // The tiling provides far better cache locality than parallel row-wise
982        // with column-strided B access, and avoids the 32KB-per-element
983        // BinnedAccumulator overhead that caused the 128→256 regression.
984        if m >= 512 && n >= 512 {
985            // Only use rayon for very large matrices where thread overhead
986            // is amortized. Split into row-bands, each processed with tiled matmul.
987            // Wrapped in `run_parallel` so the thermal policy throttles
988            // concurrency (band count auto-scales to the live pool's
989            // `current_num_threads()`); the per-band tiled result is identical
990            // regardless of how the rows are banded, so output is bit-identical.
991            let result = crate::runtime_policy::run_parallel(|| {
992                let band_size =
993                    ((m + rayon::current_num_threads() - 1) / rayon::current_num_threads()).max(64);
994                let mut result = vec![0.0f64; m * n];
995                result
996                    .par_chunks_mut(band_size * n)
997                    .enumerate()
998                    .for_each(|(band_idx, band)| {
999                        let i_start = band_idx * band_size;
1000                        let i_end = (i_start + band_size).min(m);
1001                        let band_m = i_end - i_start;
1002                        let a_band = &a[i_start * k..i_end * k];
1003                        let engine = crate::tensor_tiled::TiledMatmul::new();
1004                        let tiled_result = engine.matmul(a_band, band_m, k, b, n);
1005                        band[..band_m * n].copy_from_slice(&tiled_result);
1006                    });
1007                result
1008            });
1009
1010            return Tensor::from_vec(result, &[m, n]);
1011        }
1012
1013        // For medium matrices (256-511), use Kahan accumulator (16 bytes, not 32KB).
1014        // Each output row is independent, so throttling concurrency via
1015        // `run_parallel` leaves the values bit-identical.
1016        let result = crate::runtime_policy::run_parallel(|| {
1017            let mut result = vec![0.0f64; m * n];
1018            result
1019                .par_chunks_mut(n)
1020                .enumerate()
1021                .for_each(|(i, row)| {
1022                    for j in 0..n {
1023                        let mut acc = KahanAccumulatorF64::new();
1024                        for p in 0..k {
1025                            acc.add(a[i * k + p] * b[p * n + j]);
1026                        }
1027                        row[j] = acc.finalize();
1028                    }
1029                });
1030            result
1031        });
1032
1033        Tensor::from_vec(result, &[m, n])
1034    }
1035
1036    // -- Transformer Kernels ------------------------------------------------
1037
1038    /// Batched matrix multiplication.
1039    ///
1040    /// `self` is `[..., M, K]`, `other` is `[..., K, N]` => result is `[..., M, N]`.
1041    /// The batch dimensions must be identical (no broadcast).
1042    /// For 2-D inputs, delegates to `matmul`.
1043    pub fn bmm(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
1044        if self.ndim() < 2 || other.ndim() < 2 {
1045            return Err(RuntimeError::InvalidOperation(
1046                "bmm requires at least 2-D tensors".to_string(),
1047            ));
1048        }
1049        if self.ndim() == 2 && other.ndim() == 2 {
1050            return self.matmul(other);
1051        }
1052        if self.ndim() != other.ndim() {
1053            return Err(RuntimeError::InvalidOperation(
1054                format!(
1055                    "bmm requires same number of dimensions, got {} and {}",
1056                    self.ndim(),
1057                    other.ndim()
1058                ),
1059            ));
1060        }
1061        let nd = self.ndim();
1062        let batch_dims_a = &self.shape[..nd - 2];
1063        let batch_dims_b = &other.shape[..nd - 2];
1064        if batch_dims_a != batch_dims_b {
1065            return Err(RuntimeError::InvalidOperation(
1066                format!(
1067                    "bmm batch dimensions mismatch: {:?} vs {:?}",
1068                    batch_dims_a, batch_dims_b
1069                ),
1070            ));
1071        }
1072        let m = self.shape[nd - 2];
1073        let k = self.shape[nd - 1];
1074        let k2 = other.shape[nd - 2];
1075        let n = other.shape[nd - 1];
1076        if k != k2 {
1077            return Err(RuntimeError::DimensionMismatch {
1078                expected: k,
1079                got: k2,
1080            });
1081        }
1082
1083        let batch_size: usize = batch_dims_a.iter().product();
1084        let a = self.to_vec();
1085        let b = other.to_vec();
1086        let mat_a_stride = m * k;
1087        let mat_b_stride = k * n;
1088        let mat_c_stride = m * n;
1089        let mut result = vec![0.0f64; batch_size * mat_c_stride];
1090
1091        // Helper closure: compute one batch into c_slice
1092        let compute_batch = |batch: usize, c_slice: &mut [f64]| {
1093            let a_slice = &a[batch * mat_a_stride..(batch + 1) * mat_a_stride];
1094            let b_slice = &b[batch * mat_b_stride..(batch + 1) * mat_b_stride];
1095
1096            if m >= 64 || n >= 64 || k >= 64 {
1097                let engine = crate::tensor_tiled::TiledMatmul::new();
1098                let tiled = engine.matmul(a_slice, m, k, b_slice, n);
1099                c_slice.copy_from_slice(&tiled);
1100            } else {
1101                for i in 0..m {
1102                    for j in 0..n {
1103                        let mut acc = KahanAccumulatorF64::new();
1104                        for p in 0..k {
1105                            acc.add(a_slice[i * k + p] * b_slice[p * n + j]);
1106                        }
1107                        c_slice[i * n + j] = acc.finalize();
1108                    }
1109                }
1110            }
1111        };
1112
1113        // Parallel path: parallelize over batches when workload is large enough
1114        #[cfg(feature = "parallel")]
1115        {
1116            if batch_size > 1 && m * k >= 4096 {
1117                use rayon::prelude::*;
1118                // Each batch is independent → throttling concurrency via
1119                // `run_parallel` keeps output bit-identical.
1120                crate::runtime_policy::run_parallel(|| {
1121                    result
1122                        .par_chunks_mut(mat_c_stride)
1123                        .enumerate()
1124                        .for_each(|(batch, c_slice)| {
1125                            compute_batch(batch, c_slice);
1126                        });
1127                });
1128
1129                let mut out_shape = batch_dims_a.to_vec();
1130                out_shape.push(m);
1131                out_shape.push(n);
1132                return Tensor::from_vec(result, &out_shape);
1133            }
1134        }
1135
1136        // Sequential fallback
1137        for batch in 0..batch_size {
1138            let c_off = batch * mat_c_stride;
1139            compute_batch(batch, &mut result[c_off..c_off + mat_c_stride]);
1140        }
1141
1142        let mut out_shape = batch_dims_a.to_vec();
1143        out_shape.push(m);
1144        out_shape.push(n);
1145        Tensor::from_vec(result, &out_shape)
1146    }
1147
1148    /// Softmax along the last dimension (two-pass stable algorithm).
1149    ///
1150    /// Pass 1: find max per row (prevents overflow in exp)
1151    /// Pass 2: compute exp(x - max), accumulate sum, normalize
1152    ///
1153    /// For a tensor of shape `[..., N]`, softmax is applied independently
1154    /// to each length-N slice along the last axis.
1155    pub fn softmax(&self) -> Result<Tensor, RuntimeError> {
1156        if self.ndim() == 0 {
1157            return Err(RuntimeError::InvalidOperation(
1158                "softmax requires at least 1-D tensor".to_string(),
1159            ));
1160        }
1161        // Avoid allocation when tensor is already contiguous and starts at offset 0
1162        let data_ref;
1163        let data_vec;
1164        let data: &[f64] = if self.is_contiguous() && self.offset == 0 {
1165            data_ref = self.buffer.borrow_data();
1166            &data_ref
1167        } else {
1168            data_vec = self.to_vec();
1169            &data_vec
1170        };
1171        let n = *self.shape.last().unwrap(); // last dimension size
1172        let outer: usize = data.len() / n;  // product of all dims except last
1173        let mut result = vec![0.0f64; data.len()];
1174
1175        for row in 0..outer {
1176            let start = row * n;
1177            let end = start + n;
1178            let slice = &data[start..end];
1179
1180            // Pass 1: find max for numerical stability
1181            let mut max_val = f64::NEG_INFINITY;
1182            for &v in slice {
1183                if v > max_val {
1184                    max_val = v;
1185                }
1186            }
1187
1188            // Pass 2: exp(x - max) and accumulate sum
1189            let mut exp_vals = vec![0.0f64; n];
1190            let mut sum = 0.0f64;
1191            let mut comp = 0.0f64; // Kahan compensation
1192            for i in 0..n {
1193                let e = (slice[i] - max_val).exp();
1194                exp_vals[i] = e;
1195                // Kahan summation for the denominator
1196                let y = e - comp;
1197                let t = sum + y;
1198                comp = (t - sum) - y;
1199                sum = t;
1200            }
1201
1202            // Normalize
1203            if sum == 0.0 {
1204                // Degenerate case: all -inf inputs → uniform
1205                let uniform = 1.0 / n as f64;
1206                for i in 0..n {
1207                    result[start + i] = uniform;
1208                }
1209            } else {
1210                for i in 0..n {
1211                    result[start + i] = exp_vals[i] / sum;
1212                }
1213            }
1214        }
1215
1216        Tensor::from_vec(result, &self.shape)
1217    }
1218
1219    /// Layer normalization over the last dimension.
1220    ///
1221    /// For each length-D slice along the last axis:
1222    ///   1. mean = Σx / D  (BinnedAccumulator)
1223    ///   2. var  = Σ(x - mean)² / D  (BinnedAccumulator)
1224    ///   3. normalized = (x - mean) / √(var + eps)
1225    ///   4. output = gamma * normalized + beta
1226    ///
1227    /// `gamma` and `beta` are 1-D tensors of shape `[D]`.
1228    /// `eps` is a small constant (typically 1e-5).
1229    pub fn layer_norm(
1230        &self,
1231        gamma: &Tensor,
1232        beta: &Tensor,
1233        eps: f64,
1234    ) -> Result<Tensor, RuntimeError> {
1235        if self.ndim() == 0 {
1236            return Err(RuntimeError::InvalidOperation(
1237                "layer_norm requires at least 1-D tensor".to_string(),
1238            ));
1239        }
1240        let d = *self.shape.last().unwrap();
1241        if gamma.len() != d || beta.len() != d {
1242            return Err(RuntimeError::InvalidOperation(
1243                format!(
1244                    "layer_norm: gamma/beta length {} must match last dim {}",
1245                    gamma.len(),
1246                    d
1247                ),
1248            ));
1249        }
1250
1251        let data = self.to_vec();
1252        let gamma_data = gamma.to_vec();
1253        let beta_data = beta.to_vec();
1254        let outer = data.len() / d;
1255        let mut result = vec![0.0f64; data.len()];
1256
1257        for row in 0..outer {
1258            let start = row * d;
1259            let slice = &data[start..start + d];
1260
1261            // Pass 1: compute mean via BinnedAccumulator
1262            let mean = binned_sum_f64(slice) / d as f64;
1263
1264            // Pass 2: compute variance via BinnedAccumulator
1265            let diffs: Vec<f64> = slice.iter().map(|&x| {
1266                let diff = x - mean;
1267                diff * diff
1268            }).collect();
1269            let variance = binned_sum_f64(&diffs) / d as f64;
1270
1271            // Normalize, scale, shift
1272            let inv_std = 1.0 / (variance + eps).sqrt();
1273            for i in 0..d {
1274                let normalized = (slice[i] - mean) * inv_std;
1275                result[start + i] = gamma_data[i] * normalized + beta_data[i];
1276            }
1277        }
1278
1279        Tensor::from_vec(result, &self.shape)
1280    }
1281
1282    /// Apply a function element-wise, reusing the buffer when possible (COW).
1283    ///
1284    /// If the tensor is contiguous, starts at offset 0, and has refcount == 1,
1285    /// the data is mutated in place (zero allocations). Otherwise, allocates
1286    /// a new buffer.
1287    fn map_elementwise(&self, f: impl Fn(f64) -> f64) -> Tensor {
1288        if self.is_contiguous() && self.offset == 0 && self.buffer.refcount() == 1 {
1289            // Fast path: mutate in place (COW — we're the sole owner)
1290            let mut data = self.buffer.borrow_data().clone();
1291            for x in data.iter_mut() {
1292                *x = f(*x);
1293            }
1294            Tensor::from_vec(data, &self.shape).unwrap()
1295        } else {
1296            // Fallback: allocate new buffer
1297            let data = self.to_vec();
1298            let result: Vec<f64> = data.iter().map(|&x| f(x)).collect();
1299            Tensor::from_vec(result, &self.shape).unwrap()
1300        }
1301    }
1302
1303    /// ReLU activation: `max(0, x)` element-wise.
1304    pub fn relu(&self) -> Tensor {
1305        self.map_elementwise(|x| if x > 0.0 { x } else { 0.0 })
1306    }
1307
1308    /// Sigmoid activation: 1 / (1 + exp(-x)) element-wise.
1309    pub fn sigmoid(&self) -> Tensor {
1310        self.map_elementwise(|x| 1.0 / (1.0 + (-x).exp()))
1311    }
1312
1313    /// Tanh activation element-wise.
1314    pub fn tanh_activation(&self) -> Tensor {
1315        self.map_elementwise(|x| x.tanh())
1316    }
1317
1318    /// Leaky ReLU activation: max(alpha*x, x) element-wise.
1319    pub fn leaky_relu(&self, alpha: f64) -> Tensor {
1320        self.map_elementwise(move |x| if x > 0.0 { x } else { alpha * x })
1321    }
1322
1323    /// SiLU (Swish) activation: x * sigmoid(x) element-wise.
1324    pub fn silu(&self) -> Tensor {
1325        let data = self.to_vec();
1326        let result: Vec<f64> = data.iter().map(|&x| x / (1.0 + (-x).exp())).collect();
1327        Tensor::from_vec(result, &self.shape).unwrap()
1328    }
1329
1330    /// Mish activation: x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))).
1331    pub fn mish(&self) -> Tensor {
1332        let data = self.to_vec();
1333        let result: Vec<f64> = data.iter().map(|&x| {
1334            let sp = (1.0 + x.exp()).ln();
1335            x * sp.tanh()
1336        }).collect();
1337        Tensor::from_vec(result, &self.shape).unwrap()
1338    }
1339
1340    /// Argmax: index of the maximum element (first occurrence, deterministic).
1341    pub fn argmax(&self) -> usize {
1342        let data = self.to_vec();
1343        let mut best_idx = 0;
1344        let mut best_val = f64::NEG_INFINITY;
1345        for (i, &v) in data.iter().enumerate() {
1346            if v > best_val || (v == best_val && i < best_idx) {
1347                best_val = v;
1348                best_idx = i;
1349            }
1350        }
1351        best_idx
1352    }
1353
1354    /// Argmin: index of the minimum element (first occurrence, deterministic).
1355    pub fn argmin(&self) -> usize {
1356        let data = self.to_vec();
1357        let mut best_idx = 0;
1358        let mut best_val = f64::INFINITY;
1359        for (i, &v) in data.iter().enumerate() {
1360            if v < best_val || (v == best_val && i < best_idx) {
1361                best_val = v;
1362                best_idx = i;
1363            }
1364        }
1365        best_idx
1366    }
1367
1368    /// Clamp all elements to [min, max].
1369    pub fn clamp(&self, min: f64, max: f64) -> Tensor {
1370        let data = self.to_vec();
1371        let result: Vec<f64> = data.iter().map(|&x| x.max(min).min(max)).collect();
1372        Tensor::from_vec(result, &self.shape).unwrap()
1373    }
1374
1375    /// One-hot encoding: given a 1D tensor of integer indices and a depth,
1376    /// returns a 2D tensor of shape [len, depth].
1377    pub fn one_hot(indices: &[usize], depth: usize) -> Result<Tensor, RuntimeError> {
1378        let n = indices.len();
1379        let mut data = vec![0.0; n * depth];
1380        for (i, &idx) in indices.iter().enumerate() {
1381            if idx >= depth {
1382                return Err(RuntimeError::InvalidOperation(format!(
1383                    "one_hot: index {idx} >= depth {depth}"
1384                )));
1385            }
1386            data[i * depth + idx] = 1.0;
1387        }
1388        Tensor::from_vec(data, &[n, depth])
1389    }
1390
1391    // -----------------------------------------------------------------------
1392    // Phase B4: Tensor extensions (cat, stack, topk)
1393    // -----------------------------------------------------------------------
1394
1395    /// Concatenate tensors along existing axis.
1396    pub fn cat(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1397        if tensors.is_empty() {
1398            return Err(RuntimeError::InvalidOperation("cat: no tensors".to_string()));
1399        }
1400        let ndim = tensors[0].ndim();
1401        if axis >= ndim {
1402            return Err(RuntimeError::InvalidOperation(
1403                format!("cat: axis {axis} out of bounds for {ndim}D tensor"),
1404            ));
1405        }
1406        for (i, t) in tensors.iter().enumerate().skip(1) {
1407            if t.ndim() != ndim {
1408                return Err(RuntimeError::InvalidOperation(
1409                    format!("cat: tensor {i} has different ndim"),
1410                ));
1411            }
1412            for d in 0..ndim {
1413                if d != axis && t.shape[d] != tensors[0].shape[d] {
1414                    return Err(RuntimeError::InvalidOperation(
1415                        format!("cat: shape mismatch at dim {d}"),
1416                    ));
1417                }
1418            }
1419        }
1420        let mut out_shape = tensors[0].shape.clone();
1421        for t in tensors.iter().skip(1) {
1422            out_shape[axis] += t.shape[axis];
1423        }
1424        let total = out_shape.iter().product::<usize>();
1425        let mut result = vec![0.0; total];
1426        let mut out_strides = vec![1usize; ndim];
1427        for d in (0..ndim - 1).rev() {
1428            out_strides[d] = out_strides[d + 1] * out_shape[d + 1];
1429        }
1430        let mut offset = 0;
1431        for t in tensors {
1432            let t_data = t.to_vec();
1433            let t_total: usize = t.shape.iter().product();
1434            let mut t_strides = vec![1usize; ndim];
1435            for d in (0..ndim - 1).rev() {
1436                t_strides[d] = t_strides[d + 1] * t.shape[d + 1];
1437            }
1438            for idx in 0..t_total {
1439                let mut remaining = idx;
1440                let mut out_flat = 0;
1441                for d in 0..ndim {
1442                    let coord = remaining / t_strides[d];
1443                    remaining %= t_strides[d];
1444                    let out_coord = if d == axis { coord + offset } else { coord };
1445                    out_flat += out_coord * out_strides[d];
1446                }
1447                result[out_flat] = t_data[idx];
1448            }
1449            offset += t.shape[axis];
1450        }
1451        Tensor::from_vec(result, &out_shape)
1452    }
1453
1454    /// Stack tensors along a new axis.
1455    pub fn stack(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1456        if tensors.is_empty() {
1457            return Err(RuntimeError::InvalidOperation("stack: no tensors".to_string()));
1458        }
1459        let base_shape = &tensors[0].shape;
1460        let ndim = base_shape.len();
1461        if axis > ndim {
1462            return Err(RuntimeError::InvalidOperation(
1463                format!("stack: axis {axis} out of bounds"),
1464            ));
1465        }
1466        for (i, t) in tensors.iter().enumerate().skip(1) {
1467            if &t.shape != base_shape {
1468                return Err(RuntimeError::InvalidOperation(
1469                    format!("stack: tensor {i} shape mismatch"),
1470                ));
1471            }
1472        }
1473        let mut out_shape = Vec::with_capacity(ndim + 1);
1474        for d in 0..axis { out_shape.push(base_shape[d]); }
1475        out_shape.push(tensors.len());
1476        for d in axis..ndim { out_shape.push(base_shape[d]); }
1477        let total: usize = out_shape.iter().product();
1478        let mut result = vec![0.0; total];
1479        let inner_size: usize = base_shape[axis..].iter().product::<usize>().max(1);
1480        let outer_size: usize = base_shape[..axis].iter().product::<usize>().max(1);
1481        for (t_idx, t) in tensors.iter().enumerate() {
1482            let t_data = t.to_vec();
1483            for outer in 0..outer_size {
1484                for inner in 0..inner_size {
1485                    let src = outer * inner_size + inner;
1486                    let dst = outer * (tensors.len() * inner_size) + t_idx * inner_size + inner;
1487                    if src < t_data.len() && dst < result.len() {
1488                        result[dst] = t_data[src];
1489                    }
1490                }
1491            }
1492        }
1493        Tensor::from_vec(result, &out_shape)
1494    }
1495
1496    /// Top-k values and indices (largest k values from flat data).
1497    pub fn topk(&self, k: usize) -> Result<(Tensor, Vec<usize>), RuntimeError> {
1498        let data = self.to_vec();
1499        let n = data.len();
1500        if k > n {
1501            return Err(RuntimeError::InvalidOperation(
1502                format!("topk: k={k} exceeds data length {n}"),
1503            ));
1504        }
1505        let mut indexed: Vec<(usize, f64)> = data.into_iter().enumerate().collect();
1506        indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
1507        let top_k: Vec<(usize, f64)> = indexed[..k].to_vec();
1508        let values: Vec<f64> = top_k.iter().map(|&(_, v)| v).collect();
1509        let indices: Vec<usize> = top_k.iter().map(|&(i, _)| i).collect();
1510        Ok((Tensor::from_vec(values, &[k])?, indices))
1511    }
1512
1513    /// GELU activation (approximate): x * 0.5 * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
1514    pub fn gelu(&self) -> Tensor {
1515        let data = self.to_vec();
1516        let sqrt_2_over_pi = (2.0_f64 / std::f64::consts::PI).sqrt();
1517        let result: Vec<f64> = data.iter().map(|&x| {
1518            let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
1519            0.5 * x * (1.0 + inner.tanh())
1520        }).collect();
1521        Tensor::from_vec(result, &self.shape).unwrap()
1522    }
1523
1524    /// Linear layer: output = input @ weight^T + bias
1525    ///
1526    /// `self` is `[..., in_features]`, `weight` is `[out_features, in_features]`,
1527    /// `bias` is `[out_features]`.
1528    /// Result is `[..., out_features]`.
1529    pub fn linear(
1530        &self,
1531        weight: &Tensor,
1532        bias: &Tensor,
1533    ) -> Result<Tensor, RuntimeError> {
1534        if weight.ndim() != 2 {
1535            return Err(RuntimeError::InvalidOperation(
1536                "linear: weight must be 2-D [out_features, in_features]".to_string(),
1537            ));
1538        }
1539        let out_features = weight.shape[0];
1540        let in_features = weight.shape[1];
1541        let last_dim = *self.shape.last().ok_or_else(|| {
1542            RuntimeError::InvalidOperation("linear: input must be at least 1-D".to_string())
1543        })?;
1544        if last_dim != in_features {
1545            return Err(RuntimeError::DimensionMismatch {
1546                expected: in_features,
1547                got: last_dim,
1548            });
1549        }
1550        if bias.len() != out_features {
1551            return Err(RuntimeError::InvalidOperation(
1552                format!(
1553                    "linear: bias length {} must match out_features {}",
1554                    bias.len(),
1555                    out_features
1556                ),
1557            ));
1558        }
1559
1560        let data = self.to_vec();
1561        let w = weight.to_vec();
1562        let b = bias.to_vec();
1563        let outer = data.len() / in_features;
1564        let mut result = vec![0.0f64; outer * out_features];
1565
1566        for row in 0..outer {
1567            let x_start = row * in_features;
1568            let x_slice = &data[x_start..x_start + in_features];
1569            let y_start = row * out_features;
1570            for j in 0..out_features {
1571                let w_start = j * in_features;
1572                let mut acc = BinnedAccumulatorF64::new();
1573                for p in 0..in_features {
1574                    acc.add(x_slice[p] * w[w_start + p]);
1575                }
1576                result[y_start + j] = acc.finalize() + b[j];
1577            }
1578        }
1579
1580        let mut out_shape = self.shape[..self.shape.len() - 1].to_vec();
1581        out_shape.push(out_features);
1582        Tensor::from_vec(result, &out_shape)
1583    }
1584
1585    /// 1D convolution: signal `[signal_len]` * filters `[out_ch, kernel_size]` + bias
1586    ///
1587    /// Returns `[out_ch, signal_len - kernel_size + 1]` (valid mode, stride=1).
1588    pub fn conv1d(
1589        &self,
1590        filters: &Tensor,
1591        bias: &Tensor,
1592    ) -> Result<Tensor, RuntimeError> {
1593        if self.ndim() != 1 {
1594            return Err(RuntimeError::InvalidOperation(
1595                "conv1d: input must be 1-D [signal_len]".to_string(),
1596            ));
1597        }
1598        if filters.ndim() != 2 {
1599            return Err(RuntimeError::InvalidOperation(
1600                "conv1d: filters must be 2-D [out_channels, kernel_size]".to_string(),
1601            ));
1602        }
1603        let signal_len = self.shape[0];
1604        let out_channels = filters.shape[0];
1605        let kernel_size = filters.shape[1];
1606        if signal_len < kernel_size {
1607            return Err(RuntimeError::InvalidOperation(
1608                format!(
1609                    "conv1d: signal_len {} < kernel_size {}",
1610                    signal_len, kernel_size
1611                ),
1612            ));
1613        }
1614        if bias.len() != out_channels {
1615            return Err(RuntimeError::InvalidOperation(
1616                format!(
1617                    "conv1d: bias length {} must match out_channels {}",
1618                    bias.len(), out_channels
1619                ),
1620            ));
1621        }
1622        let out_len = signal_len - kernel_size + 1;
1623        let s = self.to_vec();
1624        let f = filters.to_vec();
1625        let b = bias.to_vec();
1626        let mut result = vec![0.0; out_channels * out_len];
1627        kernel_fns::conv1d_raw(&s, &f, &b, &mut result, signal_len, out_channels, kernel_size);
1628        Tensor::from_vec(result, &[out_channels, out_len])
1629    }
1630
1631    /// 2D convolution — NCHW layout, valid mode, configurable stride.
1632    ///
1633    /// # Arguments
1634    /// - `self`:    `[N, C_in, H, W]` input tensor
1635    /// - `filters`: `[C_out, C_in, kH, kW]`
1636    /// - `bias`:    `[C_out]`
1637    /// - `stride`:  spatial stride (default 1)
1638    ///
1639    /// # Returns
1640    /// `[N, C_out, H_out, W_out]` where `H_out = (H - kH) / stride + 1`.
1641    ///
1642    /// Uses `BinnedAccumulatorF64` for every dot product — bit-identical results
1643    /// across all runs and hardware configurations.
1644    pub fn conv2d(
1645        &self,
1646        filters: &Tensor,
1647        bias: &Tensor,
1648        stride: usize,
1649    ) -> Result<Tensor, RuntimeError> {
1650        if self.ndim() != 4 {
1651            return Err(RuntimeError::InvalidOperation(
1652                "conv2d: input must be 4-D [N, C_in, H, W]".to_string(),
1653            ));
1654        }
1655        if filters.ndim() != 4 {
1656            return Err(RuntimeError::InvalidOperation(
1657                "conv2d: filters must be 4-D [C_out, C_in, kH, kW]".to_string(),
1658            ));
1659        }
1660        if stride == 0 {
1661            return Err(RuntimeError::InvalidOperation(
1662                "conv2d: stride must be >= 1".to_string(),
1663            ));
1664        }
1665
1666        let n    = self.shape[0];
1667        let c_in = self.shape[1];
1668        let h_in = self.shape[2];
1669        let w_in = self.shape[3];
1670
1671        let c_out      = filters.shape[0];
1672        let c_in_check = filters.shape[1];
1673        let kh         = filters.shape[2];
1674        let kw         = filters.shape[3];
1675
1676        if c_in != c_in_check {
1677            return Err(RuntimeError::InvalidOperation(format!(
1678                "conv2d: input C_in={} does not match filter C_in={}",
1679                c_in, c_in_check
1680            )));
1681        }
1682        if h_in < kh || w_in < kw {
1683            return Err(RuntimeError::InvalidOperation(format!(
1684                "conv2d: input spatial [{}, {}] is smaller than kernel [{}, {}]",
1685                h_in, w_in, kh, kw
1686            )));
1687        }
1688        if bias.len() != c_out {
1689            return Err(RuntimeError::InvalidOperation(format!(
1690                "conv2d: bias length {} must match C_out={}",
1691                bias.len(), c_out
1692            )));
1693        }
1694
1695        let h_out = (h_in - kh) / stride + 1;
1696        let w_out = (w_in - kw) / stride + 1;
1697
1698        let inp = self.to_vec();
1699        let flt = filters.to_vec();
1700        let b   = bias.to_vec();
1701        let mut result = vec![0.0f64; n * c_out * h_out * w_out];
1702
1703        kernel_fns::conv2d_raw(&inp, &flt, &b, &mut result,
1704                           n, c_in, h_in, w_in, c_out, kh, kw, stride);
1705
1706        Tensor::from_vec(result, &[n, c_out, h_out, w_out])
1707    }
1708
1709    /// 2D max-pooling — NCHW layout, non-overlapping windows.
1710    ///
1711    /// - `self`: `[N, C, H, W]`
1712    /// - `ph`, `pw`: pool height/width (stride = window size)
1713    ///
1714    /// Returns `[N, C, H/ph, W/pw]`.
1715    pub fn maxpool2d(&self, ph: usize, pw: usize) -> Result<Tensor, RuntimeError> {
1716        if self.ndim() != 4 {
1717            return Err(RuntimeError::InvalidOperation(
1718                "maxpool2d: input must be 4-D [N, C, H, W]".to_string(),
1719            ));
1720        }
1721        if ph == 0 || pw == 0 {
1722            return Err(RuntimeError::InvalidOperation(
1723                "maxpool2d: pool size must be >= 1".to_string(),
1724            ));
1725        }
1726
1727        let n    = self.shape[0];
1728        let c    = self.shape[1];
1729        let h_in = self.shape[2];
1730        let w_in = self.shape[3];
1731
1732        if h_in < ph || w_in < pw {
1733            return Err(RuntimeError::InvalidOperation(format!(
1734                "maxpool2d: input [{}, {}] smaller than pool [{}, {}]",
1735                h_in, w_in, ph, pw
1736            )));
1737        }
1738
1739        let h_out = h_in / ph;
1740        let w_out = w_in / pw;
1741
1742        let inp = self.to_vec();
1743        let mut result = vec![0.0f64; n * c * h_out * w_out];
1744
1745        kernel_fns::maxpool2d_raw(&inp, &mut result, n, c, h_in, w_in, ph, pw);
1746
1747        Tensor::from_vec(result, &[n, c, h_out, w_out])
1748    }
1749
1750    /// Applies 2-D average pooling over a `[C, H, W]` tensor.
1751    ///
1752    /// # Arguments
1753    ///
1754    /// * `kernel_h` / `kernel_w` - Pooling window size
1755    /// * `stride_h` / `stride_w` - Stride for the pooling window
1756    ///
1757    /// # Returns
1758    ///
1759    /// Tensor of shape `[C, out_h, out_w]` where `out_h = (H - kernel_h) / stride_h + 1`.
1760    ///
1761    /// # Errors
1762    ///
1763    /// Returns an error if the tensor is not 3-D or if kernel/stride produce invalid output dimensions.
1764    pub fn avgpool2d(&self, kernel_h: usize, kernel_w: usize, stride_h: usize, stride_w: usize) -> Result<Tensor, RuntimeError> {
1765        let shape = self.shape();
1766        if shape.len() != 3 {
1767            return Err(RuntimeError::InvalidOperation(format!("avgpool2d requires 3-D [C,H,W], got {:?}", shape)));
1768        }
1769        let (c, h, w) = (shape[0], shape[1], shape[2]);
1770        if kernel_h > h || kernel_w > w {
1771            return Err(RuntimeError::InvalidOperation("avgpool2d: kernel larger than input".into()));
1772        }
1773        let out_h = (h - kernel_h) / stride_h + 1;
1774        let out_w = (w - kernel_w) / stride_w + 1;
1775        let data = self.to_vec();
1776        let mut out = Vec::with_capacity(c * out_h * out_w);
1777        let pool_size = (kernel_h * kernel_w) as f64;
1778
1779        for ch in 0..c {
1780            for oh in 0..out_h {
1781                for ow in 0..out_w {
1782                    let mut sum = 0.0f64;
1783                    for kh in 0..kernel_h {
1784                        for kw in 0..kernel_w {
1785                            let ih = oh * stride_h + kh;
1786                            let iw = ow * stride_w + kw;
1787                            sum += data[ch * h * w + ih * w + iw];
1788                        }
1789                    }
1790                    out.push(sum / pool_size);
1791                }
1792            }
1793        }
1794        Tensor::from_vec(out, &[c, out_h, out_w])
1795    }
1796
1797    /// Scaled dot-product attention (single head).
1798    ///
1799    /// `queries` is `[..., T, d_k]`
1800    /// `keys`    is `[..., S, d_k]`
1801    /// `values`  is `[..., S, d_v]`
1802    ///
1803    /// Computes: softmax(Q × Kᵀ / √d_k) × V
1804    /// Returns `[..., T, d_v]`.
1805    pub fn scaled_dot_product_attention(
1806        queries: &Tensor,
1807        keys: &Tensor,
1808        values: &Tensor,
1809    ) -> Result<Tensor, RuntimeError> {
1810        if queries.ndim() < 2 || keys.ndim() < 2 || values.ndim() < 2 {
1811            return Err(RuntimeError::InvalidOperation(
1812                "attention: Q, K, V must be at least 2-D".to_string(),
1813            ));
1814        }
1815        let nd = queries.ndim();
1816        let d_k = queries.shape[nd - 1];
1817        let scale = 1.0 / (d_k as f64).sqrt();
1818
1819        // Transpose keys: swap last two dims
1820        let keys_t = keys.transpose_last_two()?;
1821
1822        // Q × K^T → [... T, S]
1823        let scores = queries.bmm(&keys_t)?;
1824
1825        // Scale
1826        let scores_scaled = scores.scalar_mul(scale);
1827
1828        // Softmax along last dim
1829        let attn_weights = scores_scaled.softmax()?;
1830
1831        // Attn × V → [... T, d_v]
1832        attn_weights.bmm(values)
1833    }
1834
1835    /// Transpose the last two dimensions of a tensor.
1836    ///
1837    /// `[..., A, B]` → `[..., B, A]`
1838    pub fn transpose_last_two(&self) -> Result<Tensor, RuntimeError> {
1839        if self.ndim() < 2 {
1840            return Err(RuntimeError::InvalidOperation(
1841                "transpose_last_two requires at least 2-D tensor".to_string(),
1842            ));
1843        }
1844        let nd = self.ndim();
1845        let rows = self.shape[nd - 2];
1846        let cols = self.shape[nd - 1];
1847        let data = self.to_vec();
1848        let batch_size: usize = self.shape[..nd - 2].iter().product::<usize>().max(1);
1849        let mat_size = rows * cols;
1850        let mut result = vec![0.0f64; data.len()];
1851
1852        for b in 0..batch_size {
1853            let off = b * mat_size;
1854            for i in 0..rows {
1855                for j in 0..cols {
1856                    result[off + j * rows + i] = data[off + i * cols + j];
1857                }
1858            }
1859        }
1860
1861        let mut out_shape = self.shape.clone();
1862        out_shape[nd - 2] = cols;
1863        out_shape[nd - 1] = rows;
1864        Tensor::from_vec(result, &out_shape)
1865    }
1866
1867    // -- Zero-Copy Weight Mapping -------------------------------------------
1868
1869    /// Create a tensor view from raw bytes — **zero allocation**.
1870    ///
1871    /// Interprets `bytes` as a contiguous block of `f64` (8 bytes each) or
1872    /// `f32` (4 bytes each, promoted to f64) values and maps them into a
1873    /// `Tensor` with the given shape.
1874    ///
1875    /// `dtype` must be `"f64"` or `"f32"`.
1876    ///
1877    /// For f64: bytes.len() must equal shape_numel * 8.
1878    /// For f32: bytes.len() must equal shape_numel * 4.
1879    ///
1880    /// The returned tensor **owns** its buffer (copied from the raw bytes)
1881    /// but performs exactly one allocation for the data vector.
1882    pub fn from_bytes(bytes: &[u8], shape: &[usize], dtype: &str) -> Result<Tensor, RuntimeError> {
1883        let numel = Self::shape_numel(shape);
1884        match dtype {
1885            "f64" => {
1886                let expected = numel * 8;
1887                if bytes.len() != expected {
1888                    return Err(RuntimeError::ShapeMismatch {
1889                        expected,
1890                        got: bytes.len(),
1891                    });
1892                }
1893                let mut data = Vec::with_capacity(numel);
1894                for i in 0..numel {
1895                    let off = i * 8;
1896                    let mut buf = [0u8; 8];
1897                    buf.copy_from_slice(&bytes[off..off + 8]);
1898                    data.push(f64::from_le_bytes(buf));
1899                }
1900                Ok(Tensor {
1901                    buffer: Buffer::from_vec(data),
1902                    shape: shape.to_vec(),
1903                    strides: Self::compute_strides(shape),
1904                    offset: 0,
1905                })
1906            }
1907            "f32" => {
1908                let expected = numel * 4;
1909                if bytes.len() != expected {
1910                    return Err(RuntimeError::ShapeMismatch {
1911                        expected,
1912                        got: bytes.len(),
1913                    });
1914                }
1915                let mut data = Vec::with_capacity(numel);
1916                for i in 0..numel {
1917                    let off = i * 4;
1918                    let mut buf = [0u8; 4];
1919                    buf.copy_from_slice(&bytes[off..off + 4]);
1920                    data.push(f32::from_le_bytes(buf) as f64);
1921                }
1922                Ok(Tensor {
1923                    buffer: Buffer::from_vec(data),
1924                    shape: shape.to_vec(),
1925                    strides: Self::compute_strides(shape),
1926                    offset: 0,
1927                })
1928            }
1929            _ => Err(RuntimeError::InvalidOperation(
1930                format!("from_bytes: unsupported dtype '{}', expected 'f32' or 'f64'", dtype),
1931            )),
1932        }
1933    }
1934
1935    // -- Multi-Head Attention Splitting -------------------------------------
1936
1937    /// Reshape a 3D tensor `[batch, seq, model_dim]` into 4D
1938    /// `[batch, num_heads, seq, head_dim]` by splitting the last dimension.
1939    ///
1940    /// This is a **zero-copy view** — it only changes shape/strides metadata.
1941    /// `model_dim` must be divisible by `num_heads`.
1942    pub fn split_heads(&self, num_heads: usize) -> Result<Tensor, RuntimeError> {
1943        if self.ndim() != 3 {
1944            return Err(RuntimeError::DimensionMismatch {
1945                expected: 3,
1946                got: self.ndim(),
1947            });
1948        }
1949        let batch = self.shape[0];
1950        let seq = self.shape[1];
1951        let model_dim = self.shape[2];
1952        if model_dim % num_heads != 0 {
1953            return Err(RuntimeError::InvalidOperation(
1954                format!(
1955                    "split_heads: model_dim {} not divisible by num_heads {}",
1956                    model_dim, num_heads
1957                ),
1958            ));
1959        }
1960        let head_dim = model_dim / num_heads;
1961        // Need contiguous data for the reshape
1962        let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
1963        // Reshape [B, S, H*D] -> [B, S, H, D] then transpose to [B, H, S, D]
1964        let reshaped = Tensor {
1965            buffer: tensor.buffer.clone(),
1966            shape: vec![batch, seq, num_heads, head_dim],
1967            strides: Self::compute_strides(&[batch, seq, num_heads, head_dim]),
1968            offset: 0,
1969        };
1970        // Transpose dims 1 and 2: [B, S, H, D] -> [B, H, S, D]
1971        // New strides: swap strides[1] and strides[2]
1972        Ok(Tensor {
1973            buffer: reshaped.buffer,
1974            shape: vec![batch, num_heads, seq, head_dim],
1975            strides: vec![
1976                reshaped.strides[0], // batch stride unchanged
1977                reshaped.strides[2], // head stride (was dim 2)
1978                reshaped.strides[1], // seq stride (was dim 1)
1979                reshaped.strides[3], // head_dim stride unchanged
1980            ],
1981            offset: 0,
1982        })
1983    }
1984
1985    /// Merge heads back: reshape 4D `[batch, num_heads, seq, head_dim]` into
1986    /// 3D `[batch, seq, model_dim]`. Materializes if non-contiguous.
1987    pub fn merge_heads(&self) -> Result<Tensor, RuntimeError> {
1988        if self.ndim() != 4 {
1989            return Err(RuntimeError::DimensionMismatch {
1990                expected: 4,
1991                got: self.ndim(),
1992            });
1993        }
1994        let batch = self.shape[0];
1995        let num_heads = self.shape[1];
1996        let seq = self.shape[2];
1997        let head_dim = self.shape[3];
1998        // Need [B, H, S, D] -> [B, S, H, D] -> [B, S, H*D]
1999        // Transpose dims 1 and 2 first
2000        let transposed = Tensor {
2001            buffer: self.buffer.clone(),
2002            shape: vec![batch, seq, num_heads, head_dim],
2003            strides: vec![
2004                self.strides[0],
2005                self.strides[2], // seq stride
2006                self.strides[1], // head stride
2007                self.strides[3],
2008            ],
2009            offset: self.offset,
2010        };
2011        // Materialize contiguous then reshape
2012        let contig = transposed.to_contiguous();
2013        let model_dim = num_heads * head_dim;
2014        Ok(Tensor {
2015            buffer: contig.buffer,
2016            shape: vec![batch, seq, model_dim],
2017            strides: Self::compute_strides(&[batch, seq, model_dim]),
2018            offset: 0,
2019        })
2020    }
2021
2022    /// View-only reshape: reinterpret shape without copying.
2023    /// Only works on contiguous tensors. Falls back to copy if non-contiguous.
2024    pub fn view_reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
2025        self.reshape(new_shape)
2026    }
2027
2028    // -----------------------------------------------------------------------
2029    // Phase C4: Sorting & Tensor Indexing
2030    // -----------------------------------------------------------------------
2031
2032    /// Returns indices that would sort the flattened tensor in ascending order.
2033    /// Uses f64::total_cmp for deterministic ordering of NaN.
2034    pub fn argsort(&self) -> Tensor {
2035        let data = self.to_vec();
2036        let mut indices: Vec<usize> = (0..data.len()).collect();
2037        indices.sort_by(|&a, &b| data[a].total_cmp(&data[b]));
2038        let result: Vec<f64> = indices.iter().map(|&i| i as f64).collect();
2039        Tensor::from_vec_unchecked(result, &[data.len()])
2040    }
2041
2042    /// Gather elements from the tensor along a dimension using index tensor.
2043    /// For 1D: result[i] = self[indices[i]]
2044    /// For 2D dim=0: result[i][j] = self[indices[i][j]][j]
2045    /// For 2D dim=1: result[i][j] = self[i][indices[i][j]]
2046    pub fn gather(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
2047        let data = self.to_vec();
2048        let idx_data = indices.to_vec();
2049        if self.ndim() == 1 {
2050            let mut result = Vec::with_capacity(idx_data.len());
2051            for &idx in &idx_data {
2052                let i = idx as usize;
2053                if i >= data.len() {
2054                    return Err(RuntimeError::InvalidOperation(
2055                        format!("gather: index {} out of bounds for size {}", i, data.len()),
2056                    ));
2057                }
2058                result.push(data[i]);
2059            }
2060            Ok(Tensor::from_vec_unchecked(result, indices.shape()))
2061        } else if self.ndim() == 2 {
2062            let rows = self.shape[0];
2063            let cols = self.shape[1];
2064            let idx_shape = indices.shape();
2065            let out_rows = idx_shape[0];
2066            let out_cols = idx_shape[1];
2067            let mut result = vec![0.0; out_rows * out_cols];
2068            for i in 0..out_rows {
2069                for j in 0..out_cols {
2070                    let idx = idx_data[i * out_cols + j] as usize;
2071                    let val = if dim == 0 {
2072                        if idx >= rows {
2073                            return Err(RuntimeError::InvalidOperation(
2074                                format!("gather dim=0: index {} out of bounds for {} rows", idx, rows),
2075                            ));
2076                        }
2077                        data[idx * cols + j]
2078                    } else {
2079                        if idx >= cols {
2080                            return Err(RuntimeError::InvalidOperation(
2081                                format!("gather dim=1: index {} out of bounds for {} cols", idx, cols),
2082                            ));
2083                        }
2084                        data[i * cols + idx]
2085                    };
2086                    result[i * out_cols + j] = val;
2087                }
2088            }
2089            Ok(Tensor::from_vec_unchecked(result, idx_shape))
2090        } else {
2091            Err(RuntimeError::InvalidOperation(
2092                "gather: only 1D and 2D tensors supported".into(),
2093            ))
2094        }
2095    }
2096
2097    /// Scatter src values into a tensor of given shape at indices along a dimension.
2098    /// For 1D: result[indices[i]] = src[i]
2099    /// For 2D dim=0: result[indices[i][j]][j] = src[i][j]
2100    /// For 2D dim=1: result[i][indices[i][j]] = src[i][j]
2101    pub fn scatter(&self, dim: usize, indices: &Tensor, src: &Tensor) -> Result<Tensor, RuntimeError> {
2102        let mut result = self.to_vec();
2103        let idx_data = indices.to_vec();
2104        let src_data = src.to_vec();
2105        if self.ndim() == 1 {
2106            for (k, &idx) in idx_data.iter().enumerate() {
2107                let i = idx as usize;
2108                if i >= result.len() {
2109                    return Err(RuntimeError::InvalidOperation(
2110                        format!("scatter: index {} out of bounds for size {}", i, result.len()),
2111                    ));
2112                }
2113                result[i] = src_data[k];
2114            }
2115            Ok(Tensor::from_vec_unchecked(result, self.shape()))
2116        } else if self.ndim() == 2 {
2117            let cols = self.shape[1];
2118            let idx_shape = indices.shape();
2119            let out_cols = idx_shape[1];
2120            let out_rows = idx_shape[0];
2121            for i in 0..out_rows {
2122                for j in 0..out_cols {
2123                    let idx = idx_data[i * out_cols + j] as usize;
2124                    let src_val = src_data[i * out_cols + j];
2125                    if dim == 0 {
2126                        if idx >= self.shape[0] {
2127                            return Err(RuntimeError::InvalidOperation(
2128                                format!("scatter dim=0: index {} out of bounds for {} rows", idx, self.shape[0]),
2129                            ));
2130                        }
2131                        result[idx * cols + j] = src_val;
2132                    } else {
2133                        if idx >= cols {
2134                            return Err(RuntimeError::InvalidOperation(
2135                                format!("scatter dim=1: index {} out of bounds for {} cols", idx, cols),
2136                            ));
2137                        }
2138                        result[i * cols + idx] = src_val;
2139                    }
2140                }
2141            }
2142            Ok(Tensor::from_vec_unchecked(result, self.shape()))
2143        } else {
2144            Err(RuntimeError::InvalidOperation(
2145                "scatter: only 1D and 2D tensors supported".into(),
2146            ))
2147        }
2148    }
2149
2150    /// Select slices along a dimension by index.
2151    /// For 2D dim=0: selects rows
2152    /// For 2D dim=1: selects columns
2153    pub fn index_select(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
2154        let data = self.to_vec();
2155        let idx_data = indices.to_vec();
2156        if self.ndim() == 1 {
2157            let mut result = Vec::with_capacity(idx_data.len());
2158            for &idx in &idx_data {
2159                let i = idx as usize;
2160                if i >= data.len() {
2161                    return Err(RuntimeError::InvalidOperation(
2162                        format!("index_select: index {} out of bounds for size {}", i, data.len()),
2163                    ));
2164                }
2165                result.push(data[i]);
2166            }
2167            Ok(Tensor::from_vec_unchecked(result, &[idx_data.len()]))
2168        } else if self.ndim() == 2 {
2169            let rows = self.shape[0];
2170            let cols = self.shape[1];
2171            let n = idx_data.len();
2172            if dim == 0 {
2173                let mut result = Vec::with_capacity(n * cols);
2174                for &idx in &idx_data {
2175                    let i = idx as usize;
2176                    if i >= rows {
2177                        return Err(RuntimeError::InvalidOperation(
2178                            format!("index_select dim=0: index {} out of bounds for {} rows", i, rows),
2179                        ));
2180                    }
2181                    for j in 0..cols {
2182                        result.push(data[i * cols + j]);
2183                    }
2184                }
2185                Ok(Tensor::from_vec_unchecked(result, &[n, cols]))
2186            } else {
2187                let mut result = Vec::with_capacity(rows * n);
2188                for i in 0..rows {
2189                    for &idx in &idx_data {
2190                        let j = idx as usize;
2191                        if j >= cols {
2192                            return Err(RuntimeError::InvalidOperation(
2193                                format!("index_select dim=1: index {} out of bounds for {} cols", j, cols),
2194                            ));
2195                        }
2196                        result.push(data[i * cols + j]);
2197                    }
2198                }
2199                Ok(Tensor::from_vec_unchecked(result, &[rows, n]))
2200            }
2201        } else {
2202            Err(RuntimeError::InvalidOperation(
2203                "index_select: only 1D and 2D tensors supported".into(),
2204            ))
2205        }
2206    }
2207
2208    // -----------------------------------------------------------------------
2209    // Phase 2: Boolean / Masking Ops
2210    // -----------------------------------------------------------------------
2211
2212    /// Element-wise conditional select: `where(condition, other)`.
2213    /// For each element, returns `self[i]` if `condition[i] != 0.0`, else `other[i]`.
2214    pub fn tensor_where(&self, condition: &Tensor, other: &Tensor) -> Result<Tensor, RuntimeError> {
2215        if self.shape() != condition.shape() || self.shape() != other.shape() {
2216            return Err(RuntimeError::InvalidOperation(
2217                format!("where: shape mismatch self={:?} cond={:?} other={:?}",
2218                    self.shape(), condition.shape(), other.shape()),
2219            ));
2220        }
2221        let s = self.to_vec();
2222        let c = condition.to_vec();
2223        let o = other.to_vec();
2224        let result: Vec<f64> = s.iter().zip(c.iter()).zip(o.iter())
2225            .map(|((&sv, &cv), &ov)| if cv != 0.0 { sv } else { ov })
2226            .collect();
2227        Tensor::from_vec(result, self.shape())
2228    }
2229
2230    /// Return `true` if any element is non-zero.
2231    pub fn any(&self) -> bool {
2232        let data = self.to_vec();
2233        data.iter().any(|&x| x != 0.0)
2234    }
2235
2236    /// Return `true` if all elements are non-zero.
2237    pub fn all(&self) -> bool {
2238        let data = self.to_vec();
2239        data.iter().all(|&x| x != 0.0)
2240    }
2241
2242    /// Return a 1-D tensor of flat indices where elements are non-zero.
2243    ///
2244    /// If no elements are non-zero, returns an empty tensor of shape `[0]`.
2245    pub fn nonzero(&self) -> Tensor {
2246        let data = self.to_vec();
2247        let indices: Vec<f64> = data.iter().enumerate()
2248            .filter(|(_, &v)| v != 0.0)
2249            .map(|(i, _)| i as f64)
2250            .collect();
2251        let len = indices.len();
2252        if len == 0 {
2253            Tensor::from_vec(vec![], &[0]).unwrap()
2254        } else {
2255            Tensor::from_vec(indices, &[len]).unwrap()
2256        }
2257    }
2258
2259    /// Fill elements where `mask` is non-zero with `value`.
2260    pub fn masked_fill(&self, mask: &Tensor, value: f64) -> Result<Tensor, RuntimeError> {
2261        if self.shape() != mask.shape() {
2262            return Err(RuntimeError::InvalidOperation(
2263                format!("masked_fill: shape mismatch self={:?} mask={:?}",
2264                    self.shape(), mask.shape()),
2265            ));
2266        }
2267        let data = self.to_vec();
2268        let m = mask.to_vec();
2269        let result: Vec<f64> = data.iter().zip(m.iter())
2270            .map(|(&d, &mv)| if mv != 0.0 { value } else { d })
2271            .collect();
2272        Tensor::from_vec(result, self.shape())
2273    }
2274
2275    // -----------------------------------------------------------------------
2276    // Phase 2: Axis Reductions with keepdim
2277    // -----------------------------------------------------------------------
2278
2279    /// Generic axis reduction using a caller-provided reduction function.
2280    ///
2281    /// Gathers values along the specified axis for each output position
2282    /// and applies `reduce_fn`. If `keepdim` is `true`, the reduced axis
2283    /// retains size 1; otherwise it is removed from the output shape.
2284    fn reduce_axis<F>(&self, axis: usize, keepdim: bool, reduce_fn: F)
2285        -> Result<Tensor, RuntimeError>
2286    where
2287        F: Fn(&[f64]) -> f64,
2288    {
2289        let ndim = self.ndim();
2290        if axis >= ndim {
2291            return Err(RuntimeError::IndexOutOfBounds {
2292                index: axis,
2293                length: ndim,
2294            });
2295        }
2296
2297        let axis_len = self.shape[axis];
2298        // Build output shape
2299        let mut out_shape: Vec<usize> = self.shape.clone();
2300        out_shape[axis] = 1;
2301        let out_numel = Self::shape_numel(&out_shape);
2302        let out_strides = Self::compute_strides(&out_shape);
2303
2304        let data = self.to_vec();
2305        let mut result = Vec::with_capacity(out_numel);
2306        let mut indices = vec![0usize; ndim];
2307
2308        for out_idx in 0..out_numel {
2309            // Compute N-D index from flat output index
2310            {
2311                let mut remaining = out_idx;
2312                for d in 0..ndim {
2313                    indices[d] = remaining / out_strides[d];
2314                    remaining %= out_strides[d];
2315                }
2316            }
2317
2318            // Gather values along the reduction axis
2319            let mut vals = Vec::with_capacity(axis_len);
2320            for k in 0..axis_len {
2321                let mut flat = self.offset;
2322                for d in 0..ndim {
2323                    let idx = if d == axis { k } else { indices[d] };
2324                    flat += idx * self.strides[d];
2325                }
2326                vals.push(data[flat]);
2327            }
2328            result.push(reduce_fn(&vals));
2329        }
2330
2331        let final_shape = if keepdim {
2332            out_shape
2333        } else {
2334            // Remove the axis dimension
2335            let mut s: Vec<usize> = self.shape.iter().enumerate()
2336                .filter(|&(i, _)| i != axis)
2337                .map(|(_, &v)| v)
2338                .collect();
2339            if s.is_empty() {
2340                s.push(1); // scalar result
2341            }
2342            s
2343        };
2344
2345        Tensor::from_vec(result, &final_shape)
2346    }
2347
2348    /// Mean along an axis with optional keepdim.
2349    ///
2350    /// Uses [`BinnedAccumulatorF64`](crate::accumulator::BinnedAccumulatorF64)
2351    /// for deterministic summation before dividing by the axis length.
2352    pub fn mean_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2353        self.reduce_axis(axis, keepdim, |vals| {
2354            let mut acc = BinnedAccumulatorF64::new();
2355            for &v in vals { acc.add(v); }
2356            acc.finalize() / vals.len() as f64
2357        })
2358    }
2359
2360    /// Max along an axis with optional keepdim. Return `(values, indices)`.
2361    ///
2362    /// Ties are broken by choosing the first occurrence (smallest index).
2363    pub fn max_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2364        let ndim = self.ndim();
2365        if axis >= ndim {
2366            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2367        }
2368        let axis_len = self.shape[axis];
2369        let mut out_shape = self.shape.clone();
2370        out_shape[axis] = 1;
2371        let out_numel = Self::shape_numel(&out_shape);
2372        let out_strides = Self::compute_strides(&out_shape);
2373        let data = self.to_vec();
2374        let mut values = Vec::with_capacity(out_numel);
2375        let mut idx_vals = Vec::with_capacity(out_numel);
2376        let mut indices = vec![0usize; ndim];
2377
2378        for out_idx in 0..out_numel {
2379            let mut remaining = out_idx;
2380            for d in 0..ndim {
2381                indices[d] = remaining / out_strides[d];
2382                remaining %= out_strides[d];
2383            }
2384            let mut best_val = f64::NEG_INFINITY;
2385            let mut best_idx = 0usize;
2386            for k in 0..axis_len {
2387                let mut flat = self.offset;
2388                for d in 0..ndim {
2389                    let idx = if d == axis { k } else { indices[d] };
2390                    flat += idx * self.strides[d];
2391                }
2392                let v = data[flat];
2393                if v > best_val {
2394                    best_val = v;
2395                    best_idx = k;
2396                }
2397            }
2398            values.push(best_val);
2399            idx_vals.push(best_idx as f64);
2400        }
2401
2402        let final_shape = if keepdim {
2403            out_shape
2404        } else {
2405            let mut s: Vec<usize> = self.shape.iter().enumerate()
2406                .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2407            if s.is_empty() { s.push(1); }
2408            s
2409        };
2410        Ok((
2411            Tensor::from_vec(values, &final_shape)?,
2412            Tensor::from_vec(idx_vals, &final_shape)?,
2413        ))
2414    }
2415
2416    /// Min along an axis with optional keepdim. Return `(values, indices)`.
2417    ///
2418    /// Ties are broken by choosing the first occurrence (smallest index).
2419    pub fn min_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2420        let ndim = self.ndim();
2421        if axis >= ndim {
2422            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2423        }
2424        let axis_len = self.shape[axis];
2425        let mut out_shape = self.shape.clone();
2426        out_shape[axis] = 1;
2427        let out_numel = Self::shape_numel(&out_shape);
2428        let out_strides = Self::compute_strides(&out_shape);
2429        let data = self.to_vec();
2430        let mut values = Vec::with_capacity(out_numel);
2431        let mut idx_vals = Vec::with_capacity(out_numel);
2432        let mut indices = vec![0usize; ndim];
2433
2434        for out_idx in 0..out_numel {
2435            let mut remaining = out_idx;
2436            for d in 0..ndim {
2437                indices[d] = remaining / out_strides[d];
2438                remaining %= out_strides[d];
2439            }
2440            let mut best_val = f64::INFINITY;
2441            let mut best_idx = 0usize;
2442            for k in 0..axis_len {
2443                let mut flat = self.offset;
2444                for d in 0..ndim {
2445                    let idx = if d == axis { k } else { indices[d] };
2446                    flat += idx * self.strides[d];
2447                }
2448                let v = data[flat];
2449                if v < best_val {
2450                    best_val = v;
2451                    best_idx = k;
2452                }
2453            }
2454            values.push(best_val);
2455            idx_vals.push(best_idx as f64);
2456        }
2457
2458        let final_shape = if keepdim {
2459            out_shape
2460        } else {
2461            let mut s: Vec<usize> = self.shape.iter().enumerate()
2462                .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2463            if s.is_empty() { s.push(1); }
2464            s
2465        };
2466        Ok((
2467            Tensor::from_vec(values, &final_shape)?,
2468            Tensor::from_vec(idx_vals, &final_shape)?,
2469        ))
2470    }
2471
2472    /// Variance along an axis with optional keepdim.
2473    ///
2474    /// Computes population variance: `Var = sum((x - mean)^2) / N`.
2475    /// Uses [`BinnedAccumulatorF64`](crate::accumulator::BinnedAccumulatorF64)
2476    /// for the squared-differences summation.
2477    pub fn var_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2478        let mean_t = self.mean_axis(axis, true)?;
2479        let ndim = self.ndim();
2480        if axis >= ndim {
2481            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2482        }
2483        let axis_len = self.shape[axis];
2484        let mut out_shape = self.shape.clone();
2485        out_shape[axis] = 1;
2486        let out_numel = Self::shape_numel(&out_shape);
2487        let out_strides = Self::compute_strides(&out_shape);
2488        let data = self.to_vec();
2489        let mean_data = mean_t.to_vec();
2490        let mut result = Vec::with_capacity(out_numel);
2491        let mut indices = vec![0usize; ndim];
2492
2493        for out_idx in 0..out_numel {
2494            let mut remaining = out_idx;
2495            for d in 0..ndim {
2496                indices[d] = remaining / out_strides[d];
2497                remaining %= out_strides[d];
2498            }
2499            let mu = mean_data[out_idx];
2500            let mut acc = BinnedAccumulatorF64::new();
2501            for k in 0..axis_len {
2502                let mut flat = self.offset;
2503                for d in 0..ndim {
2504                    let idx = if d == axis { k } else { indices[d] };
2505                    flat += idx * self.strides[d];
2506                }
2507                let diff = data[flat] - mu;
2508                acc.add(diff * diff);
2509            }
2510            result.push(acc.finalize() / axis_len as f64);
2511        }
2512
2513        let final_shape = if keepdim {
2514            out_shape
2515        } else {
2516            let mut s: Vec<usize> = self.shape.iter().enumerate()
2517                .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2518            if s.is_empty() { s.push(1); }
2519            s
2520        };
2521        Tensor::from_vec(result, &final_shape)
2522    }
2523
2524    /// Standard deviation along an axis with optional keepdim.
2525    ///
2526    /// Computed as `sqrt(var_axis(axis, keepdim))`.
2527    pub fn std_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2528        let var = self.var_axis(axis, keepdim)?;
2529        Ok(var.map(|x| x.sqrt()))
2530    }
2531
2532    /// Product along an axis with optional keepdim.
2533    ///
2534    /// Computes a simple sequential product (exact for integer-like values).
2535    pub fn prod_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2536        self.reduce_axis(axis, keepdim, |vals| {
2537            // Product via exp(sum(ln(abs))) for numerical stability is overkill here;
2538            // simple product is deterministic and exact for integer-like values.
2539            let mut product = 1.0f64;
2540            for &v in vals { product *= v; }
2541            product
2542        })
2543    }
2544
2545    // -----------------------------------------------------------------------
2546    // Phase 2: Sort Operations
2547    // -----------------------------------------------------------------------
2548
2549    /// Sort along an axis (stable sort). Return the sorted tensor.
2550    ///
2551    /// For N-D tensors, independently sorts each 1-D slice along the specified
2552    /// axis. Uses `f64::partial_cmp` with deterministic tie-breaking by
2553    /// original index position.
2554    pub fn sort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2555        let ndim = self.ndim();
2556        if axis >= ndim {
2557            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2558        }
2559        let data = self.to_vec();
2560        let axis_len = self.shape[axis];
2561        let out_shape = self.shape.clone();
2562        let out_numel = Self::shape_numel(&out_shape);
2563
2564        // Build strides for iterating over all non-axis positions
2565        let mut iter_shape: Vec<usize> = Vec::new();
2566        for (i, &s) in self.shape.iter().enumerate() {
2567            if i != axis { iter_shape.push(s); }
2568        }
2569        let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2570
2571        let mut result = vec![0.0f64; out_numel];
2572
2573        // We iterate over all positions with axis index = 0
2574        let mut pos = vec![0usize; ndim];
2575        for slice_idx in 0..n_slices {
2576            // Compute the N-D position (with axis dim = 0)
2577            let mut remaining = slice_idx;
2578            let mut dim_idx = 0;
2579            for d in 0..ndim {
2580                if d == axis {
2581                    pos[d] = 0;
2582                } else {
2583                    let stride = {
2584                        let mut s = 1usize;
2585                        let mut di = 0;
2586                        for d2 in 0..ndim {
2587                            if d2 == axis { continue; }
2588                            if di > dim_idx { s *= self.shape[d2]; }
2589                            di += 1;
2590                        }
2591                        s
2592                    };
2593                    pos[d] = remaining / stride;
2594                    remaining %= stride;
2595                    dim_idx += 1;
2596                }
2597            }
2598
2599            // Gather values along axis
2600            let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2601            for k in 0..axis_len {
2602                let mut flat = self.offset;
2603                for d in 0..ndim {
2604                    let idx = if d == axis { k } else { pos[d] };
2605                    flat += idx * self.strides[d];
2606                }
2607                vals.push((data[flat], k));
2608            }
2609
2610            // Stable sort with deterministic tie-breaking by original index
2611            if descending {
2612                vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2613                    .then(a.1.cmp(&b.1)));
2614            } else {
2615                vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2616                    .then(a.1.cmp(&b.1)));
2617            }
2618
2619            // Scatter back
2620            for (k, &(v, _)) in vals.iter().enumerate() {
2621                let mut flat = 0;
2622                let out_strides_local = Self::compute_strides(&out_shape);
2623                for d in 0..ndim {
2624                    let idx = if d == axis { k } else { pos[d] };
2625                    flat += idx * out_strides_local[d];
2626                }
2627                result[flat] = v;
2628            }
2629        }
2630
2631        Tensor::from_vec(result, &out_shape)
2632    }
2633
2634    /// N-D argsort along an axis. Return a tensor of indices that would sort
2635    /// each slice along the given axis.
2636    ///
2637    /// Deterministic tie-breaking: ties are resolved by original index order.
2638    pub fn argsort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2639        let ndim = self.ndim();
2640        if axis >= ndim {
2641            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2642        }
2643        let data = self.to_vec();
2644        let axis_len = self.shape[axis];
2645        let out_shape = self.shape.clone();
2646        let out_numel = Self::shape_numel(&out_shape);
2647
2648        let mut iter_shape: Vec<usize> = Vec::new();
2649        for (i, &s) in self.shape.iter().enumerate() {
2650            if i != axis { iter_shape.push(s); }
2651        }
2652        let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2653
2654        let mut result = vec![0.0f64; out_numel];
2655        let mut pos = vec![0usize; ndim];
2656
2657        for slice_idx in 0..n_slices {
2658            let mut remaining = slice_idx;
2659            let mut dim_idx = 0;
2660            for d in 0..ndim {
2661                if d == axis {
2662                    pos[d] = 0;
2663                } else {
2664                    let stride = {
2665                        let mut s = 1usize;
2666                        let mut di = 0;
2667                        for d2 in 0..ndim {
2668                            if d2 == axis { continue; }
2669                            if di > dim_idx { s *= self.shape[d2]; }
2670                            di += 1;
2671                        }
2672                        s
2673                    };
2674                    pos[d] = remaining / stride;
2675                    remaining %= stride;
2676                    dim_idx += 1;
2677                }
2678            }
2679
2680            let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2681            for k in 0..axis_len {
2682                let mut flat = self.offset;
2683                for d in 0..ndim {
2684                    let idx = if d == axis { k } else { pos[d] };
2685                    flat += idx * self.strides[d];
2686                }
2687                vals.push((data[flat], k));
2688            }
2689
2690            if descending {
2691                vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2692                    .then(a.1.cmp(&b.1)));
2693            } else {
2694                vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2695                    .then(a.1.cmp(&b.1)));
2696            }
2697
2698            for (k, &(_, orig_idx)) in vals.iter().enumerate() {
2699                let out_strides_local = Self::compute_strides(&out_shape);
2700                let mut flat = 0;
2701                for d in 0..ndim {
2702                    let idx = if d == axis { k } else { pos[d] };
2703                    flat += idx * out_strides_local[d];
2704                }
2705                result[flat] = orig_idx as f64;
2706            }
2707        }
2708
2709        Tensor::from_vec(result, &out_shape)
2710    }
2711
2712    // -----------------------------------------------------------------------
2713    // Phase 2: Einsum
2714    // -----------------------------------------------------------------------
2715
2716    /// Einstein summation notation.
2717    /// Supports patterns like "ij,jk->ik" (matmul), "ii->i" (diagonal),
2718    /// "ij->ji" (transpose), "ijk,ikl->ijl" (batched matmul).
2719    /// Uses BinnedAccumulator for all reductions.
2720    pub fn einsum(notation: &str, inputs: &[&Tensor]) -> Result<Tensor, RuntimeError> {
2721        // Parse notation: "subscripts->output" or just "subscripts" (implicit)
2722        let parts: Vec<&str> = notation.split("->").collect();
2723        if parts.len() != 2 {
2724            return Err(RuntimeError::InvalidOperation(
2725                format!("einsum: expected 'subscripts->output' notation, got '{}'", notation),
2726            ));
2727        }
2728        let input_specs: Vec<&str> = parts[0].split(',').collect();
2729        let output_spec = parts[1];
2730
2731        if input_specs.len() != inputs.len() {
2732            return Err(RuntimeError::InvalidOperation(
2733                format!("einsum: {} input specs but {} tensors", input_specs.len(), inputs.len()),
2734            ));
2735        }
2736
2737        // Build label → size mapping
2738        let mut label_size = std::collections::BTreeMap::new();
2739        for (i, &spec) in input_specs.iter().enumerate() {
2740            let chars: Vec<char> = spec.chars().collect();
2741            if chars.len() != inputs[i].ndim() {
2742                return Err(RuntimeError::InvalidOperation(
2743                    format!("einsum: spec '{}' has {} dims but tensor has {}", spec, chars.len(), inputs[i].ndim()),
2744                ));
2745            }
2746            for (d, &c) in chars.iter().enumerate() {
2747                let sz = inputs[i].shape()[d];
2748                if let Some(&prev) = label_size.get(&c) {
2749                    if prev != sz {
2750                        return Err(RuntimeError::InvalidOperation(
2751                            format!("einsum: label '{}' has conflicting sizes {} vs {}", c, prev, sz),
2752                        ));
2753                    }
2754                } else {
2755                    label_size.insert(c, sz);
2756                }
2757            }
2758        }
2759
2760        // Determine output shape
2761        let output_chars: Vec<char> = output_spec.chars().collect();
2762        let output_shape: Vec<usize> = output_chars.iter()
2763            .map(|c| label_size.get(c).copied().ok_or_else(||
2764                RuntimeError::InvalidOperation(format!("einsum: unknown label '{}' in output", c))))
2765            .collect::<Result<_, _>>()?;
2766        let out_numel = Self::shape_numel(&output_shape);
2767
2768        // Determine contraction labels (in input but not output)
2769        let output_set: std::collections::BTreeSet<char> = output_chars.iter().copied().collect();
2770        let contract_labels: Vec<char> = label_size.keys()
2771            .filter(|c| !output_set.contains(c))
2772            .copied()
2773            .collect();
2774        let contract_sizes: Vec<usize> = contract_labels.iter()
2775            .map(|c| label_size[c])
2776            .collect();
2777        let contract_numel: usize = contract_sizes.iter().product::<usize>().max(1);
2778
2779        // Precompute input spec chars
2780        let input_chars: Vec<Vec<char>> = input_specs.iter().map(|s| s.chars().collect()).collect();
2781
2782        // For each output position, iterate over contraction indices
2783        let out_strides = Self::compute_strides(&output_shape);
2784        let mut result = vec![0.0f64; out_numel];
2785
2786        // Pre-read input data
2787        let input_data: Vec<Vec<f64>> = inputs.iter().map(|t| t.to_vec()).collect();
2788        let input_strides: Vec<Vec<usize>> = inputs.iter().map(|t| t.strides.clone()).collect();
2789        let input_offsets: Vec<usize> = inputs.iter().map(|t| t.offset).collect();
2790
2791        for out_idx in 0..out_numel {
2792            // Compute output label values
2793            let mut label_vals = std::collections::BTreeMap::new();
2794            let mut remaining = out_idx;
2795            for (d, &c) in output_chars.iter().enumerate() {
2796                let stride = if d < out_strides.len() { out_strides[d] } else { 1 };
2797                label_vals.insert(c, remaining / stride);
2798                remaining %= stride;
2799            }
2800
2801            let mut acc = BinnedAccumulatorF64::new();
2802            // Iterate over all contraction index combinations
2803            for cidx in 0..contract_numel {
2804                // Compute contraction label values
2805                let mut cr = cidx;
2806                for (ci, &cl) in contract_labels.iter().enumerate() {
2807                    let stride: usize = contract_sizes[ci+1..].iter().product::<usize>().max(1);
2808                    label_vals.insert(cl, cr / stride);
2809                    cr %= stride;
2810                }
2811
2812                // Compute product of input elements
2813                let mut product = 1.0f64;
2814                for (inp_idx, chars) in input_chars.iter().enumerate() {
2815                    let mut flat = input_offsets[inp_idx];
2816                    for (d, &c) in chars.iter().enumerate() {
2817                        flat += label_vals[&c] * input_strides[inp_idx][d];
2818                    }
2819                    product *= input_data[inp_idx][flat];
2820                }
2821                acc.add(product);
2822            }
2823            result[out_idx] = acc.finalize();
2824        }
2825
2826        if output_shape.is_empty() {
2827            Tensor::from_vec(result, &[1])
2828        } else {
2829            Tensor::from_vec(result, &output_shape)
2830        }
2831    }
2832
2833    // -----------------------------------------------------------------------
2834    // Phase 2: Reshape / View Enhancements
2835    // -----------------------------------------------------------------------
2836
2837    /// Add a dimension of size 1 at position `dim`.
2838    ///
2839    /// For a tensor of shape `[A, B]`, `unsqueeze(0)` yields `[1, A, B]`,
2840    /// `unsqueeze(1)` yields `[A, 1, B]`, etc.
2841    pub fn unsqueeze(&self, dim: usize) -> Result<Tensor, RuntimeError> {
2842        let ndim = self.ndim();
2843        if dim > ndim {
2844            return Err(RuntimeError::IndexOutOfBounds { index: dim, length: ndim + 1 });
2845        }
2846        let mut new_shape = self.shape.clone();
2847        new_shape.insert(dim, 1);
2848        self.reshape(&new_shape)
2849    }
2850
2851    /// Remove a dimension of size 1 at position `dim`.
2852    /// If `dim` is `None`, removes all dimensions of size 1.
2853    pub fn squeeze(&self, dim: Option<usize>) -> Result<Tensor, RuntimeError> {
2854        match dim {
2855            Some(d) => {
2856                if d >= self.ndim() {
2857                    return Err(RuntimeError::IndexOutOfBounds { index: d, length: self.ndim() });
2858                }
2859                if self.shape[d] != 1 {
2860                    return Err(RuntimeError::InvalidOperation(
2861                        format!("squeeze: dimension {} has size {}, not 1", d, self.shape[d]),
2862                    ));
2863                }
2864                let mut new_shape = self.shape.clone();
2865                new_shape.remove(d);
2866                if new_shape.is_empty() {
2867                    new_shape.push(1); // scalar
2868                }
2869                self.reshape(&new_shape)
2870            }
2871            None => {
2872                let new_shape: Vec<usize> = self.shape.iter()
2873                    .filter(|&&s| s != 1)
2874                    .copied()
2875                    .collect();
2876                let new_shape = if new_shape.is_empty() { vec![1] } else { new_shape };
2877                self.reshape(&new_shape)
2878            }
2879        }
2880    }
2881
2882    /// Broadcast without copying. Return a view with `stride=0` for broadcasted dims.
2883    ///
2884    /// Alias for [`broadcast_to`](Tensor::broadcast_to).
2885    pub fn expand(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
2886        self.broadcast_to(target_shape)
2887    }
2888
2889    /// Flatten a range of dimensions [start_dim, end_dim] into a single dimension.
2890    pub fn flatten(&self, start_dim: usize, end_dim: usize) -> Result<Tensor, RuntimeError> {
2891        if start_dim > end_dim || end_dim >= self.ndim() {
2892            return Err(RuntimeError::InvalidOperation(
2893                format!("flatten: invalid dim range [{}, {}] for {}D tensor", start_dim, end_dim, self.ndim()),
2894            ));
2895        }
2896        let mut new_shape = Vec::new();
2897        for i in 0..start_dim {
2898            new_shape.push(self.shape[i]);
2899        }
2900        let flat_size: usize = self.shape[start_dim..=end_dim].iter().product();
2901        new_shape.push(flat_size);
2902        for i in (end_dim + 1)..self.ndim() {
2903            new_shape.push(self.shape[i]);
2904        }
2905        self.reshape(&new_shape)
2906    }
2907
2908    /// Split tensor into `n` roughly equal chunks along dimension `dim`.
2909    pub fn chunk(&self, n: usize, dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2910        if dim >= self.ndim() {
2911            return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2912        }
2913        if n == 0 {
2914            return Err(RuntimeError::InvalidOperation("chunk: n must be > 0".into()));
2915        }
2916        let dim_size = self.shape[dim];
2917        let chunk_size = (dim_size + n - 1) / n;
2918        let mut sizes = Vec::new();
2919        let mut remaining = dim_size;
2920        while remaining > 0 {
2921            let s = remaining.min(chunk_size);
2922            sizes.push(s);
2923            remaining -= s;
2924        }
2925        self.split(&sizes, dim)
2926    }
2927
2928    /// Split tensor along dimension `dim` according to the given sizes.
2929    pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2930        if dim >= self.ndim() {
2931            return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2932        }
2933        let total: usize = sizes.iter().sum();
2934        if total != self.shape[dim] {
2935            return Err(RuntimeError::InvalidOperation(
2936                format!("split: sizes sum {} != dim size {}", total, self.shape[dim]),
2937            ));
2938        }
2939
2940        let mut results = Vec::new();
2941        let mut offset = 0;
2942
2943        for &sz in sizes {
2944            let ranges: Vec<(usize, usize)> = self.shape.iter()
2945                .enumerate()
2946                .map(|(i, &s)| {
2947                    if i == dim { (offset, offset + sz) } else { (0, s) }
2948                })
2949                .collect();
2950            let chunk = self.slice(&ranges)?;
2951            // Materialize as contiguous
2952            results.push(chunk.to_contiguous());
2953            offset += sz;
2954        }
2955
2956        Ok(results)
2957    }
2958
2959    /// Fused `alpha * self + beta * other` element-wise. Single pass, one allocation.
2960    ///
2961    /// Critical for LSTM/GRU gates where `f * c_prev + i * g` would otherwise
2962    /// create 3 intermediate tensors.
2963    pub fn scale_add(&self, alpha: f64, other: &Tensor, beta: f64) -> Result<Tensor, RuntimeError> {
2964        if self.shape != other.shape {
2965            return Err(RuntimeError::InvalidOperation(
2966                "scale_add: shape mismatch".to_string(),
2967            ));
2968        }
2969        let a = self.to_vec();
2970        let b = other.to_vec();
2971        let result: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| alpha * x + beta * y).collect();
2972        Tensor::from_vec(result, &self.shape)
2973    }
2974}
2975