use crate::prelude_dev::*;
use core::mem::transmute;
pub fn vecdot<TA, TB, DA, DB, B>(
a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
) -> Tensor<TA::Output, B, IxD>
where
TA: Mul<TB>,
DA: DimAPI,
DB: DimAPI,
B: DeviceVecdotAPI<TA, TB, TA::Output, DA, DB, IxD>
+ DeviceAPI<TA>
+ DeviceAPI<TB>
+ DeviceAPI<TA::Output>
+ DeviceCreationAnyAPI<TA::Output>,
{
vecdot_f(a, b, axes_pair).rstsr_unwrap()
}
pub fn vecdot_from<TA, TB, TC, DA, DB, DC, B>(
c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
) -> Result<()>
where
DA: DimAPI,
DB: DimAPI,
DC: DimAPI,
B: DeviceVecdotAPI<TA, TB, TC, DA, DB, DC>
+ DeviceAPI<TA>
+ DeviceAPI<TB>
+ DeviceAPI<TC>
+ DeviceAPI<MaybeUninit<TC>>,
{
vecdot_from_f(c, a, b, axes_pair)
}
pub fn vecdot_f<TA, TB, DA, DB, B>(
a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
) -> Result<Tensor<TA::Output, B, IxD>>
where
TA: Mul<TB>,
DA: DimAPI,
DB: DimAPI,
B: DeviceVecdotAPI<TA, TB, TA::Output, DA, DB, IxD>
+ DeviceAPI<TA>
+ DeviceAPI<TB>
+ DeviceAPI<TA::Output>
+ DeviceCreationAnyAPI<TA::Output>,
{
let (a, b) = (a.view(), b.view());
let device = a.device().clone();
rstsr_assert!(device.same_device(b.device()), DeviceMismatch)?;
let mut axes_pair = axes_pair.try_into().map_err(Into::into)?;
if axes_pair == AxesPairIndex::None {
axes_pair = AxesPairIndex::Val(-1);
}
let (axes_a, axes_b) = match axes_pair {
AxesPairIndex::None => unreachable!("already handled above"),
AxesPairIndex::Val(axis) => {
if axis < 0 {
rstsr_pattern!(
axis,
-(a.ndim().min(b.ndim()) as isize)..=-1,
InvalidValue,
"axis should be [-N, -1] where N is min(a.ndim, b.ndim)"
)?;
let axis_a = axis + a.ndim() as isize;
let axis_b = axis + b.ndim() as isize;
(vec![axis_a], vec![axis_b])
} else {
rstsr_pattern!(
axis,
0..(a.ndim().min(b.ndim()) as isize),
InvalidValue,
"axis should be [0, N) where N is min(a.ndim, b.ndim)"
)?;
(vec![axis], vec![axis])
}
},
AxesPairIndex::Pair(axes_a, axes_b) => {
let axes_a = normalize_axes_index(axes_a, a.ndim(), false, false)?;
let axes_b = normalize_axes_index(axes_b, b.ndim(), false, false)?;
rstsr_assert_eq!(
axes_a.len(),
axes_b.len(),
InvalidValue,
"axes_a and axes_b should have the same length"
)?;
(axes_a, axes_b)
},
};
let (las, lam) = a.layout().dim_split_axes(&axes_a)?;
let (lbs, lbm) = b.layout().dim_split_axes(&axes_b)?;
rstsr_assert_eq!(
las.shape(),
lbs.shape(),
InvalidLayout,
"the dimensions of a and b along the contracted axis should be the same"
)?;
let default_order = a.device().default_order();
let (lam_b, lbm_b) = broadcast_layout(&lam, &lbm, default_order)?;
let layout_c = match TensorIterOrder::default() {
TensorIterOrder::C => lam_b.shape().c(),
TensorIterOrder::F => lam_b.shape().f(),
_ => get_layout_for_binary_op(&lam_b, &lbm_b, default_order)?,
};
let mut storage_c = device.uninit_impl(layout_c.bounds_index()?.1)?;
device.vecdot(storage_c.raw_mut(), &layout_c, a.raw(), a.layout(), b.raw(), b.layout(), &axes_a, &axes_b)?;
unsafe { Tensor::new_f(B::assume_init_impl(storage_c)?, layout_c) }
}
pub fn vecdot_from_f<TA, TB, TC, DA, DB, DC, B>(
mut c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
) -> Result<()>
where
DA: DimAPI,
DB: DimAPI,
DC: DimAPI,
B: DeviceVecdotAPI<TA, TB, TC, DA, DB, DC>
+ DeviceAPI<TA>
+ DeviceAPI<TB>
+ DeviceAPI<TC>
+ DeviceAPI<MaybeUninit<TC>>,
{
let (a, b, mut c) = (a.view(), b.view(), c.view_mut());
let device = c.device().clone();
rstsr_assert!(device.same_device(a.device()), DeviceMismatch)?;
rstsr_assert!(device.same_device(b.device()), DeviceMismatch)?;
let mut axes_pair = axes_pair.try_into().map_err(Into::into)?;
if axes_pair == AxesPairIndex::None {
axes_pair = AxesPairIndex::Val(-1);
}
let (axes_a, axes_b) = match axes_pair {
AxesPairIndex::None => unreachable!("already handled above"),
AxesPairIndex::Val(axis) => {
if axis < 0 {
rstsr_pattern!(
axis,
-(a.ndim().min(b.ndim()) as isize)..=-1,
InvalidValue,
"axis should be [-N, -1] where N is min(a.ndim, b.ndim)"
)?;
let axis_a = axis + a.ndim() as isize;
let axis_b = axis + b.ndim() as isize;
(vec![axis_a], vec![axis_b])
} else {
rstsr_pattern!(
axis,
0..(a.ndim().min(b.ndim()) as isize),
InvalidValue,
"axis should be [0, N) where N is min(a.ndim, b.ndim)"
)?;
(vec![axis], vec![axis])
}
},
AxesPairIndex::Pair(axes_a, axes_b) => {
let axes_a = normalize_axes_index(axes_a, a.ndim(), false, false)?;
let axes_b = normalize_axes_index(axes_b, b.ndim(), false, false)?;
rstsr_assert_eq!(
axes_a.len(),
axes_b.len(),
InvalidValue,
"axes_a and axes_b should have the same length"
)?;
(axes_a, axes_b)
},
};
let (las, lam) = a.layout().dim_split_axes(&axes_a)?;
let (lbs, lbm) = b.layout().dim_split_axes(&axes_b)?;
rstsr_assert_eq!(
las.shape(),
lbs.shape(),
InvalidLayout,
"the dimensions of a and b along the contracted axis should be the same"
)?;
let shape_c_expect = broadcast_shapes_f(&[lam.shape().to_vec(), lbm.shape().to_vec()], device.default_order())?;
let shape_c = c.shape();
rstsr_assert_eq!(shape_c_expect, shape_c.as_ref(), InvalidLayout, "incompatible shapes in vecdot")?;
let c_layout = c.layout().clone();
let c_raw_mut = unsafe {
transmute::<&mut <B as DeviceRawAPI<TC>>::Raw, &mut <B as DeviceRawAPI<MaybeUninit<TC>>>::Raw>(c.raw_mut())
};
device.vecdot(c_raw_mut, &c_layout, a.raw(), a.layout(), b.raw(), b.layout(), &axes_a, &axes_b)
}
#[cfg(test)]
mod test {
use rstsr::prelude::*;
#[test]
fn test_vecdot() {
let mut device = DeviceCpuSerial::default();
device.set_default_order(RowMajor);
let a = rt::arange((6, &device)).into_shape((2, 3));
let b = rt::arange((6, 12, &device)).into_shape((2, 3));
let c = rt::vecdot(&a, &b, None);
println!("Result c: {c}");
let target = rt::tensor_from_nested!([23, 122], &device);
assert!(rt::allclose(&c, &target, None));
let a = rt::tensor_from_nested!([[0., 5., 0.], [0., 0., 10.], [0., 6., 8.]], &device);
let b = rt::tensor_from_nested!([0., 0.6, 0.8], &device);
let c = rt::vecdot(&a, &b, None);
println!("Result c: {c}");
let target = rt::tensor_from_nested!([3., 8., 10.], &device);
assert!(rt::allclose(&c, &target, None));
}
}