tch 0.13.0

Rust wrappers for the PyTorch C++ api (libtorch).
Documentation
use tch::{Kind, Tensor};

mod test_utils;
use test_utils::*;

struct TmpFile(std::path::PathBuf);

impl TmpFile {
    fn create(base: &str) -> TmpFile {
        let filename = std::env::temp_dir().join(format!(
            "tch-{}-{}-{:?}",
            base,
            std::process::id(),
            std::thread::current().id(),
        ));
        TmpFile(filename)
    }
}

impl std::convert::AsRef<std::path::Path> for TmpFile {
    fn as_ref(&self) -> &std::path::Path {
        self.0.as_path()
    }
}

impl Drop for TmpFile {
    fn drop(&mut self) {
        std::fs::remove_file(&self.0).unwrap()
    }
}

#[test]
fn save_and_load() {
    let tmp_file = TmpFile::create("save-and-load");
    let vec = [3.0, 1.0, 4.0, 1.0, 5.0].to_vec();
    let t1 = Tensor::from_slice(&vec);
    t1.save(&tmp_file).unwrap();
    let t2 = Tensor::load(&tmp_file).unwrap();
    assert_eq!(vec_f64_from(&t2), vec)
}

#[test]
fn save_to_stream_and_load() {
    let tmp_file = TmpFile::create("write-stream");
    let vec = [3.0, 1.0, 4.0, 1.0, 5.0].to_vec();
    let t1 = Tensor::from_slice(&vec);
    t1.save_to_stream(std::fs::File::create(&tmp_file).unwrap()).unwrap();
    let t2 = Tensor::load(&tmp_file).unwrap();
    assert_eq!(vec_f64_from(&t2), vec)
}

#[test]
fn save_and_load_from_stream() {
    let tmp_file = TmpFile::create("read-stream");
    let vec = [3.0, 1.0, 4.0, 1.0, 5.0].to_vec();
    let t1 = Tensor::from_slice(&vec);
    t1.save(&tmp_file).unwrap();
    let reader = std::io::BufReader::new(std::fs::File::open(&tmp_file).unwrap());
    let t2 = Tensor::load_from_stream(reader).unwrap();
    assert_eq!(vec_f64_from(&t2), vec)
}

#[test]
fn save_and_load_multi() {
    let tmp_file = TmpFile::create("save-and-load-multi");
    let pi = Tensor::from_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]);
    let e = Tensor::from_slice(&[2, 7, 1, 8, 2, 8, 1, 8, 2, 8, 4, 6]);
    Tensor::save_multi(&[(&"pi", &pi), (&"e", &e)], &tmp_file).unwrap();
    let named_tensors = Tensor::load_multi(&tmp_file).unwrap();
    assert_eq!(named_tensors.len(), 2);
    assert_eq!(named_tensors[0].0, "pi");
    assert_eq!(named_tensors[1].0, "e");
    assert_eq!(from::<i64>(&named_tensors[1].1.sum(tch::Kind::Float)), 57);
}

#[test]
fn save_to_stream_and_load_multi() {
    let tmp_file = TmpFile::create("save-to-stream-and-load-multi");
    let pi = Tensor::from_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]);
    let e = Tensor::from_slice(&[2, 7, 1, 8, 2, 8, 1, 8, 2, 8, 4, 6]);
    Tensor::save_multi_to_stream(
        &[(&"pi", &pi), (&"e", &e)],
        std::fs::File::create(&tmp_file).unwrap(),
    )
    .unwrap();
    let named_tensors = Tensor::load_multi(&tmp_file).unwrap();
    assert_eq!(named_tensors.len(), 2);
    assert_eq!(named_tensors[0].0, "pi");
    assert_eq!(named_tensors[1].0, "e");
    assert_eq!(from::<i64>(&named_tensors[1].1.sum(tch::Kind::Float)), 57);
}

#[test]
fn save_and_load_multi_from_stream() {
    let tmp_file = TmpFile::create("save-and-load-multi-from-stream");
    let pi = Tensor::from_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]);
    let e = Tensor::from_slice(&[2, 7, 1, 8, 2, 8, 1, 8, 2, 8, 4, 6]);
    Tensor::save_multi(&[(&"pi", &pi), (&"e", &e)], &tmp_file).unwrap();
    let reader = std::io::BufReader::new(std::fs::File::open(&tmp_file).unwrap());
    let named_tensors = Tensor::load_multi_from_stream(reader).unwrap();
    assert_eq!(named_tensors.len(), 2);
    assert_eq!(named_tensors[0].0, "pi");
    assert_eq!(named_tensors[1].0, "e");
    assert_eq!(from::<i64>(&named_tensors[1].1.sum(tch::Kind::Float)), 57);
}

