use crate::{kind, kind::Kind, Device, IndexOp, TchError, Tensor};
use std::collections::HashMap;
#[derive(Debug)]
pub struct Iter2 {
xs: Tensor,
ys: Tensor,
batch_index: i64,
batch_size: i64,
total_size: i64,
device: Device,
return_smaller_last_batch: bool,
}
impl Iter2 {
pub fn f_new(xs: &Tensor, ys: &Tensor, batch_size: i64) -> Result<Iter2, TchError> {
let total_size = xs.size()[0];
if ys.size()[0] != total_size {
return Err(TchError::Shape(format!(
"different dimension for the two inputs {:?} {:?}",
xs, ys
)));
}
Ok(Iter2 {
xs: xs.shallow_clone(),
ys: ys.shallow_clone(),
batch_index: 0,
batch_size,
total_size,
device: Device::Cpu,
return_smaller_last_batch: false,
})
}
pub fn new(xs: &Tensor, ys: &Tensor, batch_size: i64) -> Iter2 {
Iter2::f_new(xs, ys, batch_size).unwrap()
}
pub fn shuffle(&mut self) -> &mut Iter2 {
let index = Tensor::randperm(self.total_size, (Kind::Int64, self.device));
self.xs = self.xs.index_select(0, &index);
self.ys = self.ys.index_select(0, &index);
self
}
#[allow(clippy::wrong_self_convention)]
pub fn to_device(&mut self, device: Device) -> &mut Iter2 {
self.device = device;
self
}
pub fn return_smaller_last_batch(&mut self) -> &mut Iter2 {
self.return_smaller_last_batch = true;
self
}
}
impl Iterator for Iter2 {
type Item = (Tensor, Tensor);
fn next(&mut self) -> Option<Self::Item> {
let start = self.batch_index * self.batch_size;
let size = std::cmp::min(self.batch_size, self.total_size - start);
if size <= 0 || (!self.return_smaller_last_batch && size < self.batch_size) {
None
} else {
self.batch_index += 1;
Some((
self.xs.i(start..start + size).to_device(self.device),
self.ys.i(start..start + size).to_device(self.device),
))
}
}
}
#[derive(Debug)]
pub struct TextData {
data: Tensor,
char_for_label: Vec<char>,
label_for_char: HashMap<u8, u8>,
}
#[derive(Debug)]
pub struct TextDataIter {
data: Tensor,
seq_len: i64,
batch_index: i64,
batch_size: i64,
indexes: Tensor,
indexes_len: i64,
}
impl TextData {
pub fn new<P: AsRef<std::path::Path>>(filename: P) -> Result<TextData, TchError> {
let mut buffer = std::fs::read(filename)?;
let mut label_for_char = HashMap::<u8, u8>::new();
let mut char_for_label = Vec::<char>::new();
for c in buffer.iter_mut() {
*c = *label_for_char.entry(*c).or_insert_with(|| {
let label = char_for_label.len() as u8;
char_for_label.push(*c as char);
label
})
}
Ok(TextData { data: Tensor::of_slice(&buffer), char_for_label, label_for_char })
}
pub fn labels(&self) -> i64 {
self.char_for_label.len() as i64
}
pub fn data(&self) -> Tensor {
self.data.shallow_clone()
}
pub fn label_to_char(&self, label: i64) -> char {
self.char_for_label[label as usize]
}
pub fn char_to_label(&self, c: char) -> Result<u8, TchError> {
match self.label_for_char.get(&(c as u8)) {
None => Err(TchError::Convert(format!("cannot find char {}", c))),
Some(v) => Ok(*v),
}
}
pub fn iter_shuffle(&self, seq_len: i64, batch_size: i64) -> TextDataIter {
let indexes_len = self.data.size()[0] - seq_len + 1;
TextDataIter {
data: self.data.shallow_clone(),
seq_len,
batch_index: 0,
batch_size,
indexes: Tensor::randperm(indexes_len, kind::INT64_CPU),
indexes_len,
}
}
}
impl Iterator for TextDataIter {
type Item = Tensor;
fn next(&mut self) -> Option<Self::Item> {
let start = self.batch_index * self.batch_size;
let size = std::cmp::min(self.batch_size, self.indexes_len - start);
if size < self.batch_size {
None
} else {
self.batch_index += 1;
let indexes = Vec::<i64>::from(&self.indexes.i(start..start + size));
let batch: Vec<_> = indexes.iter().map(|&i| self.data.i(i..i + self.seq_len)).collect();
let batch: Vec<_> = batch.iter().collect();
Some(Tensor::stack(&batch, 0))
}
}
}