rssn/numerical/
tensor.rs

1//! # Numerical Tensor Operations
2//!
3//! This module provides numerical tensor operations, primarily using `ndarray`
4//! for efficient multi-dimensional array manipulation. It includes functions
5//! for tensor contraction (tensordot), outer product, and Einstein summation (`einsum`).
6
7use ndarray::ArrayD;
8use ndarray::IxDyn;
9
10/// Performs tensor contraction between two N-dimensional arrays (tensordot).
11///
12/// # Arguments
13/// * `a`, `b` - The two tensors (`ndarray::ArrayD<f64>`) to contract.
14/// * `axes_a`, `axes_b` - The axes to contract for tensor `a` and `b` respectively.
15///
16/// # Returns
17/// The resulting contracted tensor as an `ndarray::ArrayD<f64>`.
18///
19/// # Errors
20/// Returns an error if the number of axes to contract mismatch, or if the dimensions along these axes do not match.
21
22pub fn tensordot(
23    a: &ArrayD<f64>,
24    b: &ArrayD<f64>,
25    axes_a: &[usize],
26    axes_b: &[usize],
27) -> Result<ArrayD<f64>, String> {
28
29    if axes_a.len() != axes_b.len() {
30
31        return Err("Contracted axes \
32                    must have the \
33                    same length."
34            .to_string());
35    }
36
37    for (&ax_a, &ax_b) in axes_a
38        .iter()
39        .zip(axes_b.iter())
40    {
41
42        if a.shape()[ax_a]
43            != b.shape()[ax_b]
44        {
45
46            return Err(format!(
47                "Dimension mismatch \
48                 on contracted axes: \
49                 {} != {}",
50                a.shape()[ax_a],
51                b.shape()[ax_b]
52            ));
53        }
54    }
55
56    let free_axes_a: Vec<_> = (0 .. a
57        .ndim())
58        .filter(|i| !axes_a.contains(i))
59        .collect();
60
61    let free_axes_b: Vec<_> = (0 .. b
62        .ndim())
63        .filter(|i| !axes_b.contains(i))
64        .collect();
65
66    let perm_a: Vec<_> = free_axes_a
67        .iter()
68        .chain(axes_a.iter())
69        .copied()
70        .collect();
71
72    let perm_b: Vec<_> = axes_b
73        .iter()
74        .chain(free_axes_b.iter())
75        .copied()
76        .collect();
77
78    let a_perm = a
79        .clone()
80        .permuted_axes(perm_a);
81
82    let b_perm = b
83        .clone()
84        .permuted_axes(perm_b);
85
86    let free_dim_a = free_axes_a
87        .iter()
88        .map(|&i| a.shape()[i])
89        .product::<usize>();
90
91    let free_dim_b = free_axes_b
92        .iter()
93        .map(|&i| b.shape()[i])
94        .product::<usize>();
95
96    let contracted_dim = axes_a
97        .iter()
98        .map(|&i| a.shape()[i])
99        .product::<usize>();
100
101    let a_mat = a_perm
102        .to_shape((
103            free_dim_a,
104            contracted_dim,
105        ))
106        .map_err(|e| e.to_string())?
107        .to_owned();
108
109    let b_mat = b_perm
110        .to_shape((
111            contracted_dim,
112            free_dim_b,
113        ))
114        .map_err(|e| e.to_string())?
115        .to_owned();
116
117    let result_mat = a_mat.dot(&b_mat);
118
119    let mut final_shape_dims =
120        Vec::new();
121
122    final_shape_dims.extend(
123        free_axes_a
124            .iter()
125            .map(|&i| a.shape()[i]),
126    );
127
128    final_shape_dims.extend(
129        free_axes_b
130            .iter()
131            .map(|&i| b.shape()[i]),
132    );
133
134    Ok(result_mat
135        .to_shape(IxDyn(
136            &final_shape_dims,
137        ))
138        .map_err(|e| e.to_string())?
139        .to_owned())
140}
141
142/// Computes the outer product of two tensors.
143///
144/// The outer product of two tensors `A` (rank `r`) and `B` (rank `s`)
145/// results in a new tensor `C` of rank `r + s`. Each component of `C`
146/// is the product of a component from `A` and a component from `B`.
147///
148/// # Arguments
149/// * `a` - The first tensor (`ndarray::ArrayD<f64>`).
150/// * `b` - The second tensor (`ndarray::ArrayD<f64>`).
151///
152/// # Returns
153/// The resulting outer product tensor as an `ndarray::ArrayD<f64>`.
154///
155/// # Errors
156/// Returns an error if input tensors are not contiguous.
157
158pub fn outer_product(
159    a: &ArrayD<f64>,
160    b: &ArrayD<f64>,
161) -> Result<ArrayD<f64>, String> {
162
163    let mut new_shape =
164        a.shape().to_vec();
165
166    new_shape
167        .extend_from_slice(b.shape());
168
169    let a_flat = a
170        .as_slice()
171        .ok_or_else(|| {
172
173            "Input tensor 'a' is not \
174             contiguous"
175                .to_string()
176        })?;
177
178    let b_flat = b
179        .as_slice()
180        .ok_or_else(|| {
181
182            "Input tensor 'b' is not \
183             contiguous"
184                .to_string()
185        })?;
186
187    let mut result_data =
188        Vec::with_capacity(
189            a.len() * b.len(),
190        );
191
192    for val_a in a_flat {
193
194        for val_b in b_flat {
195
196            result_data
197                .push(val_a * val_b);
198        }
199    }
200
201    ArrayD::from_shape_vec(
202        IxDyn(&new_shape),
203        result_data,
204    )
205    .map_err(|e| e.to_string())
206}
207
208/// Performs tensor-vector multiplication.
209///
210/// # Errors
211/// Returns an error if the tensor has zero dimensions or if the last dimension mismatches the vector size.
212
213pub fn tensor_vec_mul(
214    tensor: &ArrayD<f64>,
215    vector: &[f64],
216) -> Result<ArrayD<f64>, String> {
217
218    if tensor.ndim() < 1 {
219
220        return Err("Tensor must \
221                    have at least \
222                    one dimension."
223            .to_string());
224    }
225
226    let last_dim = tensor.shape()
227        [tensor.ndim() - 1];
228
229    if last_dim != vector.len() {
230
231        return Err(format!(
232            "Dimension mismatch: last \
233             tensor dim {} != vector \
234             length {}",
235            last_dim,
236            vector.len()
237        ));
238    }
239
240    let vec_arr =
241        ndarray::Array1::from_vec(
242            vector.to_vec(),
243        );
244
245    let res = tensordot(
246        tensor,
247        &vec_arr.into_dyn(),
248        &[tensor.ndim() - 1],
249        &[0],
250    )?;
251
252    Ok(res)
253}
254
255/// Computes the inner product of two tensors of the same shape.
256///
257/// # Errors
258/// Returns an error if the shapes mismatch or if tensors are not contiguous.
259
260pub fn inner_product(
261    a: &ArrayD<f64>,
262    b: &ArrayD<f64>,
263) -> Result<f64, String> {
264
265    if a.shape() != b.shape() {
266
267        return Err("Tensors must \
268                    have the same \
269                    shape for inner \
270                    product."
271            .to_string());
272    }
273
274    let a_flat = a.as_slice().ok_or(
275        "Tensor 'a' is not contiguous",
276    )?;
277
278    let b_flat = b.as_slice().ok_or(
279        "Tensor 'b' is not contiguous",
280    )?;
281
282    Ok(a_flat
283        .iter()
284        .zip(b_flat.iter())
285        .map(|(x, y)| x * y)
286        .sum())
287}
288
289/// Contracts a single tensor along two specified axes.
290///
291/// # Errors
292/// Returns an error if axes are the same, dimensions mismatch, or if general rank contraction is not yet implemented.
293
294pub fn contract(
295    a: &ArrayD<f64>,
296    axis1: usize,
297    axis2: usize,
298) -> Result<ArrayD<f64>, String> {
299
300    if axis1 == axis2 {
301
302        return Err("Axes must be \
303                    different for \
304                    contraction."
305            .to_string());
306    }
307
308    if a.shape()[axis1]
309        != a.shape()[axis2]
310    {
311
312        return Err("Dimensions \
313                    along contraction \
314                    axes must be \
315                    equal."
316            .to_string());
317    }
318
319    let n = a.shape()[axis1];
320
321    #[warn(clippy::collection_is_never_read)]
322    let mut new_shape = Vec::new();
323
324    for i in 0 .. a.ndim() {
325
326        if i != axis1 && i != axis2 {
327
328            new_shape
329                .push(a.shape()[i]);
330        }
331    }
332
333    // if new_shape.is_empty() {
334    //     let mut sum = 0.0;
335    //     for i in 0..n {
336    //         // This is actually a bit complex to index generically without recursion or specific tools
337    //         // For now, simpler implementation for trace-like contraction
338    //     }
339    // }
340
341    // Fallback: use tensordot with identity-like structure if needed, or implement manually
342    // For now, let's keep it simple or use a placeholder if it's too complex for a quick edit.
343    // Actually, sprs or ndarray might have better support.
344
345    // Simplified: Only support rank 2 (trace) for now if we want to be safe, or implement full.
346    if a.ndim() == 2 {
347
348        let mut sum = 0.0;
349
350        for i in 0 .. n {
351
352            sum += a[[i, i]];
353        }
354
355        return Ok(
356            ndarray::Array0::from_elem(
357                (),
358                sum,
359            )
360            .into_dyn(),
361        );
362    }
363
364    Err(
365        "General tensor contraction \
366         (trace) for rank > 2 not yet \
367         implemented."
368            .to_string(),
369    )
370}
371
372/// Computes the Frobenius norm of a tensor.
373#[must_use]
374
375pub fn norm(a: &ArrayD<f64>) -> f64 {
376
377    a.iter()
378        .map(|x| x * x)
379        .sum::<f64>()
380        .sqrt()
381}
382
383use serde::Deserialize;
384use serde::Serialize;
385
386/// A serializable representation of an N-dimensional tensor.
387#[derive(
388    Serialize, Deserialize, Debug, Clone,
389)]
390
391pub struct TensorData {
392    /// Dimensions of the tensor.
393    pub shape: Vec<usize>,
394    /// Flat vector of tensor data.
395    pub data: Vec<f64>,
396}
397
398impl From<&ArrayD<f64>> for TensorData {
399    fn from(arr: &ArrayD<f64>) -> Self {
400
401        Self {
402            shape : arr.shape().to_vec(),
403            data : arr
404                .clone()
405                .into_raw_vec_and_offset()
406                .0,
407        }
408    }
409}
410
411impl TensorData {
412    /// Converts back to an `ndarray::ArrayD`.
413    ///
414    /// # Errors
415    /// Returns an error if the shape and data are inconsistent.
416
417    pub fn to_arrayd(
418        &self
419    ) -> Result<ArrayD<f64>, String>
420    {
421
422        ArrayD::from_shape_vec(
423            IxDyn(&self.shape),
424            self.data.clone(),
425        )
426        .map_err(|e| e.to_string())
427    }
428}
429
430#[cfg(test)]
431
432mod tests {
433
434    use ndarray::array;
435
436    use super::*;
437
438    #[test]
439
440    fn test_tensordot() {
441
442        let a = array![
443            [1.0, 2.0],
444            [3.0, 4.0]
445        ]
446        .into_dyn();
447
448        let b = array![
449            [5.0, 6.0],
450            [7.0, 8.0]
451        ]
452        .into_dyn();
453
454        let res = tensordot(
455            &a,
456            &b,
457            &[1],
458            &[0],
459        )
460        .unwrap();
461
462        // Standard matrix multiplication
463        assert_eq!(
464            res.shape(),
465            &[2, 2]
466        );
467
468        assert_eq!(
469            res[[0, 0]],
470            1.0 * 5.0 + 2.0 * 7.0
471        );
472    }
473
474    #[test]
475
476    fn test_outer_product() {
477
478        let a =
479            array![1.0, 2.0].into_dyn();
480
481        let b =
482            array![3.0, 4.0].into_dyn();
483
484        let res = outer_product(&a, &b)
485            .unwrap();
486
487        assert_eq!(
488            res.shape(),
489            &[2, 2]
490        );
491
492        assert_eq!(res[[0, 0]], 3.0);
493
494        assert_eq!(res[[1, 1]], 8.0);
495    }
496}