Skip to main content

god_graph/tensor/
dense.rs

1//! 密集张量实现
2//!
3//! 基于 ndarray 的 N 维密集张量,支持 BLAS 加速
4
5use core::fmt;
6
7#[cfg(feature = "tensor")]
8use ndarray::Array2;
9
10use crate::tensor::error::TensorError;
11use crate::tensor::traits::{DType, Device, TensorBase, TensorOps};
12
13/// 密集张量:N 维数组的高性能实现
14///
15/// 使用 64 字节对齐,支持 BLAS 加速的矩阵运算
16#[derive(Clone, PartialEq)]
17pub struct DenseTensor {
18    /// 张量数据(64 字节对齐)
19    data: Vec<f64>,
20    /// 张量形状
21    shape: Vec<usize>,
22    ///  strides(跨步)
23    strides: Vec<usize>,
24    /// 数据类型
25    dtype: DType,
26    /// 设备类型
27    device: Device,
28}
29
30#[cfg(feature = "tensor")]
31impl DenseTensor {
32    /// 获取字节大小
33    pub fn nbytes(&self) -> usize {
34        self.data.len() * self.dtype.size_bytes()
35    }
36
37    /// 检查是否连续存储
38    pub fn is_contiguous(&self) -> bool {
39        self.is_c_contiguous()
40    }
41
42    /// 获取对齐字节数
43    pub fn alignment(&self) -> usize {
44        64 // Vec<f64> 默认对齐,可以优化为实际对齐
45    }
46
47    /// 创建新的密集张量
48    ///
49    /// # Arguments
50    /// * `data` - 数据向量(行优先顺序,C-order)
51    /// * `shape` - 张量形状
52    ///
53    /// # Returns
54    /// 返回新创建的 DenseTensor
55    ///
56    /// # Panics
57    /// 如果 data 长度与 shape 不匹配会 panic
58    pub fn new(data: Vec<f64>, shape: Vec<usize>) -> Self {
59        let expected_len = shape.iter().product::<usize>();
60        assert_eq!(
61            data.len(),
62            expected_len,
63            "Data length {} does not match shape product {}",
64            data.len(),
65            expected_len
66        );
67
68        let strides = compute_strides(&shape);
69        Self {
70            data,
71            shape,
72            strides,
73            dtype: DType::F64,
74            device: Device::Cpu,
75        }
76    }
77
78    /// 从 Vec 创建张量(行优先顺序)
79    pub fn from_vec(data: Vec<f64>, shape: Vec<usize>) -> Self {
80        Self::new(data, shape)
81    }
82
83    /// 创建全零张量
84    pub fn zeros(shape: Vec<usize>) -> Self {
85        let data = vec![0.0; shape.iter().product()];
86        Self::new(data, shape)
87    }
88
89    /// 创建全一张量
90    pub fn ones(shape: Vec<usize>) -> Self {
91        let data = vec![1.0; shape.iter().product()];
92        Self::new(data, shape)
93    }
94
95    /// 创建标量张量
96    pub fn scalar(value: f64) -> Self {
97        Self {
98            data: vec![value],
99            shape: vec![],
100            strides: vec![],
101            dtype: DType::F64,
102            device: Device::Cpu,
103        }
104    }
105
106    /// 创建 2D 矩阵
107    pub fn matrix(rows: usize, cols: usize, data: Vec<f64>) -> Self {
108        Self::new(data, vec![rows, cols])
109    }
110
111    /// 创建 2D 单位矩阵
112    pub fn eye(size: usize) -> Self {
113        let mut data = vec![0.0; size * size];
114        for i in 0..size {
115            data[i * size + i] = 1.0;
116        }
117        Self::new(data, vec![size, size])
118    }
119
120    /// 从 ndarray::Array2 创建
121    #[cfg(feature = "tensor")]
122    pub fn from_ndarray(arr: &Array2<f64>) -> Self {
123        let shape = vec![arr.nrows(), arr.ncols()];
124        let data = arr.as_slice().unwrap().to_vec();
125        Self::new(data, shape)
126    }
127
128    /// 转换为 ndarray::Array2
129    #[cfg(feature = "tensor")]
130    pub fn to_ndarray(&self) -> Result<Array2<f64>, TensorError> {
131        if self.ndim() != 2 {
132            return Err(TensorError::DimensionMismatch {
133                expected: 2,
134                got: self.ndim(),
135            });
136        }
137        Ok(Array2::from_shape_vec((self.shape[0], self.shape[1]), self.data.clone()).unwrap())
138    }
139
140    /// 获取数据切片
141    pub fn data(&self) -> &[f64] {
142        &self.data
143    }
144
145    /// 获取可变数据切片
146    pub fn data_mut(&mut self) -> &mut [f64] {
147        &mut self.data
148    }
149
150    /// 获取 strides
151    pub fn strides(&self) -> &[usize] {
152        &self.strides
153    }
154
155    /// 检查是否为连续内存(C-order)
156    pub fn is_c_contiguous(&self) -> bool {
157        if self.ndim() <= 1 {
158            return true;
159        }
160        for i in 0..self.ndim() - 1 {
161            if self.strides[i] != self.strides[i + 1] * self.shape[i + 1] {
162                return false;
163            }
164        }
165        true
166    }
167
168    /// 获取指定索引的元素
169    pub fn get(&self, indices: &[usize]) -> Result<f64, TensorError> {
170        if indices.len() != self.ndim() {
171            return Err(TensorError::DimensionMismatch {
172                expected: self.ndim(),
173                got: indices.len(),
174            });
175        }
176
177        let mut offset = 0;
178        for (i, &idx) in indices.iter().enumerate() {
179            if idx >= self.shape[i] {
180                return Err(TensorError::IndexOutOfBounds {
181                    index: idx,
182                    dim: i,
183                    size: self.shape[i],
184                });
185            }
186            offset += idx * self.strides[i];
187        }
188
189        Ok(self.data[offset])
190    }
191
192    /// 设置指定索引的元素
193    pub fn set(&mut self, indices: &[usize], value: f64) -> Result<(), TensorError> {
194        if indices.len() != self.ndim() {
195            return Err(TensorError::DimensionMismatch {
196                expected: self.ndim(),
197                got: indices.len(),
198            });
199        }
200
201        let mut offset = 0;
202        for (i, &idx) in indices.iter().enumerate() {
203            if idx >= self.shape[i] {
204                return Err(TensorError::IndexOutOfBounds {
205                    index: idx,
206                    dim: i,
207                    size: self.shape[i],
208                });
209            }
210            offset += idx * self.strides[i];
211        }
212
213        self.data[offset] = value;
214        Ok(())
215    }
216
217    /// 获取指定行的数据
218    pub fn row(&self, row: usize) -> Result<Vec<f64>, TensorError> {
219        if self.ndim() != 2 {
220            return Err(TensorError::DimensionMismatch {
221                expected: 2,
222                got: self.ndim(),
223            });
224        }
225        if row >= self.shape[0] {
226            return Err(TensorError::IndexOutOfBounds {
227                index: row,
228                dim: 0,
229                size: self.shape[0],
230            });
231        }
232
233        let start = row * self.strides[0];
234        let cols = self.shape[1];
235        Ok(self.data[start..start + cols].to_vec())
236    }
237
238    /// 获取指定列的数据
239    pub fn col(&self, col: usize) -> Result<Vec<f64>, TensorError> {
240        if self.ndim() != 2 {
241            return Err(TensorError::DimensionMismatch {
242                expected: 2,
243                got: self.ndim(),
244            });
245        }
246        if col >= self.shape[1] {
247            return Err(TensorError::IndexOutOfBounds {
248                index: col,
249                dim: 1,
250                size: self.shape[1],
251            });
252        }
253
254        let mut result = Vec::with_capacity(self.shape[0]);
255        for row in 0..self.shape[0] {
256            let idx = row * self.strides[0] + col;
257            result.push(self.data[idx]);
258        }
259        Ok(result)
260    }
261}
262
263/// 计算 strides(C-order,行优先)
264fn compute_strides(shape: &[usize]) -> Vec<usize> {
265    let ndim = shape.len();
266    if ndim == 0 {
267        return vec![];
268    }
269
270    let mut strides = vec![1; ndim];
271    for i in (0..ndim - 1).rev() {
272        strides[i] = strides[i + 1] * shape[i + 1];
273    }
274    strides
275}
276
277#[cfg(feature = "tensor")]
278impl TensorBase for DenseTensor {
279    fn shape(&self) -> &[usize] {
280        &self.shape
281    }
282
283    fn dtype(&self) -> DType {
284        self.dtype
285    }
286
287    fn device(&self) -> Device {
288        self.device
289    }
290
291    fn to_dense(&self) -> DenseTensor {
292        self.clone()
293    }
294
295    #[cfg(feature = "tensor")]
296    fn to_sparse(&self) -> Option<crate::tensor::sparse::SparseTensor> {
297        // 将密集张量转换为 CSR 格式
298        let mut row_offsets = vec![0];
299        let mut col_indices = Vec::new();
300        let mut values = Vec::new();
301
302        if self.ndim() == 2 {
303            let rows = self.shape[0];
304            let cols = self.shape[1];
305
306            for row in 0..rows {
307                for col in 0..cols {
308                    let val = self.get(&[row, col]).unwrap();
309                    if val.abs() > 1e-10 {
310                        col_indices.push(col);
311                        values.push(val);
312                    }
313                }
314                row_offsets.push(col_indices.len());
315            }
316
317            let values_tensor = DenseTensor::new(values.clone(), vec![values.len()]);
318            let csr = crate::tensor::sparse::CSRTensor::new(
319                row_offsets,
320                col_indices,
321                values_tensor,
322                [self.shape[0], self.shape[1]],
323            );
324            Some(crate::tensor::sparse::SparseTensor::CSR(csr))
325        } else {
326            None
327        }
328    }
329}
330
331#[cfg(feature = "tensor")]
332impl TensorOps for DenseTensor {
333    fn add(&self, other: &Self) -> Self {
334        assert_eq!(
335            self.shape, other.shape,
336            "Shape mismatch for addition: {:?} vs {:?}",
337            self.shape, other.shape
338        );
339
340        let data: Vec<f64> = self
341            .data
342            .iter()
343            .zip(other.data.iter())
344            .map(|(&a, &b)| a + b)
345            .collect();
346
347        Self::new(data, self.shape.clone())
348    }
349
350    fn sub(&self, other: &Self) -> Self {
351        assert_eq!(
352            self.shape, other.shape,
353            "Shape mismatch for subtraction: {:?} vs {:?}",
354            self.shape, other.shape
355        );
356
357        let data: Vec<f64> = self
358            .data
359            .iter()
360            .zip(other.data.iter())
361            .map(|(&a, &b)| a - b)
362            .collect();
363
364        Self::new(data, self.shape.clone())
365    }
366
367    fn mul(&self, other: &Self) -> Self {
368        assert_eq!(
369            self.shape, other.shape,
370            "Shape mismatch for element-wise multiplication: {:?} vs {:?}",
371            self.shape, other.shape
372        );
373
374        let data: Vec<f64> = self
375            .data
376            .iter()
377            .zip(other.data.iter())
378            .map(|(&a, &b)| a * b)
379            .collect();
380
381        Self::new(data, self.shape.clone())
382    }
383
384    fn div(&self, other: &Self) -> Self {
385        assert_eq!(
386            self.shape, other.shape,
387            "Shape mismatch for division: {:?} vs {:?}",
388            self.shape, other.shape
389        );
390
391        let data: Vec<f64> = self
392            .data
393            .iter()
394            .zip(other.data.iter())
395            .map(|(&a, &b)| a / b)
396            .collect();
397
398        Self::new(data, self.shape.clone())
399    }
400
401    fn matmul(&self, other: &Self) -> Self {
402        assert_eq!(
403            self.ndim(),
404            2,
405            "matmul requires 2D tensors, got {}D",
406            self.ndim()
407        );
408        assert_eq!(
409            other.ndim(),
410            2,
411            "matmul requires 2D tensors, got {}D",
412            other.ndim()
413        );
414        assert_eq!(
415            self.shape[1], other.shape[0],
416            "Shape mismatch for matmul: {:?} x {:?}",
417            self.shape, other.shape
418        );
419
420        let m = self.shape[0];
421        let k = self.shape[1];
422        let n = other.shape[1];
423
424        let mut result = vec![0.0; m * n];
425
426        // 朴素矩阵乘法实现(后续可用 BLAS 优化)
427        for i in 0..m {
428            for j in 0..n {
429                let mut sum = 0.0;
430                for p in 0..k {
431                    sum += self.data[i * k + p] * other.data[p * n + j];
432                }
433                result[i * n + j] = sum;
434            }
435        }
436
437        Self::new(result, vec![m, n])
438    }
439
440    fn transpose(&self, axes: Option<&[usize]>) -> Self {
441        if self.ndim() == 0 {
442            return self.clone();
443        }
444
445        if self.ndim() == 2 {
446            // 2D 转置
447            let rows = self.shape[0];
448            let cols = self.shape[1];
449            let mut result = vec![0.0; cols * rows];
450
451            for i in 0..rows {
452                for j in 0..cols {
453                    result[j * rows + i] = self.data[i * cols + j];
454                }
455            }
456
457            Self::new(result, vec![cols, rows])
458        } else {
459            // N 维转置
460            let default_axes: Vec<usize> = (0..self.ndim()).rev().collect();
461            let axes = axes.unwrap_or(&default_axes);
462
463            assert_eq!(axes.len(), self.ndim(), "Axes length must match ndim");
464
465            let new_shape: Vec<usize> = axes.iter().map(|&a| self.shape[a]).collect();
466            let mut result = vec![0.0; self.numel()];
467
468            // 简化的 N 维转置(对于高维情况可能需要优化)
469            for (i, &val) in self.data.iter().enumerate() {
470                let mut idx = i;
471                let mut new_idx = 0;
472                let mut stride = 1;
473
474                for &a in axes.iter().rev() {
475                    let dim_size = self.shape[a];
476                    let dim_idx = idx % dim_size;
477                    idx /= dim_size;
478                    new_idx += dim_idx * stride;
479                    stride *= new_shape[new_shape.len() - 1 - a];
480                }
481
482                result[new_idx] = val;
483            }
484
485            Self::new(result, new_shape)
486        }
487    }
488
489    fn sum(&self, axes: Option<&[usize]>) -> Self {
490        if let Some(axes) = axes {
491            if axes.is_empty() {
492                return self.clone();
493            }
494
495            // 简化实现:仅支持单轴归约
496            if axes.len() == 1 {
497                let axis = axes[0];
498                if self.ndim() == 2 && axis == 0 {
499                    // 按行求和
500                    let cols = self.shape[1];
501                    let mut result = vec![0.0; cols];
502                    for row in self.data.chunks(cols) {
503                        for (j, val) in row.iter().enumerate() {
504                            result[j] += val;
505                        }
506                    }
507                    return Self::new(result, vec![cols]);
508                } else if self.ndim() == 2 && axis == 1 {
509                    // 按列求和
510                    let rows = self.shape[0];
511                    let cols = self.shape[1];
512                    let mut result = vec![0.0; rows];
513                    for (i, row_sum) in result.iter_mut().enumerate().take(rows) {
514                        let row_start = i * cols;
515                        *row_sum = self.data[row_start..row_start + cols].iter().sum();
516                    }
517                    return Self::new(result, vec![rows]);
518                }
519            }
520
521            // 默认:返回所有元素的和(标量)
522            let sum: f64 = self.data.iter().sum();
523            Self::scalar(sum)
524        } else {
525            // 无轴:返回所有元素的和
526            let sum: f64 = self.data.iter().sum();
527            Self::scalar(sum)
528        }
529    }
530
531    fn mean(&self, axes: Option<&[usize]>) -> Self {
532        let sum = self.sum(axes);
533        let count = if let Some(axes) = axes {
534            if axes.is_empty() {
535                1
536            } else {
537                axes.iter().map(|&a| self.shape[a]).product::<usize>()
538            }
539        } else {
540            self.numel()
541        };
542
543        sum.mul_scalar(1.0 / count as f64)
544    }
545
546    fn mul_scalar(&self, scalar: f64) -> Self {
547        let data: Vec<f64> = self.data.iter().map(|&x| x * scalar).collect();
548        Self::new(data, self.shape.clone())
549    }
550
551    fn add_scalar(&self, scalar: f64) -> Self {
552        let data: Vec<f64> = self.data.iter().map(|&x| x + scalar).collect();
553        Self::new(data, self.shape.clone())
554    }
555
556    fn map<F>(&self, f: F) -> Self
557    where
558        F: Fn(f64) -> f64 + Send + Sync,
559    {
560        let data: Vec<f64> = self.data.iter().copied().map(f).collect();
561        Self::new(data, self.shape.clone())
562    }
563
564    fn reshape(&self, new_shape: &[usize]) -> Self {
565        let new_size: usize = new_shape.iter().product();
566        assert_eq!(
567            new_size,
568            self.numel(),
569            "Reshape size mismatch: {} vs {}",
570            new_size,
571            self.numel()
572        );
573
574        Self::new(self.data.clone(), new_shape.to_vec())
575    }
576
577    fn slice(&self, axes: &[usize], ranges: &[core::ops::Range<usize>]) -> Self {
578        assert_eq!(axes.len(), ranges.len(), "Axes and ranges length mismatch");
579
580        // 简化实现:仅支持 2D 切片
581        if self.ndim() == 2 && axes.len() == 2 {
582            let row_range = if axes[0] == 0 {
583                ranges[0].clone()
584            } else {
585                ranges[1].clone()
586            };
587            let col_range = if axes[1] == 1 {
588                ranges[1].clone()
589            } else {
590                ranges[0].clone()
591            };
592
593            let new_rows = row_range.len();
594            let new_cols = col_range.len();
595            let mut result = Vec::with_capacity(new_rows * new_cols);
596
597            for i in row_range {
598                for j in col_range.clone() {
599                    result.push(self.data[i * self.shape[1] + j]);
600                }
601            }
602
603            return Self::new(result, vec![new_rows, new_cols]);
604        }
605
606        // 默认返回克隆
607        self.clone()
608    }
609
610    fn concat(&self, other: &Self, axis: usize) -> Self {
611        assert_eq!(
612            self.ndim(),
613            other.ndim(),
614            "Concat ndim mismatch: {} vs {}",
615            self.ndim(),
616            other.ndim()
617        );
618        assert!(
619            axis < self.ndim(),
620            "Axis {} out of range for {}D tensor",
621            axis,
622            self.ndim()
623        );
624
625        // 检查除 concat 轴外的所有维度是否匹配
626        for (i, (&s, &o)) in self.shape.iter().zip(other.shape.iter()).enumerate() {
627            if i != axis {
628                assert_eq!(s, o, "Shape mismatch at dim {}", i);
629            }
630        }
631
632        // 简化实现:仅支持 2D 沿轴 0 拼接
633        if self.ndim() == 2 && axis == 0 {
634            assert_eq!(
635                self.shape[1], other.shape[1],
636                "Column count mismatch for concat"
637            );
638
639            let new_rows = self.shape[0] + other.shape[0];
640            let cols = self.shape[1];
641            let mut result = Vec::with_capacity(new_rows * cols);
642
643            // 复制第一个张量
644            result.extend_from_slice(&self.data);
645            // 复制第二个张量
646            result.extend_from_slice(&other.data);
647
648            return Self::new(result, vec![new_rows, cols]);
649        }
650
651        // 默认:返回错误(需要更复杂的实现)
652        unimplemented!("Concat for this case is not implemented")
653    }
654
655    fn max(&self) -> f64 {
656        self.data.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
657    }
658
659    fn min(&self) -> f64 {
660        self.data.iter().cloned().fold(f64::INFINITY, f64::min)
661    }
662
663    fn norm(&self) -> f64 {
664        self.data.iter().map(|&x| x * x).sum::<f64>().sqrt()
665    }
666
667    fn normalize(&self) -> Self {
668        let norm = self.norm();
669        if norm > 1e-10 {
670            self.mul_scalar(1.0 / norm)
671        } else {
672            self.clone()
673        }
674    }
675}
676
677// Additional DenseTensor methods for transformer support
678#[cfg(feature = "tensor")]
679impl DenseTensor {
680    /// SiLU activation: f(x) = x * sigmoid(x)
681    pub fn silu(&self) -> Self {
682        self.map(|x| x / (1.0 + (-x).exp()))
683    }
684
685    /// GELU derivative (for backpropagation)
686    pub fn gelu_derivative(&self) -> Self {
687        const SQRT_2_OVER_PI: f64 = 0.7978845608028654;
688        const COEF: f64 = 0.044715;
689        self.map(|x| {
690            let x2 = x * x;
691            let x3 = x * x2;
692            let tanh_arg = SQRT_2_OVER_PI * (x + COEF * x3);
693            let tanh_val = tanh_arg.tanh();
694            0.5 * (1.0 + tanh_val) + x * 0.5 * (1.0 - tanh_val * tanh_val) * SQRT_2_OVER_PI * (1.0 + 3.0 * COEF * x2)
695        })
696    }
697
698    /// Mean along a specific dimension
699    pub fn mean_dim(&self, dim: isize) -> Self {
700        let ndim = self.ndim();
701        let axis = if dim < 0 { (ndim as isize + dim) as usize } else { dim as usize };
702
703        if ndim == 2 && axis == 0 {
704            // Mean along rows (result: [cols])
705            let cols = self.shape[1];
706            let rows = self.shape[0];
707            let mut result = vec![0.0; cols];
708            #[allow(clippy::needless_range_loop)]
709            for col in 0..cols {
710                for row in 0..rows {
711                    result[col] += self.data[row * cols + col];
712                }
713                result[col] /= rows as f64;
714            }
715            Self::new(result, vec![1, cols])
716        } else if ndim == 2 && axis == 1 {
717            // Mean along columns (result: [rows, 1])
718            let rows = self.shape[0];
719            let cols = self.shape[1];
720            let mut result = vec![0.0; rows];
721            #[allow(clippy::needless_range_loop)]
722            for row in 0..rows {
723                let row_start = row * cols;
724                result[row] = self.data[row_start..row_start + cols].iter().sum::<f64>() / cols as f64;
725            }
726            Self::new(result, vec![rows, 1])
727        } else if ndim == 3 && axis == 2 {
728            // Mean along last dimension for 3D tensors
729            let batch = self.shape[0];
730            let seq = self.shape[1];
731            let dim = self.shape[2];
732            let mut result = vec![0.0; batch * seq];
733            for b in 0..batch {
734                for s in 0..seq {
735                    let start = (b * seq + s) * dim;
736                    let sum: f64 = self.data[start..start + dim].iter().sum();
737                    result[b * seq + s] = sum / dim as f64;
738                }
739            }
740            Self::new(result, vec![batch, seq, 1])
741        } else {
742            // Fallback: return scalar mean
743            let sum: f64 = self.data.iter().sum();
744            Self::scalar(sum / self.numel() as f64)
745        }
746    }
747
748    /// Variance along a specific dimension
749    pub fn var_dim(&self, dim: isize) -> Self {
750        let mean = self.mean_dim(dim);
751        let ndim = self.ndim();
752        let axis = if dim < 0 { (ndim as isize + dim) as usize } else { dim as usize };
753
754        if ndim == 2 && axis == 0 {
755            let cols = self.shape[1];
756            let rows = self.shape[0];
757            let mut result = vec![0.0; cols];
758            #[allow(clippy::needless_range_loop)]
759            for col in 0..cols {
760                for row in 0..rows {
761                    let diff = self.data[row * cols + col] - mean.data()[col];
762                    result[col] += diff * diff;
763                }
764                result[col] /= rows as f64;
765            }
766            Self::new(result, vec![1, cols])
767        } else if ndim == 2 && axis == 1 {
768            let rows = self.shape[0];
769            let cols = self.shape[1];
770            let mut result = vec![0.0; rows];
771            #[allow(clippy::needless_range_loop)]
772            for row in 0..rows {
773                let row_start = row * cols;
774                let m = mean.data()[row];
775                let var: f64 = self.data[row_start..row_start + cols]
776                    .iter()
777                    .map(|&x| (x - m) * (x - m))
778                    .sum::<f64>() / cols as f64;
779                result[row] = var;
780            }
781            Self::new(result, vec![rows, 1])
782        } else if ndim == 3 && axis == 2 {
783            let batch = self.shape[0];
784            let seq = self.shape[1];
785            let dim = self.shape[2];
786            let mut result = vec![0.0; batch * seq];
787            for b in 0..batch {
788                for s in 0..seq {
789                    let start = (b * seq + s) * dim;
790                    let m = mean.data()[b * seq + s];
791                    let var: f64 = self.data[start..start + dim]
792                        .iter()
793                        .map(|&x| (x - m) * (x - m))
794                        .sum::<f64>() / dim as f64;
795                    result[b * seq + s] = var;
796                }
797            }
798            Self::new(result, vec![batch, seq, 1])
799        } else {
800            // Fallback: return scalar variance
801            let mean_val = self.data.iter().sum::<f64>() / self.numel() as f64;
802            let var: f64 = self.data.iter().map(|&x| (x - mean_val) * (x - mean_val)).sum::<f64>() / self.numel() as f64;
803            Self::scalar(var)
804        }
805    }
806
807    /// Element-wise square root
808    pub fn sqrt(&self) -> Self {
809        self.map(|x| x.sqrt())
810    }
811
812    /// Negate the tensor
813    pub fn neg(&self) -> Self {
814        self.mul_scalar(-1.0)
815    }
816
817    /// Element-wise greater than comparison (returns 1.0 if true, 0.0 otherwise)
818    pub fn gt(&self, value: f64) -> Self {
819        self.map(|x| if x > value { 1.0 } else { 0.0 })
820    }
821
822    /// Fill values with a given value where mask is 1.0
823    pub fn mask_fill(&self, mask: &Self, value: f64) -> Self {
824        assert_eq!(self.shape, mask.shape, "Shape mismatch for mask_fill");
825        let data: Vec<f64> = self.data.iter()
826            .zip(mask.data.iter())
827            .map(|(&v, &m)| if m > 0.5 { value } else { v })
828            .collect();
829        Self::new(data, self.shape.clone())
830    }
831
832    /// Transpose for 2D tensors (convenience method)
833    pub fn transpose_2d(&self) -> Self {
834        self.transpose(None)
835    }
836
837    /// Get a row from a 2D or 3D tensor
838    pub fn get_row(&self, row: usize) -> Self {
839        if self.ndim() == 2 {
840            let cols = self.shape[1];
841            let start = row * cols;
842            Self::new(self.data[start..start + cols].to_vec(), vec![1, cols])
843        } else if self.ndim() == 3 {
844            // For 3D tensors [batch, seq, dim], get row at index row
845            // This returns a 2D slice [batch, dim] at position row along seq dimension
846            let batch = self.shape[0];
847            let dim = self.shape[2];
848            let mut result_data = Vec::with_capacity(batch * dim);
849            
850            for b in 0..batch {
851                let offset = (b * self.shape[1] + row) * dim;
852                result_data.extend_from_slice(&self.data[offset..offset + dim]);
853            }
854            
855            Self::new(result_data, vec![batch, dim])
856        } else {
857            // Fallback: return first element
858            Self::scalar(self.data[0])
859        }
860    }
861
862    /// Set a row in the tensor (mutable)
863    pub fn set_row(&mut self, row: usize, data: &Self) {
864        if self.ndim() == 2 && data.ndim() == 2 {
865            let cols = self.shape[1];
866            let start = row * cols;
867            self.data[start..start + cols].copy_from_slice(data.data());
868        }
869    }
870
871    /// Create a tensor filled with a value
872    pub fn full(shape: &[usize], value: f64) -> Self {
873        let size: usize = shape.iter().product();
874        let data = vec![value; size];
875        Self::new(data, shape.to_vec())
876    }
877
878    /// Scale the tensor by a scalar
879    pub fn scale(&self, scalar: f64) -> Self {
880        self.mul_scalar(scalar)
881    }
882
883    /// Softmax along the last dimension
884    pub fn softmax(&self, dim: isize) -> Self {
885        crate::tensor::ops::activations::softmax(self, dim)
886    }
887
888    /// ReLU activation
889    pub fn relu(&self) -> Self {
890        crate::tensor::ops::activations::relu(self)
891    }
892
893    /// GELU activation
894    pub fn gelu(&self) -> Self {
895        crate::tensor::ops::activations::gelu(self)
896    }
897
898    /// Element-wise cosine
899    pub fn cos(&self) -> Self {
900        self.map(|x| x.cos())
901    }
902
903    /// Element-wise sine
904    pub fn sin(&self) -> Self {
905        self.map(|x| x.sin())
906    }
907
908    /// Element-wise natural logarithm
909    pub fn ln(&self) -> Self {
910        self.map(|x| x.ln())
911    }
912
913    /// Batched matrix multiplication
914    /// For 3D tensors: [batch, seq, hidden] @ [hidden, out] -> [batch, seq, out]
915    /// Broadcasts 2D weight across batch dimension
916    pub fn bmm_broadcast_weight(&self, weight: &DenseTensor) -> Self {
917        assert_eq!(self.ndim(), 3, "bmm_broadcast_weight requires 3D tensor, got {}D", self.ndim());
918        assert_eq!(weight.ndim(), 2, "weight must be 2D, got {}D", weight.ndim());
919        assert_eq!(
920            self.shape[2], weight.shape[0],
921            "Shape mismatch for bmm: {:?} x {:?}",
922            self.shape, weight.shape
923        );
924
925        let batch = self.shape[0];
926        let seq = self.shape[1];
927        let hidden = self.shape[2];
928        let out = weight.shape[1];
929
930        let mut result = vec![0.0; batch * seq * out];
931
932        // Batched matmul: for each batch and seq, do matmul with weight
933        for b in 0..batch {
934            for s in 0..seq {
935                let input_start = (b * seq + s) * hidden;
936                let output_start = (b * seq + s) * out;
937                
938                for o in 0..out {
939                    let mut sum = 0.0;
940                    for h in 0..hidden {
941                        sum += self.data[input_start + h] * weight.data[h * out + o];
942                    }
943                    result[output_start + o] = sum;
944                }
945            }
946        }
947
948        Self::new(result, vec![batch, seq, out])
949    }
950
951    /// Expand the last dimension from 1 to target_dim (for broadcasting)
952    /// E.g., [batch, seq, 1] -> [batch, seq, target_dim]
953    pub fn expand_last_dim(&self, target_dim: usize) -> Self {
954        assert!(
955            self.ndim() >= 1 && self.shape()[self.ndim() - 1] == 1,
956            "Last dimension must be 1 for expansion"
957        );
958
959        let mut new_shape = self.shape.to_vec();
960        new_shape[self.ndim() - 1] = target_dim;
961
962        let mut data = Vec::with_capacity(self.numel() * target_dim);
963        for &val in self.data.iter() {
964            for _ in 0..target_dim {
965                data.push(val);
966            }
967        }
968
969        Self::new(data, new_shape)
970    }
971
972    /// Expand a 1D tensor [hidden] to 3D [batch, seq, hidden]
973    pub fn expand_to_3d(&self, batch: usize, seq: usize) -> Self {
974        assert_eq!(self.ndim(), 1, "Must be 1D tensor for 3D expansion");
975        let hidden = self.shape[0];
976
977        let mut data = Vec::with_capacity(batch * seq * hidden);
978        for _ in 0..batch * seq {
979            data.extend_from_slice(&self.data);
980        }
981
982        Self::new(data, vec![batch, seq, hidden])
983    }
984
985    /// Expand the last dimension from 1 to target_dim for 2D tensors
986    /// E.g., [seq, 1] -> [seq, target_dim]
987    pub fn expand_last_dim_2d(&self, target_dim: usize) -> Self {
988        assert!(
989            self.ndim() == 2 && self.shape()[1] == 1,
990            "Must be 2D tensor with last dim 1 for expansion"
991        );
992
993        let seq = self.shape[0];
994        let mut data = Vec::with_capacity(seq * target_dim);
995        for &val in self.data.iter() {
996            for _ in 0..target_dim {
997                data.push(val);
998            }
999        }
1000
1001        Self::new(data, vec![seq, target_dim])
1002    }
1003
1004    /// Expand a 1D tensor [hidden] to 2D [seq, hidden]
1005    pub fn expand_to_2d(&self, seq: usize) -> Self {
1006        assert_eq!(self.ndim(), 1, "Must be 1D tensor for 2D expansion");
1007        let hidden = self.shape[0];
1008
1009        let mut data = Vec::with_capacity(seq * hidden);
1010        for _ in 0..seq {
1011            data.extend_from_slice(&self.data);
1012        }
1013
1014        Self::new(data, vec![seq, hidden])
1015    }
1016}
1017
1018#[cfg(feature = "tensor")]
1019impl fmt::Debug for DenseTensor {
1020    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1021        f.debug_struct("DenseTensor")
1022            .field("shape", &self.shape)
1023            .field("dtype", &self.dtype)
1024            .field("device", &self.device)
1025            .field("numel", &self.numel())
1026            .finish()
1027    }
1028}
1029
1030#[cfg(feature = "tensor")]
1031impl Default for DenseTensor {
1032    fn default() -> Self {
1033        Self::zeros(vec![1])
1034    }
1035}
1036
1037#[cfg(test)]
1038mod tests {
1039    use super::*;
1040
1041    #[test]
1042    fn test_dense_tensor_creation() {
1043        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1044        let tensor = DenseTensor::new(data.clone(), vec![2, 3]);
1045
1046        assert_eq!(tensor.shape(), &[2, 3]);
1047        assert_eq!(tensor.data(), &data);
1048        assert_eq!(tensor.numel(), 6);
1049        assert_eq!(tensor.ndim(), 2);
1050    }
1051
1052    #[test]
1053    fn test_zeros_and_ones() {
1054        let zeros = DenseTensor::zeros(vec![2, 3]);
1055        assert!(zeros.data().iter().all(|&x| x == 0.0));
1056
1057        let ones = DenseTensor::ones(vec![2, 3]);
1058        assert!(ones.data().iter().all(|&x| x == 1.0));
1059    }
1060
1061    #[test]
1062    fn test_matrix_operations() {
1063        let a = DenseTensor::matrix(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1064        let b = DenseTensor::matrix(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1065
1066        let sum = a.add(&b);
1067        assert_eq!(sum.data(), &[6.0, 8.0, 10.0, 12.0]);
1068
1069        let diff = a.sub(&b);
1070        assert_eq!(diff.data(), &[-4.0, -4.0, -4.0, -4.0]);
1071    }
1072
1073    #[test]
1074    fn test_matmul() {
1075        let a = DenseTensor::matrix(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1076        let b = DenseTensor::matrix(3, 2, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
1077
1078        let result = a.matmul(&b);
1079        assert_eq!(result.shape(), &[2, 2]);
1080        // [1,2,3]·[7,9,11] = 7+18+33 = 58
1081        // [1,2,3]·[8,10,12] = 8+20+36 = 64
1082        // [4,5,6]·[7,9,11] = 28+45+66 = 139
1083        // [4,5,6]·[8,10,12] = 32+50+72 = 154
1084        assert_eq!(result.data(), &[58.0, 64.0, 139.0, 154.0]);
1085    }
1086
1087    #[test]
1088    fn test_transpose() {
1089        let a = DenseTensor::matrix(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1090        let t = a.transpose(None);
1091
1092        assert_eq!(t.shape(), &[3, 2]);
1093        assert_eq!(t.data(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1094    }
1095
1096    #[test]
1097    fn test_scalar_operations() {
1098        let a = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![3]);
1099
1100        let mul = a.mul_scalar(2.0);
1101        assert_eq!(mul.data(), &[2.0, 4.0, 6.0]);
1102
1103        let add = a.add_scalar(1.0);
1104        assert_eq!(add.data(), &[2.0, 3.0, 4.0]);
1105    }
1106
1107    #[test]
1108    fn test_norm_and_normalize() {
1109        let a = DenseTensor::new(vec![3.0, 4.0], vec![2]);
1110
1111        assert!((a.norm() - 5.0).abs() < 1e-10);
1112
1113        let normalized = a.normalize();
1114        assert!((normalized.norm() - 1.0).abs() < 1e-10);
1115    }
1116
1117    #[test]
1118    #[should_panic]
1119    fn test_shape_mismatch_panic() {
1120        let a = DenseTensor::new(vec![1.0, 2.0], vec![2]);
1121        let b = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![3]);
1122        let _ = a.add(&b);
1123    }
1124}