#[test]
fn save_and_load_npz() {
    let tmp_file = TmpFile::create("save-and-load-npz");
    let pi = Tensor::from_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]);
    let e = Tensor::from_slice(&[2, 7, 1, 8, 2, 8, 1, 8, 2, 8, 4, 6]);
    Tensor::write_npz(&[(&"pi", &pi), (&"e", &e)], &tmp_file).unwrap();
    let named_tensors = Tensor::read_npz(&tmp_file).unwrap();
    assert_eq!(named_tensors.len(), 2);
    assert_eq!(named_tensors[0].0, "pi");
    assert_eq!(named_tensors[1].0, "e");
    assert_eq!(from::<i64>(&named_tensors[1].1.sum(tch::Kind::Float)), 57);
}

#[test]
fn save_and_load_npz_half() {
    let tmp_file = TmpFile::create("save-and-load-npz-half");
    let pi = Tensor::from_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]).to_dtype(Kind::Half, true, false);
    let e =
        Tensor::from_slice(&[2, 7, 1, 8, 2, 8, 1, 8, 2, 8, 4, 6]).to_dtype(Kind::Half, true, false);
    Tensor::write_npz(&[(&"pi", &pi), (&"e", &e)], &tmp_file).unwrap();
    let named_tensors = Tensor::read_npz(&tmp_file).unwrap();
    assert_eq!(named_tensors.len(), 2);
    assert_eq!(named_tensors[0].0, "pi");
    assert_eq!(named_tensors[1].0, "e");
    assert_eq!(from::<i64>(&named_tensors[1].1.sum(tch::Kind::Float)), 57);
}

#[test]
fn save_and_load_npz_byte() {
    let tmp_file = TmpFile::create("save-and-load-npz-byte");
    let pi = Tensor::from_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]).to_dtype(Kind::Int8, true, false);
    let e =
        Tensor::from_slice(&[2, 7, 1, 8, 2, 8, 1, 8, 2, 8, 4, 6]).to_dtype(Kind::Int8, true, false);
    Tensor::write_npz(&[(&"pi", &pi), (&"e", &e)], &tmp_file).unwrap();
    let named_tensors = Tensor::read_npz(&tmp_file).unwrap();
    assert_eq!(named_tensors.len(), 2);
    assert_eq!(named_tensors[0].0, "pi");
    assert_eq!(named_tensors[1].0, "e");
    assert_eq!(from::<i8>(&named_tensors[1].1.sum(tch::Kind::Int8)), 57);
}

#[test]
fn save_and_load_npy() {
    let tmp_file = TmpFile::create("save-and-load-npy");
    let pi = Tensor::from_slice(&[3.0, 1.0, 4.0, 1.0, 5.0, 9.0]);
    pi.write_npy(&tmp_file).unwrap();
    let pi = Tensor::read_npy(&tmp_file).unwrap();
    assert_eq!(vec_f64_from(&pi), [3.0, 1.0, 4.0, 1.0, 5.0, 9.0]);
    let pi = pi.reshape([3, 1, 2]);
    pi.write_npy(&tmp_file).unwrap();
    let pi = Tensor::read_npy(&tmp_file).unwrap();
    assert_eq!(pi.size(), [3, 1, 2]);
    assert_eq!(vec_f64_from(&pi.flatten(0, -1)), [3.0, 1.0, 4.0, 1.0, 5.0, 9.0]);
}

#[test]
fn save_and_load_safetensors() {
    let tmp_file = TmpFile::create("save-and-load-safetensors");
    let pi = Tensor::from_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]);
    let e = Tensor::from_slice(&[2, 7, 1, 8, 2, 8, 1, 8, 2, 8, 4, 6]);
    Tensor::write_safetensors(&[(&"pi", &pi), (&"e", &e)], &tmp_file).unwrap();
    let named_tensors = Tensor::read_safetensors(&tmp_file).unwrap();
    assert_eq!(named_tensors.len(), 2);
    for (name, tensor) in named_tensors {
        match name.as_str() {
            "pi" => assert_eq!(from::<i64>(&tensor.sum(tch::Kind::Float)), 14),
            "e" => assert_eq!(from::<i64>(&tensor.sum(tch::Kind::Float)), 57),
            _ => panic!("unknow name tensors"),
        }
    }
}

#[test]
fn save_and_load_safetensors_half() {
    let tmp_file = TmpFile::create("save-and-load-safetensors-half");
    let pi = Tensor::from_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]).to_dtype(Kind::Half, true, false);
    let e =
        Tensor::from_slice(&[2, 7, 1, 8, 2, 8, 1, 8, 2, 8, 4, 6]).to_dtype(Kind::Half, true, false);
    Tensor::write_safetensors(&[(&"pi", &pi), (&"e", &e)], &tmp_file).unwrap();
    let named_tensors = Tensor::read_safetensors(&tmp_file).unwrap();
    assert_eq!(named_tensors.len(), 2);
    for (name, tensor) in named_tensors {
        match name.as_str() {
            "pi" => assert_eq!(from::<i64>(&tensor.sum(tch::Kind::Float)), 14),
            "e" => assert_eq!(from::<i64>(&tensor.sum(tch::Kind::Float)), 57),
            _ => panic!("unknow name tensors"),
        }
    }
}