use std::borrow::Borrow;
use tch::{TchError, Tensor};
pub fn f_flatten_tensors<I>(tensors: I) -> Result<Tensor, TchError>
where
I: IntoIterator,
I::Item: Borrow<Tensor>,
{
Tensor::f_cat(
&tensors
.into_iter()
.map(|t| t.borrow().f_flatten(0, -1))
.collect::<Result<Vec<_>, _>>()?,
0,
)
}
pub fn flatten_tensors<I>(tensors: I) -> Tensor
where
I: IntoIterator,
I::Item: Borrow<Tensor>,
{
f_flatten_tensors(tensors).unwrap()
}
fn shape_size(shape: &[i64]) -> i64 {
assert!(
shape.iter().all(|&d| d >= 0),
"Negative dimension in shape {:?}",
shape
);
shape.iter().product()
}
pub fn f_unflatten_tensors(vector: &Tensor, shapes: &[Vec<i64>]) -> Result<Vec<Tensor>, TchError> {
let sizes: Vec<_> = shapes.iter().map(|shape| shape_size(shape)).collect();
vector
.f_split_with_sizes(&sizes, 0)?
.iter()
.zip(shapes)
.map(|(t, shape)| t.f_reshape(shape))
.collect()
}
#[must_use]
pub fn unflatten_tensors(vector: &Tensor, shapes: &[Vec<i64>]) -> Vec<Tensor> {
f_unflatten_tensors(vector, shapes).unwrap()
}
pub fn f_flat_dot(a: &Tensor, b: &Tensor) -> Result<Tensor, TchError> {
a.f_flatten(0, -1)?.f_dot(&b.f_flatten(0, -1)?)
}
pub fn flat_dot(a: &Tensor, b: &Tensor) -> Tensor {
f_flat_dot(a, b).unwrap()
}
#[cfg(test)]
mod flatten {
use super::*;
#[test]
fn test_flatten_tensors() {
let a = Tensor::of_slice(&[1, 2, 3, 4, 5, 6]).reshape(&[2, 3]);
let b = Tensor::of_slice(&[10, 11, 12, 13]).reshape(&[4, 1, 1]);
let v = flatten_tensors(&[a, b]);
assert_eq!(v, Tensor::of_slice(&[1, 2, 3, 4, 5, 6, 10, 11, 12, 13]));
}
#[test]
fn test_unflatten_tensors() {
let v = Tensor::of_slice(&[1, 2, 3, 4, 5, 6, 10, 11, 12, 13]);
let shapes = [vec![2, 3], vec![4, 1, 1]];
let ts = unflatten_tensors(&v, &shapes);
let a = Tensor::of_slice(&[1, 2, 3, 4, 5, 6]).reshape(&[2, 3]);
let b = Tensor::of_slice(&[10, 11, 12, 13]).reshape(&[4, 1, 1]);
assert_eq!(ts, vec![a, b]);
}
}
#[cfg(test)]
mod flat_dot {
use super::*;
use tch::{Device, Kind};
#[test]
fn test_flat_dot() {
let a = Tensor::of_slice(&[1, 2, 3, 4]).reshape(&[2, 2]);
let b = Tensor::of_slice(&[10, 9, 8, 7]).reshape(&[2, 2]);
let expected = Tensor::scalar_tensor(10 + 9 * 2 + 8 * 3 + 4 * 7, (Kind::Int, Device::Cpu));
assert_eq!(flat_dot(&a, &b), expected);
}
}