#pragma once
#include <cute/config.hpp>
#include <cute/pointer_base.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/numeric/integral_ratio.hpp>
namespace cute
{
template <int Sparsity, class T>
struct sparse_elem
{
static constexpr int sparsity = Sparsity;
using raw_type = T;
T elem_;
CUTE_HOST_DEVICE constexpr
explicit sparse_elem(T const& elem = {}) : elem_(elem) {}
CUTE_HOST_DEVICE constexpr friend bool operator==(sparse_elem const& a, sparse_elem const& b) { return a.elem_ == b.elem_; }
CUTE_HOST_DEVICE constexpr friend bool operator!=(sparse_elem const& a, sparse_elem const& b) { return a.elem_ != b.elem_; }
CUTE_HOST_DEVICE constexpr friend bool operator< (sparse_elem const& a, sparse_elem const& b) { return a.elem_ < b.elem_; }
CUTE_HOST_DEVICE constexpr friend bool operator<=(sparse_elem const& a, sparse_elem const& b) { return a.elem_ <= b.elem_; }
CUTE_HOST_DEVICE constexpr friend bool operator> (sparse_elem const& a, sparse_elem const& b) { return a.elem_ > b.elem_; }
CUTE_HOST_DEVICE constexpr friend bool operator>=(sparse_elem const& a, sparse_elem const& b) { return a.elem_ >= b.elem_; }
};
template <class T>
struct is_sparse : false_type {};
template <class T>
struct is_sparse<T const> : is_sparse<T> {};
template <int S, class T>
struct is_sparse<sparse_elem<S,T>> : true_type {};
template<class T>
static constexpr auto is_sparse_v = is_sparse<T>::value;
template <int S, class T>
struct sizeof_bits<sparse_elem<S,T>> {
static constexpr auto value = cute::ratio(cute::Int<cute::sizeof_bits_v<T>>{}, cute::Int<S>{});
};
template <class T, class = void>
struct is_sparse_ptr : false_type {};
template <class T>
struct is_sparse_ptr<T, void_t<typename T::iterator>> : is_sparse_ptr<typename T::iterator> {};
template <int Sparsity, class Iterator>
struct sparse_ptr : iter_adaptor<Iterator, sparse_ptr<Sparsity, Iterator>>
{
using reference = typename iterator_traits<Iterator>::reference;
using element_type = typename iterator_traits<Iterator>::element_type;
using value_type = typename iterator_traits<Iterator>::value_type;
static_assert(is_sparse<value_type>::value, "Enforce sparse value-type");
static_assert(Sparsity == iter_value_t<Iterator>::sparsity, "Enforce sparsity S");
static_assert(not is_sparse_ptr<Iterator>::value, "Enforce sparse singleton");
template <class Index>
CUTE_HOST_DEVICE constexpr
sparse_ptr operator+(Index const& i) const {
assert(i % Sparsity == 0);
return {this->get() + i / Sparsity};
}
template <class Index>
CUTE_HOST_DEVICE constexpr
reference operator[](Index const& i) const {
return *(this->get() + i / Sparsity);
}
};
template <int S, class I>
struct is_sparse_ptr<sparse_ptr<S,I>> : true_type {};
template <int Sparsity, class Iter>
CUTE_HOST_DEVICE constexpr
auto
make_sparse_ptr(Iter const& iter) {
if constexpr (Sparsity == 1) {
return iter;
} else {
return sparse_ptr<Sparsity, Iter>{iter};
}
CUTE_GCC_UNREACHABLE;
}
template <class NewT, int S, class Iter>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(sparse_ptr<S,Iter> const& ptr) {
static_assert(not is_sparse<NewT>::value);
return recast_ptr<NewT>(ptr.get());
}
template <int S, class Iter>
CUTE_HOST_DEVICE void print(sparse_ptr<S,Iter> ptr)
{
printf("sparse<%d>_", S); print(ptr.get());
}
#if !defined(__CUDACC_RTC__)
template <int S, class Iter>
CUTE_HOST std::ostream& operator<<(std::ostream& os, sparse_ptr<S,Iter> ptr)
{
return os << "sparse<" << S << ">_" << ptr.get();
}
#endif
}