tch 0.6.1

Rust wrappers for the PyTorch C++ api (libtorch).
Documentation
use std::io::Write;
use tch::{data, IndexOp, Tensor};

#[test]
fn iter2() {
    let bsize: usize = 4;
    let vs: Vec<i64> = (0..1337).collect();
    let xs = Tensor::of_slice(&vs);
    let ys = Tensor::of_slice(&vs.iter().map(|x| x * 2).collect::<Vec<_>>());
    for (batch_xs, batch_ys) in data::Iter2::new(&xs, &ys, bsize as i64) {
        let xs = Vec::<i64>::from(&batch_xs);
        let ys = Vec::<i64>::from(&batch_ys);
        assert_eq!(xs.len(), bsize);
        assert_eq!(ys.len(), bsize);
        for i in 0..bsize {
            assert_eq!(ys[i], 2 * xs[i]);
            if i > 0 {
                assert_eq!(xs[i - 1] + 1, xs[i])
            }
        }
    }
    let mut all_in_order = true;
    for (batch_xs, batch_ys) in data::Iter2::new(&xs, &ys, bsize as i64).shuffle() {
        let xs = Vec::<i64>::from(&batch_xs);
        let ys = Vec::<i64>::from(&batch_ys);
        assert_eq!(xs.len(), bsize);
        assert_eq!(ys.len(), bsize);
        for i in 0..bsize {
            assert_eq!(ys[i], 2 * xs[i]);
            if i > 0 && xs[i - 1] + 1 != xs[i] {
                all_in_order = false
            }
        }
    }
    assert_eq!(all_in_order, false)
}

#[test]
fn text() {
    let filename = std::env::temp_dir().join(format!("tch-{}.txt", std::process::id()));
    {
        let mut file = std::fs::File::create(&filename).unwrap();
        file.write_all(b"01234567890123456789").unwrap();
    }
    let text_data = data::TextData::new(&filename).unwrap();
    for i in 0..10 {
        assert_eq!(
            text_data.label_to_char(i),
            i.to_string().chars().next().unwrap()
        );
    }
    for xs in text_data.iter_shuffle(2, 5) {
        let first_column_plus_one = (xs.i((.., ..1)) + 1).fmod(10);
        let second_column = xs.i((.., 1..=1));
        let err = i64::from(
            (first_column_plus_one - second_column)
                .pow_tensor_scalar(2)
                .sum(tch::Kind::Float),
        );
        assert_eq!(err, 0)
    }
}