1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
//! Torch utilities.
use std::borrow::Borrow;
use tch::{TchError, Tensor};

/// Flatten a set of tensors into a single vector.
///
/// # Errors
/// Returns any error raised `tch` or the Torch C++ api.
/// These are generally related to incorrect `Tensor` shapes or types.
pub fn f_flatten_tensors<I>(tensors: I) -> Result<Tensor, TchError>
where
    I: IntoIterator,
    I::Item: Borrow<Tensor>,
{
    Tensor::f_cat(
        &tensors
            .into_iter()
            .map(|t| t.borrow().f_flatten(0, -1))
            .collect::<Result<Vec<_>, _>>()?,
        0,
    )
}

/// Flatten a set of tensors into a single tensor.
///
/// # Warning
/// Undefined tensors are silently flattened to nothing.
///
/// # Panics
/// If [`f_flatten_tensors`] fails.
pub fn flatten_tensors<I>(tensors: I) -> Tensor
where
    I: IntoIterator,
    I::Item: Borrow<Tensor>,
{
    f_flatten_tensors(tensors).unwrap()
}

/// The number of elements in a Tensor with the given shape.
///
/// # Panics
/// If any dimension has negative size.
fn shape_size(shape: &[i64]) -> i64 {
    assert!(
        shape.iter().all(|&d| d >= 0),
        "Negative dimension in shape {:?}",
        shape
    );
    shape.iter().product()
}

/// Unflatten a vector into a set of tensors with the given shapes.
///
/// # Errors
/// Returns any error raised `tch` or the Torch C++ api.
/// These are generally related to incorrect `Tensor` shapes or types.
///
/// # Panics
/// Panics if any shape has a dimension with negative size.
pub fn f_unflatten_tensors(vector: &Tensor, shapes: &[Vec<i64>]) -> Result<Vec<Tensor>, TchError> {
    let sizes: Vec<_> = shapes.iter().map(|shape| shape_size(shape)).collect();
    vector
        .f_split_with_sizes(&sizes, 0)?
        .iter()
        .zip(shapes)
        .map(|(t, shape)| t.f_reshape(shape))
        .collect()
}

/// Unflatten a vector into a set of tensors with the given shapes.
///
/// # Panics
/// If [`f_unflatten_tensors`] fails.
#[must_use]
pub fn unflatten_tensors(vector: &Tensor, shapes: &[Vec<i64>]) -> Vec<Tensor> {
    f_unflatten_tensors(vector, shapes).unwrap()
}

/// Dot product of two flattened tensors.
///
/// # Errors
/// Returns any error raised `tch` or the Torch C++ api.
/// These are generally related to incorrect `Tensor` shapes or types.
pub fn f_flat_dot(a: &Tensor, b: &Tensor) -> Result<Tensor, TchError> {
    a.f_flatten(0, -1)?.f_dot(&b.f_flatten(0, -1)?)
}

/// Dot product of two flattened tensors.
///
/// Equivalently, the sum of the elementwise product of two tensors.
/// The shapes may differ so long as the total number of elements are the same.
///
/// # Panics
/// If [`f_flat_dot`] fails.
pub fn flat_dot(a: &Tensor, b: &Tensor) -> Tensor {
    f_flat_dot(a, b).unwrap()
}

#[cfg(test)]
mod flatten {
    use super::*;

    #[test]
    fn test_flatten_tensors() {
        let a = Tensor::of_slice(&[1, 2, 3, 4, 5, 6]).reshape(&[2, 3]);
        let b = Tensor::of_slice(&[10, 11, 12, 13]).reshape(&[4, 1, 1]);

        let v = flatten_tensors(&[a, b]);
        assert_eq!(v, Tensor::of_slice(&[1, 2, 3, 4, 5, 6, 10, 11, 12, 13]));
    }

    #[test]
    fn test_unflatten_tensors() {
        let v = Tensor::of_slice(&[1, 2, 3, 4, 5, 6, 10, 11, 12, 13]);
        let shapes = [vec![2, 3], vec![4, 1, 1]];
        let ts = unflatten_tensors(&v, &shapes);

        let a = Tensor::of_slice(&[1, 2, 3, 4, 5, 6]).reshape(&[2, 3]);
        let b = Tensor::of_slice(&[10, 11, 12, 13]).reshape(&[4, 1, 1]);
        assert_eq!(ts, vec![a, b]);
    }
}

#[cfg(test)]
mod flat_dot {
    use super::*;
    use tch::{Device, Kind};

    #[test]
    fn test_flat_dot() {
        let a = Tensor::of_slice(&[1, 2, 3, 4]).reshape(&[2, 2]);
        let b = Tensor::of_slice(&[10, 9, 8, 7]).reshape(&[2, 2]);
        let expected = Tensor::scalar_tensor(10 + 9 * 2 + 8 * 3 + 4 * 7, (Kind::Int, Device::Cpu));
        assert_eq!(flat_dot(&a, &b), expected);
    }
}