relearn 0.3.1

A Reinforcement Learning library
use ndarray::{ArrayView, ArrayViewMut, Dim, Dimension, IntoDimension, Ix, IxDyn};
use std::marker::PhantomData;
use std::ptr::NonNull;
use std::{mem, slice};
use tch::{kind::Element, Device, Kind, Tensor};
use thiserror::Error;

/// An exclusive owner of a [`Tensor`] and its data.
/// Given an ordinary `Tensor`, it is impossible to reason about the lifetime of the data at
/// [`Tensor::data_ptr`]. Copies created by [`Tensor::shallow_clone`] share the same underlying
/// tensor object and can cause the data memory to be moved or reallocated at any time (for
/// example, by calling `Tensor::resize_`]).
/// To avoid this issue, `ExclusiveTensor` manages the creation of the tensor such that it has
/// exclusive access to the underlying data. An `ExclusiveTensor` can never provide `&Tensor`
/// references to the managed tensor.
/// The managed tensor always lives on the CPU device.
pub struct ExclusiveTensor<E, D>
    D: Dimension,
    tensor: Tensor,
    /// Track shape to avoid runtime checks
    shape: D,
    /// Number of elements in the tensor
    num_elements: usize,
    /// Track element type
    element_type: PhantomData<E>,

impl<E, D> ExclusiveTensor<E, D>
    E: Element,
    D: Dimension + IntoTorchShape,
    /// Create a zero-initialized tensor.
    pub fn zeros<Sh: IntoDimension<Dim = D>>(shape: Sh) -> Self {
        unsafe {
            Self::from_tensor_fn(shape, |shape, kind| {
                Tensor::zeros(shape, (kind, Device::Cpu))

    /// Create a one-initialized tensor.
    pub fn ones<Sh: IntoDimension<Dim = D>>(shape: Sh) -> Self {
        unsafe {
            Self::from_tensor_fn(shape, |shape, kind| {
                Tensor::ones(shape, (kind, Device::Cpu))

    /// Initialize given a tensor construction function.
    /// # Safety
    /// The constructed tensor must
    ///     * have number of elements corresponding to `shape`,
    ///     * have elements of type `E`,
    ///     * use `Device::Cpu`, and
    ///     * exclusively manage its own memory (e.g. no `shallow_clone`).
    /// # Panics
    /// If the total size of all elements exceeds `isize::MAX`.
    unsafe fn from_tensor_fn<Sh, F>(shape: Sh, f: F) -> Self
        Sh: IntoDimension<Dim = D>,
        F: FnOnce(&[i64], Kind) -> Tensor,
        let shape = shape.into_dimension();
        let num_elements = match shape.size_checked() {
            Some(size) if size < isize::MAX as usize => size,
            _ => panic!("number of elements must not exceed isize::MAX"),
        match num_elements.checked_mul(mem::size_of::<E>()) {
            Some(size) if size < isize::MAX as usize => {}
            _ => panic!("size of allocated memory must not exceed isize::MAX"),
        let tensor = f(shape.clone().into_torch_shape().as_ref(), E::KIND);
        Self {
            element_type: PhantomData,

impl<E, D: Dimension> ExclusiveTensor<E, D> {
    /// Convert into the inner tensor.
    #[allow(clippy::missing_const_for_fn)] // false positive; cannot run destructors
    pub fn into_tensor(self) -> Tensor {

impl<E> ExclusiveTensor<E, IxDyn>
    E: Element,
    /// Try to create a dynamic-shape tensor by deep copying from a `Tensor`.
    /// The tensor `dtype` must match the element type `E`.
    pub fn try_copy_from(tensor: &Tensor) -> Result<Self, ExclusiveTensorError> {
        let kind = tensor.kind();
        if kind != E::KIND {
            return Err(ExclusiveTensorError::MismatchedKind {
                expected: E::KIND,
                actual: kind,

        let shape_vec: Vec<usize> = tensor
            .map(|d| d.try_into().unwrap()) // i64 -> usize
        let shape = IxDyn(&shape_vec);
        unsafe {
            Ok(Self::from_tensor_fn(shape, |shape, kind| {
                let mut new_tensor = Tensor::zeros(shape, (kind, Device::Cpu));

impl<E, D> ExclusiveTensor<E, D>
    E: Element,
    D: Dimension,
    /// View the tensor data as a slice.
    pub fn as_slice(&self) -> &[E] {
        // # Safety
        // ✓ **data must be valid for reads for `len * mem::size_of::<T>()` many bytes,
        //   and it must be properly aligned.**
        //   The tensor is storing that amount of data at the pointer, so long as the size is
        //   non-empty. The pointer is NonNull::dangling for empty tensors.
        // ✓ **data must point to len consecutive properly initialized values of type T.**
        //   The tensor has been fully initialized with valid data.
        // ✓ **The memory referenced by the returned slice must not be mutated for the duration of
        //   lifetime 'a, except inside an UnsafeCell.**
        //   Managed by the lifetime of self, which has exclusive access to the tensor memory.
        // ✓ **The total size len * mem::size_of::<T>() must be no larger than isize::MAX.**
        //   Asserted in construction and probably must hold for Tensor anyways.
        unsafe { slice::from_raw_parts(self.data_ptr().as_ptr(), self.num_elements) }

    /// View the tensor data as a mutable slice.
    pub fn as_slice_mut(&mut self) -> &mut [E] {
        // # Safety
        // See `Self::as_slice` implementation
        unsafe { slice::from_raw_parts_mut(self.data_ptr().as_ptr(), self.num_elements) }

    /// View as an n-dimensional array.
    pub fn array_view(&self) -> ArrayView<E, D> {
        // # Safety
        // ✓ **Elements must live as long as 'a (in ArrayView<'a, E, D>).**
        //   Managed by the lifetime of self, which has exclusive access to the tensor memory.
        // ✓ **ptr must be non-null and aligned, and it must be safe to .offset() ptr by zero.**
        //   This is up to torch but it should be true for non-empty tensors since data is being
        //   stored at this pointer value.
        //   In the case of empty tensors, the data pointer is NonNull::dangling.
        // ? **It must be safe to .offset() the pointer repeatedly along all axes and calculate the
        //   counts for the .offset() calls without overflow, even if the array is empty or the
        //   elements are zero-sized.**
        //   Up to pytorch but again it should be true since the full tensor's worth of data is
        //   being stored at this pointer value.
        // ✓ **The product of non-zero axis lengths must not exceed isize::MAX.**
        //   Asserted in constructors; but probably a similar constraint applies to the tensor
        //   creation by pytorch.
        // ✓ **Strides must be non-negative.**
        //   Dimension as IntoDimension as Into<StrideShape> always uses C-style strides
        //   which have a value of 0 or 1 depending on the array shape.
        unsafe { ArrayView::from_shape_ptr(self.shape.clone(), self.data_ptr().as_ptr()) }

    /// View as a mutable n-dimensional array.
    pub fn array_view_mut(&mut self) -> ArrayViewMut<E, D> {
        // # Safety
        // See `Self::array_view` implementation
        unsafe { ArrayViewMut::from_shape_ptr(self.shape.clone(), self.data_ptr().as_ptr()) }

    /// The current tensor data pointer; may be dangling if the tensor is empty.
    /// This is not cached in case additional methods are added that can cause the tensor to
    /// re-allocate.
    fn data_ptr(&self) -> NonNull<E> {
        if self.num_elements == 0 {
        } else {
            NonNull::new(self.tensor.data_ptr().cast()).expect("unexpected null data_ptr")

impl<E, D: Dimension> From<ExclusiveTensor<E, D>> for Tensor {
    fn from(exclusive: ExclusiveTensor<E, D>) -> Self {

impl<'a, E, D> From<&'a ExclusiveTensor<E, D>> for ArrayView<'a, E, D>
    E: Element,
    D: Dimension,
    fn from(exclusive: &'a ExclusiveTensor<E, D>) -> Self {

fn to_i64(x: Ix) -> i64 {
    x.try_into().expect("dimension too large")

/// Convert an ndarray-style dimension into the shape type used by [`tch`].
pub trait IntoTorchShape {
    type TorchDim: AsRef<[i64]>;
    fn into_torch_shape(self) -> Self::TorchDim;
impl IntoTorchShape for IxDyn {
    type TorchDim = Vec<i64>;
    fn into_torch_shape(self) -> Self::TorchDim {
            .map(|&x| to_i64(x))
impl IntoTorchShape for Dim<[Ix; 0]> {
    type TorchDim = [i64; 0];
    fn into_torch_shape(self) -> Self::TorchDim {
impl IntoTorchShape for Dim<[Ix; 1]> {
    type TorchDim = [i64; 1];
    fn into_torch_shape(self) -> Self::TorchDim {
        [self.into_pattern() as _]
impl IntoTorchShape for Dim<[Ix; 2]> {
    type TorchDim = [i64; 2];
    fn into_torch_shape(self) -> Self::TorchDim {
        let (a, b) = self.into_pattern();
        [to_i64(a), to_i64(b)]
impl IntoTorchShape for Dim<[Ix; 3]> {
    type TorchDim = [i64; 3];
    fn into_torch_shape(self) -> Self::TorchDim {
        let (a, b, c) = self.into_pattern();
        [to_i64(a), to_i64(b), to_i64(c)]
impl IntoTorchShape for Dim<[Ix; 4]> {
    type TorchDim = [i64; 4];
    fn into_torch_shape(self) -> Self::TorchDim {
        let (a, b, c, d) = self.into_pattern();
        [to_i64(a), to_i64(b), to_i64(c), to_i64(d)]
impl IntoTorchShape for Dim<[Ix; 5]> {
    type TorchDim = [i64; 5];
    fn into_torch_shape(self) -> Self::TorchDim {
        let (a, b, c, d, e) = self.into_pattern();
        [to_i64(a), to_i64(b), to_i64(c), to_i64(d), to_i64(e)]
impl IntoTorchShape for Dim<[Ix; 6]> {
    type TorchDim = [i64; 6];
    fn into_torch_shape(self) -> Self::TorchDim {
        let (a, b, c, d, e, f) = self.into_pattern();

#[derive(Error, Debug, Clone, PartialEq, Eq, Hash)]
pub enum ExclusiveTensorError {
    #[error("expected kind {expected:?} but got {actual:?}")]
    MismatchedKind { expected: Kind, actual: Kind },

mod tests {
    use super::*;
    use ndarray::{arr2, Array};

    fn zeros() {
        let u = ExclusiveTensor::<f32, _>::zeros([2, 4, 3]);
        let tensor: Tensor = u.into();
        assert_eq!(tensor.size(), [2, 4, 3]);
        assert_eq!(tensor.kind(), Kind::Float);
        assert_eq!(tensor.device(), Device::Cpu);
            Tensor::zeros(&[2, 4, 3], (Kind::Float, Device::Cpu))

    fn ones() {
        let u = ExclusiveTensor::<f32, _>::ones([2, 4, 3]);
        let tensor: Tensor = u.into();
        assert_eq!(tensor.size(), [2, 4, 3]);
        assert_eq!(tensor.kind(), Kind::Float);
        assert_eq!(tensor.device(), Device::Cpu);
        assert_eq!(tensor, Tensor::ones(&[2, 4, 3], (Kind::Float, Device::Cpu)));

    fn try_copy_from() {
        let src = Tensor::of_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(&[2, 3]);
        let copy = ExclusiveTensor::<f32, _>::try_copy_from(&src).unwrap();
        let tensor: Tensor = copy.into();
        assert_eq!(tensor.size(), [2, 3]);
        assert_eq!(tensor.kind(), Kind::Float);
        assert_eq!(tensor.device(), Device::Cpu);
        assert_eq!(tensor, src);

    fn try_copy_from_cuda_if_available() {
        let src = Tensor::of_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0])
        let copy = ExclusiveTensor::<f32, _>::try_copy_from(&src).unwrap();
        let tensor: Tensor = copy.into();
        assert_eq!(tensor.device(), Device::Cpu);
        assert_eq!(tensor, src.to_device(Device::Cpu));

    fn try_copy_from_mismatched_type() {
        let src = Tensor::of_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
            ExclusiveTensor::<f64, _>::try_copy_from(&src).unwrap_err(),
            ExclusiveTensorError::MismatchedKind {
                expected: Kind::Double,
                actual: Kind::Float

    fn slice_f64() {
        let u = ExclusiveTensor::<f64, _>::ones([3, 1, 2]);
        assert_eq!(u.as_slice().len(), 6);
        assert_eq!(u.as_slice(), &[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);

    fn slice_mut_i16() {
        let mut u = ExclusiveTensor::<i16, _>::ones([3, 1, 2]);
        assert_eq!(u.as_slice_mut().len(), 6);
        for (i, x) in u.as_slice_mut().iter_mut().enumerate() {
            *x = i.try_into().unwrap()
        assert_eq!(u.as_slice(), &[0, 1, 2, 3, 4, 5]);
        let tensor: Tensor = u.into();
            Tensor::of_slice(&[0, 1, 2, 3, 4, 5]).reshape(&[3, 1, 2])

    fn array_view_f32() {
        let u = ExclusiveTensor::<f32, _>::ones([2, 4, 3]);
        let view = u.array_view();
        assert_eq!(view.dim(), (2, 4, 3));
        assert_eq!(view, Array::<f32, _>::ones((2, 4, 3)));

    fn array_view_i64_scalar() {
        let u = ExclusiveTensor::<i64, _>::ones([]);
        let view = u.array_view();
        assert_eq!(view.dim(), ());
        assert_eq!(view.into_scalar(), &1);

    fn array_view_f32_empty() {
        let u = ExclusiveTensor::<f32, _>::ones([0]);
        let view = u.array_view();
        assert_eq!(view.dim(), 0);

    fn array_view_mut() {
        let mut u = ExclusiveTensor::<i32, _>::ones([3, 4]);
        let mut view = u.array_view_mut();
        for (i, mut row) in view.rows_mut().into_iter().enumerate() {
            for (j, cell) in row.iter_mut().enumerate() {
                *cell = (i * 10 + j).try_into().unwrap();
        let expected = arr2(&[[0, 1, 2, 3], [10, 11, 12, 13], [20, 21, 22, 23]]);
        assert_eq!(view, expected); // Compare as arrays
        let t: Tensor = u.into();
        let expected: Tensor = expected.try_into().unwrap();
        assert_eq!(t, expected); // Compare as tensors

    fn array_view_mut_empty() {
        let mut u = ExclusiveTensor::<f32, _>::ones([2, 0, 3]);
        let mut view = u.array_view_mut();