#ifndef KOKKOS_COMPLEX_HPP
#define KOKKOS_COMPLEX_HPP
#include <Kokkos_Atomic.hpp>
#include <Kokkos_NumericTraits.hpp>
#include <complex>
#include <iostream>
namespace Kokkos {
template<class RealType>
class complex {
private:
RealType re_, im_;
public:
typedef RealType value_type;
KOKKOS_INLINE_FUNCTION complex () :
re_ (0.0), im_ (0.0)
{}
KOKKOS_INLINE_FUNCTION complex (const complex<RealType>& src) :
re_ (src.re_), im_ (src.im_)
{}
KOKKOS_INLINE_FUNCTION complex (const volatile complex<RealType>& src) :
re_ (src.re_), im_ (src.im_)
{}
template<class InputRealType>
complex (const std::complex<InputRealType>& src) :
re_ (std::real (src)), im_ (std::imag (src))
{}
operator std::complex<RealType> () const {
return std::complex<RealType> (re_, im_);
}
template<class InputRealType>
KOKKOS_INLINE_FUNCTION complex (const InputRealType& val) :
re_ (val), im_ (static_cast<InputRealType>(0.0))
{}
KOKKOS_INLINE_FUNCTION complex( const RealType& re, const RealType& im):
re_ (re), im_ (im)
{}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION complex (const RealType1& re, const RealType2& im) :
re_ (re), im_ (im)
{}
template<class InputRealType>
KOKKOS_INLINE_FUNCTION
complex<RealType>& operator= (const complex<InputRealType>& src) {
re_ = src.re_;
im_ = src.im_;
return *this;
}
template<class InputRealType>
KOKKOS_INLINE_FUNCTION
void operator= (const complex<InputRealType>& src) volatile {
re_ = src.re_;
im_ = src.im_;
}
template<class InputRealType>
KOKKOS_INLINE_FUNCTION
volatile complex<RealType>& operator= (const volatile complex<InputRealType>& src) volatile {
re_ = src.re_;
im_ = src.im_;
return *this;
}
template<class InputRealType>
KOKKOS_INLINE_FUNCTION
complex<RealType>& operator= (const volatile complex<InputRealType>& src) {
re_ = src.re_;
im_ = src.im_;
return *this;
}
template<class InputRealType>
KOKKOS_INLINE_FUNCTION
complex<RealType>& operator= (const InputRealType& val) {
re_ = val;
im_ = static_cast<RealType> (0.0);
return *this;
}
template<class InputRealType>
KOKKOS_INLINE_FUNCTION
void operator= (const InputRealType& val) volatile {
re_ = val;
im_ = static_cast<RealType> (0.0);
}
template<class InputRealType>
complex<RealType>& operator= (const std::complex<InputRealType>& src) {
re_ = std::real (src);
im_ = std::imag (src);
return *this;
}
KOKKOS_INLINE_FUNCTION RealType& imag () {
return im_;
}
KOKKOS_INLINE_FUNCTION RealType& real () {
return re_;
}
KOKKOS_INLINE_FUNCTION const RealType imag () const {
return im_;
}
KOKKOS_INLINE_FUNCTION const RealType real () const {
return re_;
}
KOKKOS_INLINE_FUNCTION volatile RealType& imag () volatile {
return im_;
}
KOKKOS_INLINE_FUNCTION volatile RealType& real () volatile {
return re_;
}
KOKKOS_INLINE_FUNCTION const RealType imag () const volatile {
return im_;
}
KOKKOS_INLINE_FUNCTION const RealType real () const volatile {
return re_;
}
KOKKOS_INLINE_FUNCTION void imag (RealType v) {
im_ = v;
}
KOKKOS_INLINE_FUNCTION void real (RealType v) {
re_ = v;
}
template<typename InputRealType>
KOKKOS_INLINE_FUNCTION
complex<RealType>&
operator += (const complex<InputRealType>& src) {
static_assert(std::is_convertible<InputRealType,RealType>::value,
"InputRealType must be convertible to RealType");
re_ += src.re_;
im_ += src.im_;
return *this;
}
template<typename InputRealType>
KOKKOS_INLINE_FUNCTION
void
operator += (const volatile complex<InputRealType>& src) volatile {
static_assert(std::is_convertible<InputRealType,RealType>::value,
"InputRealType must be convertible to RealType");
re_ += src.re_;
im_ += src.im_;
}
KOKKOS_INLINE_FUNCTION
complex<RealType>&
operator += (const std::complex<RealType>& src) {
re_ += src.real();
im_ += src.imag();
return *this;
}
template<typename InputRealType>
KOKKOS_INLINE_FUNCTION
complex<RealType>&
operator += (const InputRealType& src) {
static_assert(std::is_convertible<InputRealType,RealType>::value,
"InputRealType must be convertible to RealType");
re_ += src;
return *this;
}
template<typename InputRealType>
KOKKOS_INLINE_FUNCTION
void
operator += (const volatile InputRealType& src) volatile {
static_assert(std::is_convertible<InputRealType,RealType>::value,
"InputRealType must be convertible to RealType");
re_ += src;
}
template<typename InputRealType>
KOKKOS_INLINE_FUNCTION
complex<RealType>&
operator -= (const complex<InputRealType>& src) {
static_assert(std::is_convertible<InputRealType,RealType>::value,
"InputRealType must be convertible to RealType");
re_ -= src.re_;
im_ -= src.im_;
return *this;
}
KOKKOS_INLINE_FUNCTION
complex<RealType>&
operator -= (const std::complex<RealType>& src) {
re_ -= src.real();
im_ -= src.imag();
return *this;
}
template<typename InputRealType>
KOKKOS_INLINE_FUNCTION
complex<RealType>&
operator -= (const InputRealType& src) {
static_assert(std::is_convertible<InputRealType,RealType>::value,
"InputRealType must be convertible to RealType");
re_ -= src;
return *this;
}
template<typename InputRealType>
KOKKOS_INLINE_FUNCTION
complex<RealType>&
operator *= (const complex<InputRealType>& src) {
static_assert(std::is_convertible<InputRealType,RealType>::value,
"InputRealType must be convertible to RealType");
const RealType realPart = re_ * src.re_ - im_ * src.im_;
const RealType imagPart = re_ * src.im_ + im_ * src.re_;
re_ = realPart;
im_ = imagPart;
return *this;
}
template<typename InputRealType>
KOKKOS_INLINE_FUNCTION
void
operator *= (const volatile complex<InputRealType>& src) volatile {
static_assert(std::is_convertible<InputRealType,RealType>::value,
"InputRealType must be convertible to RealType");
const RealType realPart = re_ * src.re_ - im_ * src.im_;
const RealType imagPart = re_ * src.im_ + im_ * src.re_;
re_ = realPart;
im_ = imagPart;
}
KOKKOS_INLINE_FUNCTION
complex<RealType>&
operator *= (const std::complex<RealType>& src) {
const RealType realPart = re_ * src.real() - im_ * src.imag();
const RealType imagPart = re_ * src.imag() + im_ * src.real();
re_ = realPart;
im_ = imagPart;
return *this;
}
template<typename InputRealType>
KOKKOS_INLINE_FUNCTION
complex<RealType>&
operator *= (const InputRealType& src) {
static_assert(std::is_convertible<InputRealType,RealType>::value,
"InputRealType must be convertible to RealType");
re_ *= src;
im_ *= src;
return *this;
}
template<typename InputRealType>
KOKKOS_INLINE_FUNCTION
void
operator *= (const volatile InputRealType& src) volatile {
static_assert(std::is_convertible<InputRealType,RealType>::value,
"InputRealType must be convertible to RealType");
re_ *= src;
im_ *= src;
}
template<typename InputRealType>
KOKKOS_INLINE_FUNCTION
complex<RealType>&
operator /= (const complex<InputRealType>& y) {
static_assert(std::is_convertible<InputRealType,RealType>::value,
"InputRealType must be convertible to RealType");
const RealType s = std::fabs (y.real ()) + std::fabs (y.imag ());
if (s == 0.0) {
this->re_ /= s;
this->im_ /= s;
}
else {
const complex<RealType> x_scaled (this->re_ / s, this->im_ / s);
const complex<RealType> y_conj_scaled (y.re_ / s, -(y.im_) / s);
const RealType y_scaled_abs = y_conj_scaled.re_ * y_conj_scaled.re_ +
y_conj_scaled.im_ * y_conj_scaled.im_; *this = x_scaled * y_conj_scaled;
*this /= y_scaled_abs;
}
return *this;
}
KOKKOS_INLINE_FUNCTION
complex<RealType>&
operator /= (const std::complex<RealType>& y) {
const RealType s = std::fabs (y.real ()) + std::fabs (y.imag ());
if (s == 0.0) {
this->re_ /= s;
this->im_ /= s;
}
else {
const complex<RealType> x_scaled (this->re_ / s, this->im_ / s);
const complex<RealType> y_conj_scaled (y.re_ / s, -(y.im_) / s);
const RealType y_scaled_abs = y_conj_scaled.re_ * y_conj_scaled.re_ +
y_conj_scaled.im_ * y_conj_scaled.im_; *this = x_scaled * y_conj_scaled;
*this /= y_scaled_abs;
}
return *this;
}
template<typename InputRealType>
KOKKOS_INLINE_FUNCTION
complex<RealType>&
operator /= (const InputRealType& src) {
static_assert(std::is_convertible<InputRealType,RealType>::value,
"InputRealType must be convertible to RealType");
re_ /= src;
im_ /= src;
return *this;
}
template<typename InputRealType>
KOKKOS_INLINE_FUNCTION
bool
operator == (const complex<InputRealType>& src) {
static_assert(std::is_convertible<InputRealType,RealType>::value,
"InputRealType must be convertible to RealType");
return (re_ == static_cast<RealType>(src.re_)) && (im_ == static_cast<RealType>(src.im_));
}
KOKKOS_INLINE_FUNCTION
bool
operator == (const std::complex<RealType>& src) {
return (re_ == src.real()) && (im_ == src.imag());
}
template<typename InputRealType>
KOKKOS_INLINE_FUNCTION
bool
operator == (const InputRealType src) {
static_assert(std::is_convertible<InputRealType,RealType>::value,
"InputRealType must be convertible to RealType");
return (re_ == static_cast<RealType>(src)) && (im_ == RealType(0));
}
template<typename InputRealType>
KOKKOS_INLINE_FUNCTION
bool
operator != (const complex<InputRealType>& src) {
static_assert(std::is_convertible<InputRealType,RealType>::value,
"InputRealType must be convertible to RealType");
return (re_ != static_cast<RealType>(src.re_)) || (im_ != static_cast<RealType>(src.im_));
}
KOKKOS_INLINE_FUNCTION
bool
operator != (const std::complex<RealType>& src) {
return (re_ != src.real()) || (im_ != src.imag());
}
template<typename InputRealType>
KOKKOS_INLINE_FUNCTION
bool
operator != (const InputRealType src) {
static_assert(std::is_convertible<InputRealType,RealType>::value,
"InputRealType must be convertible to RealType");
return (re_ != static_cast<RealType>(src)) || (im_ != RealType(0));
}
};
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
complex<typename std::common_type<RealType1,RealType2>::type>
operator + (const complex<RealType1>& x, const complex<RealType2>& y) {
return complex<typename std::common_type<RealType1,RealType2>::type > (x.real () + y.real (), x.imag () + y.imag ());
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
complex<typename std::common_type<RealType1,RealType2>::type>
operator + (const complex<RealType1>& x, const RealType2& y) {
return complex<typename std::common_type<RealType1,RealType2>::type> (x.real () + y , x.imag ());
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
complex<typename std::common_type<RealType1,RealType2>::type>
operator + (const RealType1& x, const complex<RealType2>& y) {
return complex<typename std::common_type<RealType1,RealType2>::type> (x + y.real (), y.imag ());
}
template<class RealType>
KOKKOS_INLINE_FUNCTION
complex<RealType>
operator + (const complex<RealType>& x) {
return x;
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
complex<typename std::common_type<RealType1,RealType2>::type>
operator - (const complex<RealType1>& x, const complex<RealType2>& y) {
return complex<typename std::common_type<RealType1,RealType2>::type> (x.real () - y.real (), x.imag () - y.imag ());
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
complex<typename std::common_type<RealType1,RealType2>::type>
operator - (const complex<RealType1>& x, const RealType2& y) {
return complex<typename std::common_type<RealType1,RealType2>::type> (x.real () - y , x.imag ());
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
complex<typename std::common_type<RealType1,RealType2>::type>
operator - (const RealType1& x, const complex<RealType2>& y) {
return complex<typename std::common_type<RealType1,RealType2>::type> (x - y.real (), - y.imag ());
}
template<class RealType>
KOKKOS_INLINE_FUNCTION
complex<RealType>
operator - (const complex<RealType>& x) {
return complex<RealType> (-x.real (), -x.imag ());
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
complex<typename std::common_type<RealType1,RealType2>::type>
operator * (const complex<RealType1>& x, const complex<RealType2>& y) {
return complex<typename std::common_type<RealType1,RealType2>::type> (x.real () * y.real () - x.imag () * y.imag (),
x.real () * y.imag () + x.imag () * y.real ());
}
template<class RealType1, class RealType2>
inline
complex<typename std::common_type<RealType1,RealType2>::type>
operator * (const std::complex<RealType1>& x, const complex<RealType2>& y) {
return complex<typename std::common_type<RealType1,RealType2>::type> (x.real () * y.real () - x.imag () * y.imag (),
x.real () * y.imag () + x.imag () * y.real ());
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
complex<typename std::common_type<RealType1,RealType2>::type>
operator * (const RealType1& x, const complex<RealType2>& y) {
return complex<typename std::common_type<RealType1,RealType2>::type> (x * y.real (), x * y.imag ());
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
complex<typename std::common_type<RealType1,RealType2>::type>
operator * (const complex<RealType1>& y, const RealType2& x) {
return complex<typename std::common_type<RealType1,RealType2>::type> (x * y.real (), x * y.imag ());
}
template<class RealType>
KOKKOS_INLINE_FUNCTION
RealType imag (const complex<RealType>& x) {
return x.imag ();
}
template<class RealType>
KOKKOS_INLINE_FUNCTION
RealType real (const complex<RealType>& x) {
return x.real ();
}
template<class RealType>
KOKKOS_INLINE_FUNCTION
RealType abs (const complex<RealType>& x) {
#ifndef __CUDA_ARCH__
using std::hypot;
#endif
return hypot(x.real(),x.imag());
}
template<class RealType>
KOKKOS_INLINE_FUNCTION
Kokkos::complex<RealType> pow (const complex<RealType>& x, const RealType& e) {
RealType r = abs(x);
RealType phi = std::atan(x.imag()/x.real());
return std::pow(r,e) * Kokkos::complex<RealType>(std::cos(phi*e),std::sin(phi*e));
}
template<class RealType>
KOKKOS_INLINE_FUNCTION
Kokkos::complex<RealType> sqrt (const complex<RealType>& x) {
RealType r = abs(x);
RealType phi = std::atan(x.imag()/x.real());
return std::sqrt(r) * Kokkos::complex<RealType>(std::cos(phi*0.5),std::sin(phi*0.5));
}
template<class RealType>
KOKKOS_INLINE_FUNCTION
complex<RealType> conj (const complex<RealType>& x) {
return complex<RealType> (real (x), -imag (x));
}
template<class RealType>
KOKKOS_INLINE_FUNCTION
complex<RealType> exp (const complex<RealType>& x) {
return std::exp(x.real()) * complex<RealType> (std::cos (x.imag()), std::sin(x.imag()));
}
template<class RealType>
inline
complex<RealType>
exp (const std::complex<RealType>& c) {
return complex<RealType>( std::exp( c.real() )*std::cos( c.imag() ), std::exp( c.real() )*std::sin( c.imag() ) );
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
complex<typename std::common_type<RealType1,RealType2>::type>
operator / (const complex<RealType1>& x, const RealType2& y) {
return complex<typename std::common_type<RealType1,RealType2>::type> (real (x) / y, imag (x) / y);
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
complex<typename std::common_type<RealType1,RealType2>::type>
operator / (const complex<RealType1>& x, const complex<RealType2>& y) {
typedef typename std::common_type<RealType1,RealType2>::type common_real_type;
const common_real_type s = std::fabs (real (y)) + std::fabs (imag (y));
if (s == 0.0) {
return complex<common_real_type> (real (x) / s, imag (x) / s);
}
else {
const complex<common_real_type> x_scaled (real (x) / s, imag (x) / s);
const complex<common_real_type> y_conj_scaled (real (y) / s, -imag (y) / s);
const RealType1 y_scaled_abs = real (y_conj_scaled) * real (y_conj_scaled) +
imag (y_conj_scaled) * imag (y_conj_scaled); complex<common_real_type> result = x_scaled * y_conj_scaled;
result /= y_scaled_abs;
return result;
}
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
complex<typename std::common_type<RealType1,RealType2>::type>
operator / (const RealType1& x, const complex<RealType2>& y) {
return complex<typename std::common_type<RealType1,RealType2>::type> (x)/y;
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
bool
operator == (const complex<RealType1>& x, const complex<RealType2>& y) {
typedef typename std::common_type<RealType1,RealType2>::type common_real_type;
return ( static_cast<common_real_type>(real (x)) == static_cast<common_real_type>(real (y)) &&
static_cast<common_real_type>(imag (x)) == static_cast<common_real_type>(imag (y)) );
}
template<class RealType1, class RealType2>
inline
bool
operator == (const std::complex<RealType1>& x, const complex<RealType2>& y) {
typedef typename std::common_type<RealType1,RealType2>::type common_real_type;
return ( static_cast<common_real_type>(std::real (x)) == static_cast<common_real_type>(real (y)) &&
static_cast<common_real_type>(std::imag (x)) == static_cast<common_real_type>(imag (y)) );
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
bool
operator == (const complex<RealType1>& x, const RealType2& y) {
typedef typename std::common_type<RealType1,RealType2>::type common_real_type;
return ( static_cast<common_real_type>(real (x)) == static_cast<common_real_type>(y) &&
static_cast<common_real_type>(imag (x)) == static_cast<common_real_type>(0.0) );
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
bool
operator == (const RealType1& x, const complex<RealType2>& y) {
return y == x;
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
bool
operator != (const complex<RealType1>& x, const complex<RealType2>& y) {
typedef typename std::common_type<RealType1,RealType2>::type common_real_type;
return ( static_cast<common_real_type>(real (x)) != static_cast<common_real_type>(real (y)) ||
static_cast<common_real_type>(imag (x)) != static_cast<common_real_type>(imag (y)) );
}
template<class RealType1, class RealType2>
inline
bool
operator != (const std::complex<RealType1>& x, const complex<RealType2>& y) {
typedef typename std::common_type<RealType1,RealType2>::type common_real_type;
return ( static_cast<common_real_type>(std::real (x)) != static_cast<common_real_type>(real (y)) ||
static_cast<common_real_type>(std::imag (x)) != static_cast<common_real_type>(imag (y)) );
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
bool
operator != (const complex<RealType1>& x, const RealType2& y) {
typedef typename std::common_type<RealType1,RealType2>::type common_real_type;
return ( static_cast<common_real_type>(real (x)) != static_cast<common_real_type>(y) ||
static_cast<common_real_type>(imag (x)) != static_cast<common_real_type>(0.0) );
}
template<class RealType1, class RealType2>
KOKKOS_INLINE_FUNCTION
bool
operator != (const RealType1& x, const complex<RealType2>& y) {
return y != x;
}
template<class RealType>
std::ostream& operator << (std::ostream& os, const complex<RealType>& x) {
const std::complex<RealType> x_std (Kokkos::real (x), Kokkos::imag (x));
os << x_std;
return os;
}
template<class RealType>
std::ostream& operator >> (std::ostream& os, complex<RealType>& x) {
std::complex<RealType> x_std;
os >> x_std;
x = x_std; return os;
}
template<class T>
struct reduction_identity<Kokkos::complex<T> > {
typedef reduction_identity<T> t_red_ident;
KOKKOS_FORCEINLINE_FUNCTION constexpr static Kokkos::complex<T> sum()
{return Kokkos::complex<T>(t_red_ident::sum(),t_red_ident::sum());}
KOKKOS_FORCEINLINE_FUNCTION constexpr static Kokkos::complex<T> prod()
{return Kokkos::complex<T>(t_red_ident::prod(),t_red_ident::sum());}
};
}
#endif