Skip to main content

god_graph/tensor/
ops.rs

1//! Tensor 高级操作
2//!
3//! 提供额外的 tensor 操作,包括激活函数、归一化等
4
5use crate::tensor::dense::DenseTensor;
6use crate::tensor::error::TensorError;
7use crate::tensor::traits::{TensorBase, TensorOps};
8
9/// 激活函数实现
10pub mod activations {
11    use super::*;
12
13    /// ReLU 激活函数:f(x) = max(0, x)
14    pub fn relu(tensor: &DenseTensor) -> DenseTensor {
15        tensor.map(|x| x.max(0.0))
16    }
17
18    /// Sigmoid 激活函数:f(x) = 1 / (1 + exp(-x))
19    pub fn sigmoid(tensor: &DenseTensor) -> DenseTensor {
20        tensor.map(|x| 1.0 / (1.0 + (-x).exp()))
21    }
22
23    /// Tanh 激活函数:f(x) = tanh(x)
24    pub fn tanh(tensor: &DenseTensor) -> DenseTensor {
25        tensor.map(|x| x.tanh())
26    }
27
28    /// Leaky ReLU 激活函数:f(x) = x if x > 0 else alpha * x
29    pub fn leaky_relu(tensor: &DenseTensor, alpha: f64) -> DenseTensor {
30        tensor.map(|x| if x > 0.0 { x } else { alpha * x })
31    }
32
33    /// Softmax 函数(沿指定轴)
34    pub fn softmax(tensor: &DenseTensor, axis: isize) -> DenseTensor {
35        let ndim = tensor.ndim();
36        // Handle negative axis
37        let axis = if axis < 0 { (ndim as isize + axis) as usize } else { axis as usize };
38        
39        if ndim == 1 {
40            // 1D 情况:直接计算 softmax
41            let max_val = tensor.max();
42            let exp_data: Vec<f64> = tensor.data().iter().map(|&x| (x - max_val).exp()).collect();
43            let sum: f64 = exp_data.iter().sum();
44            let data: Vec<f64> = exp_data.iter().map(|&x| x / sum).collect();
45            DenseTensor::new(data, tensor.shape().to_vec())
46        } else if ndim == 2 {
47            // 2D 情况:按行或按列计算 softmax
48            let rows = tensor.shape()[0];
49            let cols = tensor.shape()[1];
50
51            if axis == 0 {
52                // 按列 softmax
53                let mut result = vec![0.0; rows * cols];
54                for col in 0..cols {
55                    let mut col_data = Vec::with_capacity(rows);
56                    for row in 0..rows {
57                        col_data.push(tensor.data()[row * cols + col]);
58                    }
59                    let max_val = col_data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
60                    let exp_data: Vec<f64> =
61                        col_data.iter().map(|&x| (x - max_val).exp()).collect();
62                    let sum: f64 = exp_data.iter().sum();
63                    for row in 0..rows {
64                        result[row * cols + col] = exp_data[row] / sum;
65                    }
66                }
67                DenseTensor::new(result, vec![rows, cols])
68            } else {
69                // 按行 softmax
70                let mut result = vec![0.0; rows * cols];
71                for row in 0..rows {
72                    let row_start = row * cols;
73                    let row_data = &tensor.data()[row_start..row_start + cols];
74                    let max_val = row_data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
75                    let exp_data: Vec<f64> =
76                        row_data.iter().map(|&x| (x - max_val).exp()).collect();
77                    let sum: f64 = exp_data.iter().sum();
78                    for col in 0..cols {
79                        result[row_start + col] = exp_data[col] / sum;
80                    }
81                }
82                DenseTensor::new(result, vec![rows, cols])
83            }
84        } else if ndim == 3 {
85            // 3D 情况:[batch, seq, dim]
86            let batch = tensor.shape()[0];
87            let seq = tensor.shape()[1];
88            let dim = tensor.shape()[2];
89            
90            if axis == 2 {
91                // Softmax along last dimension (most common for transformers)
92                let mut result = Vec::with_capacity(batch * seq * dim);
93                for b in 0..batch {
94                    for s in 0..seq {
95                        let start = (b * seq + s) * dim;
96                        let row_data = &tensor.data()[start..start + dim];
97                        let max_val = row_data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
98                        let exp_data: Vec<f64> = row_data.iter().map(|&x| (x - max_val).exp()).collect();
99                        let sum: f64 = exp_data.iter().sum();
100                        for &e in &exp_data {
101                            result.push(e / sum);
102                        }
103                    }
104                }
105                DenseTensor::new(result, vec![batch, seq, dim])
106            } else {
107                panic!("Softmax for 3D tensors only supports axis=2 or axis=-1");
108            }
109        } else {
110            // N 维情况:简化处理,沿最后一个轴计算 softmax
111            if axis == ndim - 1 {
112                let outer_size: usize = tensor.shape()[..ndim-1].iter().product();
113                let inner_size = tensor.shape()[ndim-1];
114                let mut result = Vec::with_capacity(tensor.numel());
115                
116                for i in 0..outer_size {
117                    let start = i * inner_size;
118                    let row_data = &tensor.data()[start..start + inner_size];
119                    let max_val = row_data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
120                    let exp_data: Vec<f64> = row_data.iter().map(|&x| (x - max_val).exp()).collect();
121                    let sum: f64 = exp_data.iter().sum();
122                    for &e in &exp_data {
123                        result.push(e / sum);
124                    }
125                }
126                DenseTensor::new(result, tensor.shape().to_vec())
127            } else {
128                panic!("Softmax for {}D tensors with axis={} is not yet implemented", ndim, axis);
129            }
130        }
131    }
132
133    /// GELU 激活函数:f(x) = x * Φ(x) ≈ x * 0.5 * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
134    pub fn gelu(tensor: &DenseTensor) -> DenseTensor {
135        const SQRT_2_OVER_PI: f64 = 0.7978845608028654;
136        const COEF: f64 = 0.044715;
137        tensor.map(|x| {
138            let x3 = x * x * x;
139            let inner = SQRT_2_OVER_PI * (x + COEF * x3);
140            x * 0.5 * (1.0 + inner.tanh())
141        })
142    }
143
144    #[cfg(test)]
145    mod tests {
146        use super::*;
147
148        #[test]
149        fn test_relu() {
150            let t = DenseTensor::new(vec![-1.0, 2.0, -3.0, 4.0], vec![4]);
151            let result = relu(&t);
152            assert_eq!(result.data(), &[0.0, 2.0, 0.0, 4.0]);
153        }
154
155        #[test]
156        fn test_sigmoid() {
157            let t = DenseTensor::new(vec![0.0], vec![1]);
158            let result = sigmoid(&t);
159            assert!((result.data()[0] - 0.5).abs() < 1e-6);
160        }
161
162        #[test]
163        fn test_softmax_1d() {
164            let t = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![3]);
165            let result = softmax(&t, 0);
166            let sum: f64 = result.data().iter().sum();
167            assert!((sum - 1.0).abs() < 1e-6);
168        }
169
170        #[test]
171        fn test_softmax_2d() {
172            let t = DenseTensor::matrix(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
173            let result = softmax(&t, 1);
174            // 每行的和应该为 1
175            let row0_sum = result.data()[0] + result.data()[1];
176            let row1_sum = result.data()[2] + result.data()[3];
177            assert!((row0_sum - 1.0).abs() < 1e-6);
178            assert!((row1_sum - 1.0).abs() < 1e-6);
179        }
180    }
181}
182
183/// 归一化操作
184pub mod normalization {
185    use super::*;
186
187    /// Layer Normalization
188    pub fn layer_norm(tensor: &DenseTensor, epsilon: f64) -> DenseTensor {
189        if tensor.ndim() == 1 {
190            let mean = tensor.mean(None).data()[0];
191            let centered = tensor.add_scalar(-mean);
192            let std =
193                centered.data().iter().map(|&x| x * x).sum::<f64>().sqrt() / tensor.numel() as f64;
194            centered.mul_scalar(1.0 / (std + epsilon))
195        } else {
196            // 对于高维张量,沿最后一个轴归一化
197            // TODO: 实现完整的 N 维 LayerNorm
198            panic!(
199                "LayerNorm for {}D tensors is not yet implemented",
200                tensor.ndim()
201            );
202        }
203    }
204
205    /// Batch Normalization(简化版)
206    pub fn batch_norm(
207        tensor: &DenseTensor,
208        mean: &DenseTensor,
209        var: &DenseTensor,
210        epsilon: f64,
211    ) -> DenseTensor {
212        let centered = tensor.sub(mean);
213        let std = var.map(|v| (v + epsilon).sqrt());
214        centered.div(&std)
215    }
216
217    /// Graph Normalization(图归一化)
218    pub fn graph_norm(tensor: &DenseTensor, epsilon: f64) -> DenseTensor {
219        // 对整个图的特征进行归一化
220        let mean = tensor.mean(None).data()[0];
221        let centered = tensor.add_scalar(-mean);
222        let std =
223            centered.data().iter().map(|&x| x * x).sum::<f64>().sqrt() / tensor.numel() as f64;
224        centered.mul_scalar(1.0 / (std + epsilon))
225    }
226
227    #[cfg(test)]
228    mod tests {
229        use super::*;
230
231        #[test]
232        fn test_layer_norm_1d() {
233            let t = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0], vec![5]);
234            let result = layer_norm(&t, 1e-5);
235            let mean = result.mean(None).data()[0];
236            assert!(mean.abs() < 1e-5);
237        }
238    }
239}
240
241/// 矩阵操作
242pub mod matrix {
243    use super::*;
244
245    /// 矩阵转置(2D 专用)
246    pub fn transpose(tensor: &DenseTensor) -> DenseTensor {
247        tensor.transpose(None)
248    }
249
250    /// 矩阵求逆(使用高斯 - 约旦消元法,仅适用于小矩阵)
251    pub fn inverse(tensor: &DenseTensor) -> Result<DenseTensor, TensorError> {
252        if tensor.ndim() != 2 || tensor.shape()[0] != tensor.shape()[1] {
253            return Err(TensorError::DimensionMismatch {
254                expected: 2,
255                got: tensor.ndim(),
256            });
257        }
258
259        let n = tensor.shape()[0];
260        let mut augmented = vec![0.0; n * n * 2];
261
262        // 构建增广矩阵 [A|I]
263        for i in 0..n {
264            for j in 0..n {
265                augmented[i * (n * 2) + j] = tensor.data()[i * n + j];
266                if i == j {
267                    augmented[i * (n * 2) + n + j] = 1.0;
268                }
269            }
270        }
271
272        // 高斯 - 约旦消元
273        for col in 0..n {
274            // 寻找主元
275            let mut max_row = col;
276            for row in col + 1..n {
277                if augmented[row * (n * 2) + col].abs() > augmented[max_row * (n * 2) + col].abs() {
278                    max_row = row;
279                }
280            }
281
282            // 交换行
283            if max_row != col {
284                for j in 0..n * 2 {
285                    augmented.swap(col * (n * 2) + j, max_row * (n * 2) + j);
286                }
287            }
288
289            let pivot = augmented[col * (n * 2) + col];
290            if pivot.abs() < 1e-10 {
291                return Err(TensorError::MatrixError {
292                    message: "Matrix is singular".to_string(),
293                });
294            }
295
296            // 归一化当前行
297            for j in 0..n * 2 {
298                augmented[col * (n * 2) + j] /= pivot;
299            }
300
301            // 消去其他行
302            for row in 0..n {
303                if row != col {
304                    let factor = augmented[row * (n * 2) + col];
305                    for j in 0..n * 2 {
306                        augmented[row * (n * 2) + j] -= factor * augmented[col * (n * 2) + j];
307                    }
308                }
309            }
310        }
311
312        // 提取逆矩阵
313        let mut inv_data = vec![0.0; n * n];
314        for i in 0..n {
315            for j in 0..n {
316                inv_data[i * n + j] = augmented[i * (n * 2) + n + j];
317            }
318        }
319
320        Ok(DenseTensor::new(inv_data, vec![n, n]))
321    }
322
323    /// 矩阵行列式(使用余子式展开递归计算)
324    pub fn determinant(tensor: &DenseTensor) -> Result<f64, TensorError> {
325        if tensor.ndim() != 2 || tensor.shape()[0] != tensor.shape()[1] {
326            return Err(TensorError::DimensionMismatch {
327                expected: 2,
328                got: tensor.ndim(),
329            });
330        }
331
332        let n = tensor.shape()[0];
333        if n == 1 {
334            return Ok(tensor.data()[0]);
335        }
336        if n == 2 {
337            return Ok(tensor.data()[0] * tensor.data()[3] - tensor.data()[1] * tensor.data()[2]);
338        }
339
340        // 使用余子式展开(对于小矩阵)
341        let mut det = 0.0;
342        for j in 0..n {
343            let minor = get_minor(tensor, 0, j);
344            let cofactor = if j % 2 == 0 { 1.0 } else { -1.0 };
345            det += cofactor * tensor.data()[j] * determinant(&minor)?;
346        }
347
348        Ok(det)
349    }
350
351    /// 获取余子式矩阵
352    fn get_minor(tensor: &DenseTensor, row: usize, col: usize) -> DenseTensor {
353        let n = tensor.shape()[0];
354        let mut minor_data = Vec::with_capacity((n - 1) * (n - 1));
355
356        for i in 0..n {
357            if i == row {
358                continue;
359            }
360            for j in 0..n {
361                if j == col {
362                    continue;
363                }
364                let src_idx = i * n + j;
365                minor_data.push(tensor.data()[src_idx]);
366            }
367        }
368
369        DenseTensor::new(minor_data, vec![n - 1, n - 1])
370    }
371
372    /// 特征值和特征向量(使用幂迭代法,仅适用于对称矩阵)
373    pub fn power_iteration(tensor: &DenseTensor, max_iter: usize, tol: f64) -> (f64, DenseTensor) {
374        let n = tensor.shape()[0];
375        let mut v = DenseTensor::ones(vec![n]).normalize();
376
377        let mut eigenvalue = 0.0;
378        for _ in 0..max_iter {
379            // Av
380            let av = tensor.matmul(&v);
381            let new_eigenvalue = v
382                .data()
383                .iter()
384                .zip(av.data().iter())
385                .map(|(&a, &b)| a * b)
386                .sum::<f64>();
387
388            // 归一化
389            v = av.normalize();
390
391            if (new_eigenvalue - eigenvalue).abs() < tol {
392                break;
393            }
394            eigenvalue = new_eigenvalue;
395        }
396
397        (eigenvalue, v)
398    }
399
400    #[cfg(test)]
401    mod tests {
402        use super::*;
403
404        #[test]
405        fn test_transpose() {
406            let t = DenseTensor::matrix(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
407            let result = transpose(&t);
408            assert_eq!(result.shape(), &[3, 2]);
409            assert_eq!(result.data(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
410        }
411
412        #[test]
413        fn test_inverse_2x2() {
414            // [1, 2; 3, 4] 的逆矩阵是 [-2, 1; 1.5, -0.5]
415            let t = DenseTensor::matrix(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
416            let inv = inverse(&t).unwrap();
417            assert!((inv.data()[0] - (-2.0)).abs() < 1e-6);
418            assert!((inv.data()[1] - 1.0).abs() < 1e-6);
419            assert!((inv.data()[2] - 1.5).abs() < 1e-6);
420            assert!((inv.data()[3] - (-0.5)).abs() < 1e-6);
421        }
422
423        #[test]
424        fn test_determinant() {
425            let t = DenseTensor::matrix(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
426            let det = determinant(&t).unwrap();
427            assert!((det - (-2.0)).abs() < 1e-6);
428        }
429    }
430}
431
432/// 随机操作
433#[cfg(feature = "rand")]
434pub mod random {
435    use super::*;
436    use rand::Rng;
437
438    /// 创建随机初始化的张量(Xavier 初始化)
439    pub fn xavier_init(rows: usize, cols: usize) -> DenseTensor {
440        let limit = (6.0 / (rows + cols) as f64).sqrt();
441        let mut rng = rand::thread_rng();
442        let data: Vec<f64> = (0..rows * cols)
443            .map(|_| rng.gen_range(-limit..limit))
444            .collect();
445        DenseTensor::new(data, vec![rows, cols])
446    }
447
448    /// 创建随机初始化的张量(He 初始化)
449    pub fn he_init(rows: usize, cols: usize) -> DenseTensor {
450        let std = (2.0 / rows as f64).sqrt();
451        let mut rng = rand::thread_rng();
452        let data: Vec<f64> = (0..rows * cols)
453            .map(|_| {
454                // Box-Muller 变换生成正态分布
455                let u1: f64 = rng.gen_range(0.0..1.0);
456                let u2: f64 = rng.gen_range(0.0..1.0);
457                std * (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
458            })
459            .collect();
460        DenseTensor::new(data, vec![rows, cols])
461    }
462
463    /// Dropout(训练时使用)
464    pub fn dropout(tensor: &DenseTensor, p: f64) -> DenseTensor {
465        let mut rng = rand::thread_rng();
466        let scale = 1.0 / (1.0 - p);
467        let data: Vec<f64> = tensor
468            .data()
469            .iter()
470            .map(|&x| if rng.gen::<f64>() < p { 0.0 } else { x * scale })
471            .collect();
472        DenseTensor::new(data, tensor.shape().to_vec())
473    }
474}