use std::collections::HashMap;
use std::convert::TryInto;
use std::ops::{Deref, DerefMut};
use ndarray::{s, Array1, Array2, ArrayView1};
use tch::Tensor;
pub struct LabelTensor {
inner: HashMap<String, Array2<i64>>,
}
impl LabelTensor {
fn from_shape(
encoder_names: impl IntoIterator<Item = impl Into<String>>,
batch_size: usize,
time_steps: usize,
) -> Self {
let labels = encoder_names
.into_iter()
.map(Into::into)
.map(|encoder_name| (encoder_name, Array2::zeros((batch_size, time_steps))))
.collect();
LabelTensor { inner: labels }
}
}
impl Deref for LabelTensor {
type Target = HashMap<String, Array2<i64>>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for LabelTensor {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
pub struct TensorBuilder {
current_sequence: usize,
inputs: Array2<i64>,
labels: Option<LabelTensor>,
token_mask: Array2<i32>,
seq_lens: Array1<i32>,
}
impl TensorBuilder {
pub fn new_without_labels(batch_size: usize, max_seq_len: usize) -> Self {
TensorBuilder {
current_sequence: 0,
inputs: Array2::zeros((batch_size, max_seq_len)),
token_mask: Array2::zeros((batch_size, max_seq_len)),
labels: None,
seq_lens: Array1::zeros((batch_size,)),
}
}
pub fn new_with_labels(
batch_size: usize,
max_seq_len: usize,
encoder_names: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
TensorBuilder {
current_sequence: 0,
inputs: Array2::zeros((batch_size, max_seq_len)),
token_mask: Array2::zeros((batch_size, max_seq_len)),
labels: Some(LabelTensor::from_shape(
encoder_names,
batch_size,
max_seq_len,
)),
seq_lens: Array1::zeros((batch_size,)),
}
}
}
impl TensorBuilder {
pub fn add_without_labels(&mut self, input: ArrayView1<i64>, token_mask: ArrayView1<i32>) {
assert!(
self.current_sequence < self.inputs.shape()[0],
"TensorBuilder is already filled."
);
#[allow(clippy::deref_addrof)]
self.inputs
.row_mut(self.current_sequence)
.slice_mut(s![0..input.len()])
.assign(&input);
self.token_mask
.row_mut(self.current_sequence)
.slice_mut(s![0..token_mask.len()])
.assign(&token_mask);
self.seq_lens[self.current_sequence] = input.len() as i32;
self.current_sequence += 1
}
pub fn add_with_labels(
&mut self,
input: ArrayView1<i64>,
labels: HashMap<&str, Array1<i64>>,
token_mask: ArrayView1<i32>,
) {
assert!(
self.current_sequence < self.inputs.shape()[0],
"TensorBuilder is already filled."
);
assert_eq!(
self.labels.as_ref().unwrap().len(),
labels.len(),
"Expected labels for {} encoders, got labels for {}",
self.labels.as_ref().unwrap().len(),
labels.len(),
);
for (encoder_name, labels) in labels {
assert_eq!(
labels.len(),
token_mask.len(),
"Input for encoder {} has length {}, but the mask length is {}",
encoder_name,
labels.len(),
token_mask.len()
);
#[allow(clippy::deref_addrof)]
self.labels
.as_mut()
.unwrap()
.get_mut(encoder_name)
.unwrap_or_else(|| panic!("Undefined encoder: {}", encoder_name))
.row_mut(self.current_sequence)
.slice_mut(s![0..input.len()])
.assign(&labels)
}
self.add_without_labels(input, token_mask);
}
}
#[derive(Debug)]
pub struct Tensors {
pub inputs: Tensor,
pub labels: Option<HashMap<String, Tensor>>,
pub token_mask: Tensor,
pub seq_lens: Tensor,
}
impl From<TensorBuilder> for Tensors {
fn from(builder: TensorBuilder) -> Self {
let labels = builder.labels.map(|labels| {
labels
.inner
.into_iter()
.map(|(encoder_name, matrix)| (encoder_name, matrix.try_into().unwrap()))
.collect()
});
Tensors {
inputs: builder.inputs.try_into().unwrap(),
labels,
token_mask: builder.token_mask.try_into().unwrap(),
seq_lens: builder.seq_lens.try_into().unwrap(),
}
}
}
#[cfg(test)]
mod tests {
use ndarray::arr1;
use tch::Tensor;
use super::{TensorBuilder, Tensors};
#[test]
fn instances_are_added() {
let mut builder: TensorBuilder = TensorBuilder::new_without_labels(2, 3);
builder.add_without_labels(arr1(&[1, 2]).view(), arr1(&[1, 0]).view());
builder.add_without_labels(arr1(&[3, 4, 5]).view(), arr1(&[1, 0, 1]).view());
let tensors: Tensors = builder.into();
assert_eq!(tensors.labels, None);
assert_eq!(
tensors.token_mask,
Tensor::of_slice(&[1, 0, 0, 1, 0, 1]).reshape(&[2, 3])
);
assert_eq!(tensors.seq_lens, Tensor::of_slice(&[2, 3]));
assert_eq!(
tensors.inputs,
Tensor::of_slice(&[1, 2, 0, 3, 4, 5]).reshape(&[2, 3])
);
}
#[test]
fn instances_are_added_with_labels() {
let mut builder: TensorBuilder = TensorBuilder::new_with_labels(2, 3, vec!["a", "b"]);
builder.add_with_labels(
arr1(&[1, 2]).view(),
vec![("a", arr1(&[11, 12])), ("b", arr1(&[21, 22]))]
.into_iter()
.collect(),
arr1(&[1, 0]).view(),
);
builder.add_with_labels(
arr1(&[3, 4, 5]).view(),
vec![("a", arr1(&[13, 14, 15])), ("b", arr1(&[23, 24, 25]))]
.into_iter()
.collect(),
arr1(&[1, 0, 1]).view(),
);
let tensors: Tensors = builder.into();
assert_eq!(
tensors.labels,
Some(
vec![
(
"a".to_string(),
Tensor::of_slice(&[11, 12, 0, 13, 14, 15]).reshape(&[2, 3])
),
(
"b".to_string(),
Tensor::of_slice(&[21, 22, 0, 23, 24, 25]).reshape(&[2, 3])
)
]
.into_iter()
.collect()
)
);
assert_eq!(
tensors.token_mask,
Tensor::of_slice(&[1, 0, 0, 1, 0, 1]).reshape(&[2, 3])
);
assert_eq!(tensors.seq_lens, Tensor::of_slice(&[2, 3]));
assert_eq!(
tensors.inputs,
Tensor::of_slice(&[1, 2, 0, 3, 4, 5]).reshape(&[2, 3])
);
}
#[should_panic]
#[test]
fn panics_when_labels_and_mask_len_differ() {
let mut builder: TensorBuilder = TensorBuilder::new_with_labels(2, 3, vec!["a", "b"]);
builder.add_with_labels(
arr1(&[1, 2]).view(),
vec![("a", arr1(&[11])), ("b", arr1(&[21, 22]))]
.into_iter()
.collect(),
arr1(&[1, 0]).view(),
);
}
#[should_panic]
#[test]
fn panics_when_too_many_instances_pushed() {
let mut builder: TensorBuilder = TensorBuilder::new_without_labels(1, 3);
builder.add_without_labels(arr1(&[1, 2]).view(), arr1(&[1, 0]).view());
builder.add_without_labels(arr1(&[3, 4, 5]).view(), arr1(&[1, 0, 1]).view());
}
#[should_panic]
#[test]
fn panics_when_labels_for_encoder_missing() {
let mut builder: TensorBuilder = TensorBuilder::new_with_labels(2, 3, vec!["a", "b"]);
builder.add_with_labels(
arr1(&[1, 2]).view(),
vec![("b", arr1(&[21, 22]))].into_iter().collect(),
arr1(&[1, 0]).view(),
);
}
}