use super::Tensor;
use super::handle::{ArcTensor, RcTensor};
use crate::view::{ArcTensorView, AsView, RcTensorView, TensorViewOps};
use std::ops::{Add, AddAssign};
macro_rules! impl_tensor_wrapper {
(
$wrapper:ident,
$view:ty,
$convert:expr
) => {
impl $wrapper {
pub fn as_view(&self) -> $view {
<$view>::new(self.clone())
}
pub fn broadcast_to(&self, target_shape: &[usize]) -> Result<$view, String> {
let view = self.as_view();
view.broadcast_to(target_shape)
}
pub fn transpose(&self, axes: &[usize]) -> Result<$view, String> {
let view = self.as_view();
view.transpose(axes)
}
pub fn T(&self) -> Result<$view, String> {
let view = self.as_view();
view.T()
}
}
impl Add for $wrapper {
type Output = Self;
fn add(self, other: Self) -> Self::Output {
let a_view = self.as_view();
let b_view = other.as_view();
let result_view = a_view + b_view;
result_view.into_handle()
}
}
impl AddAssign for $wrapper {
fn add_assign(&mut self, other: Self) {
let mut a_view = self.as_view();
let b_view = other.as_view();
a_view += b_view;
}
}
impl AsView for $wrapper {
type View = $view;
fn as_view(&self) -> Self::View {
<$view>::new(self.clone())
}
}
};
}
impl_tensor_wrapper!(
RcTensor,
RcTensorView,
Tensor::into_rc_raw );
impl_tensor_wrapper!(ArcTensor, ArcTensorView, Tensor::into_arc_raw);