1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
//! Utility functions using tch-rs. use ndarray::ArrayD; use tch::{TchError, Tensor}; /// Converts [ndarray::ArrayD] into tch Tensor. /// Borrowed from tch-rs. The original code didn't work with ndarray 0.14. pub fn try_from<T>(value: ArrayD<T>) -> Result<Tensor, TchError> where T: tch::kind::Element, { // TODO: Replace this with `?` once it works with `std::option::ErrorNone` let slice = match value.as_slice() { None => return Err(TchError::Convert("cannot convert to slice".to_string())), Some(v) => v, }; let tn = Tensor::f_of_slice(slice)?; let shape: Vec<i64> = value.shape().iter().map(|s| *s as i64).collect(); // Ok(tn.f_reshape(&shape)?) tn.f_reshape(&shape) }