#pragma once
#include <cute/config.hpp>
#include <cute/tensor_impl.hpp>
#include <cute/tensor_predicate.hpp>
namespace cute
{
template <class Alpha,
class XEngine, class XLayout,
class Beta,
class YEngine, class YLayout,
class PrdTensor = TrivialPredTensor>
CUTE_HOST_DEVICE
void
axpby(Alpha const& alpha,
Tensor<XEngine, XLayout> const& x,
Beta const& beta,
Tensor<YEngine, YLayout> && y,
PrdTensor const& p = {})
{
return axpby(alpha, x, beta, y, p);
}
template <class Alpha,
class XEngine, class XLayout,
class Beta,
class YEngine, class YLayout,
class PrdTensor = TrivialPredTensor>
CUTE_HOST_DEVICE
void
axpby(Alpha const& alpha,
Tensor<XEngine, XLayout> const& x,
Beta const& beta,
Tensor<YEngine, YLayout> & y,
PrdTensor const& p = {})
{
auto isBetaZero = [&] () {
if constexpr (is_complex<Beta>::value) {
return beta.real() == Int<0>{} && beta.imag() == Int<0>{};
}
else {
return beta == Int<0>{};
}
CUTE_GCC_UNREACHABLE;
} ();
CUTE_UNROLL
for (int i = 0; i < size(x); ++i) {
if (p(i)) {
y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i));
}
}
}
}