#pragma once
#include <cute/config.hpp>
#include <cute/numeric/numeric_types.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/util/type_traits.hpp>
namespace cute
{
namespace detail {
template <class T, class = void>
struct iter_ref { using type = decltype(*declval<T&>()); };
template <class T>
struct iter_ref<T,void_t<typename T::reference>> { using type = typename T::reference; };
}
template <class T>
using iter_reference = detail::iter_ref<T>;
template <class T>
using iter_reference_t = typename iter_reference<T>::type;
namespace detail {
template <class T, class = void>
struct iter_e { using type = remove_reference_t<typename iter_ref<T>::type>; };
template <class T>
struct iter_e<T,void_t<typename T::element_type>> { using type = typename T::element_type; };
}
template <class T>
using iter_element = detail::iter_e<T>;
template <class T>
using iter_element_t = typename iter_element<T>::type;
namespace detail {
template <class T, class = void>
struct iter_v { using type = remove_cv_t<typename iter_e<T>::type>; };
template <class T>
struct iter_v<T,void_t<typename T::value_type>> { using type = typename T::value_type; };
}
template <class T>
using iter_value = detail::iter_v<T>;
template <class T>
using iter_value_t = typename iter_value<T>::type;
template <class Iterator>
struct iterator_traits {
using reference = iter_reference_t<Iterator>;
using element_type = iter_element_t<Iterator>;
using value_type = iter_value_t<Iterator>;
};
namespace detail {
template <class T, class = void>
struct has_dereference : CUTE_STL_NAMESPACE::false_type {};
template <class T>
struct has_dereference<T, void_t<decltype(*declval<T&>())>> : CUTE_STL_NAMESPACE::true_type {};
}
template <class T>
using has_dereference = detail::has_dereference<T>;
template <class T>
CUTE_HOST_DEVICE constexpr
T*
raw_pointer_cast(T* ptr) {
return ptr;
}
template <class T>
CUTE_HOST_DEVICE constexpr
Int<0>
max_alignment(T*) {
return {};
}
template <class Iterator, class DerivedType>
struct iter_adaptor
{
using iterator = Iterator;
using reference = typename iterator_traits<iterator>::reference;
using element_type = typename iterator_traits<iterator>::element_type;
using value_type = typename iterator_traits<iterator>::value_type;
iterator ptr_;
CUTE_HOST_DEVICE constexpr
iter_adaptor(iterator ptr = {}) : ptr_(ptr) {}
CUTE_HOST_DEVICE constexpr
reference operator*() const { return *ptr_; }
template <class Index>
CUTE_HOST_DEVICE constexpr
reference operator[](Index const& i) const { return ptr_[i]; }
template <class Index>
CUTE_HOST_DEVICE constexpr
DerivedType operator+(Index const& i) const { return {ptr_ + i}; }
CUTE_HOST_DEVICE constexpr
iterator get() const { return ptr_; }
CUTE_HOST_DEVICE constexpr
friend bool operator==(DerivedType const& x, DerivedType const& y) { return x.ptr_ == y.ptr_; }
CUTE_HOST_DEVICE constexpr
friend bool operator!=(DerivedType const& x, DerivedType const& y) { return x.ptr_ != y.ptr_; }
CUTE_HOST_DEVICE constexpr
friend bool operator< (DerivedType const& x, DerivedType const& y) { return x.ptr_ < y.ptr_; }
CUTE_HOST_DEVICE constexpr
friend bool operator<=(DerivedType const& x, DerivedType const& y) { return x.ptr_ <= y.ptr_; }
CUTE_HOST_DEVICE constexpr
friend bool operator> (DerivedType const& x, DerivedType const& y) { return x.ptr_ > y.ptr_; }
CUTE_HOST_DEVICE constexpr
friend bool operator>=(DerivedType const& x, DerivedType const& y) { return x.ptr_ >= y.ptr_; }
};
template <class I, class D>
CUTE_HOST_DEVICE constexpr
auto
raw_pointer_cast(iter_adaptor<I,D> const& x) {
return raw_pointer_cast(x.ptr_);
}
template <class I, class D>
CUTE_HOST_DEVICE constexpr
auto
max_alignment(iter_adaptor<I,D> const& x) {
return max_alignment(x.ptr_);
}
template <class T = int>
struct counting_iterator
{
using index_type = T;
using value_type = T;
using reference = T;
index_type n_;
CUTE_HOST_DEVICE constexpr
counting_iterator(index_type n = 0) : n_(n) {}
CUTE_HOST_DEVICE constexpr
index_type operator*() const { return n_; }
CUTE_HOST_DEVICE constexpr
index_type operator[](index_type i) const { return n_ + i; }
CUTE_HOST_DEVICE constexpr
counting_iterator operator+(index_type i) const { return {n_ + i}; }
CUTE_HOST_DEVICE constexpr
counting_iterator& operator++() { ++n_; return *this; }
CUTE_HOST_DEVICE constexpr
counting_iterator operator++(int) { counting_iterator ret = *this; ++n_; return ret; }
CUTE_HOST_DEVICE constexpr
friend bool operator==(counting_iterator const& x, counting_iterator const& y) { return x.n_ == y.n_; }
CUTE_HOST_DEVICE constexpr
friend bool operator!=(counting_iterator const& x, counting_iterator const& y) { return x.n_ != y.n_; }
CUTE_HOST_DEVICE constexpr
friend bool operator< (counting_iterator const& x, counting_iterator const& y) { return x.n_ < y.n_; }
CUTE_HOST_DEVICE constexpr
friend bool operator<=(counting_iterator const& x, counting_iterator const& y) { return x.n_ <= y.n_; }
CUTE_HOST_DEVICE constexpr
friend bool operator> (counting_iterator const& x, counting_iterator const& y) { return x.n_ > y.n_; }
CUTE_HOST_DEVICE constexpr
friend bool operator>=(counting_iterator const& x, counting_iterator const& y) { return x.n_ >= y.n_; }
};
template <class T>
CUTE_HOST_DEVICE constexpr
T
raw_pointer_cast(counting_iterator<T> const& x) {
return x.n_;
}
template <class T>
CUTE_HOST_DEVICE void print(T const* const ptr)
{
printf("ptr["); print(sizeof_bits<T>::value); printf("b](%p)", ptr);
}
template <class T>
CUTE_HOST_DEVICE void print(counting_iterator<T> ptr)
{
printf("counting_iter("); print(ptr.n_); printf(")");
}
#if !defined(__CUDACC_RTC__)
template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, counting_iterator<T> ptr)
{
return os << "counting_iter(" << ptr.n_ << ")";
}
#endif
}