#pragma once
#include <cute/config.hpp>
#include <cute/tensor_impl.hpp>
#include <cute/container/tuple.hpp>
namespace cute
{
template <class... Iters>
struct ZipIterator
{
using value_type = cute::tuple<iter_value_t<Iters>...>;
using element_type = cute::tuple<iter_element_t<Iters>...>;
using reference = value_type;
ZipIterator() = delete;
CUTE_HOST_DEVICE constexpr
ZipIterator(Iters... iters)
: iters_(iters...)
{}
CUTE_HOST_DEVICE constexpr
ZipIterator(cute::tuple<Iters...> const& iters)
: iters_(iters)
{}
CUTE_HOST_DEVICE constexpr
reference operator*() const {
return cute::apply(iters_, [](auto&&... args) { return reference(*args...); });
}
template <class... Index>
CUTE_HOST_DEVICE constexpr
ZipIterator operator+(cute::tuple<Index...> const& idxs) const {
static_assert(sizeof...(Index) == sizeof...(Iters), "Expect same number of offsets as iterators.");
return cute::transform(iters_, idxs, [](auto&& iter, auto&& idx) { return iter + idx; });
}
template <class... Index>
CUTE_HOST_DEVICE constexpr
reference operator[](cute::tuple<Index...> const& idxs) const {
return *(*this + idxs);
}
cute::tuple<Iters...> iters_;
};
template <class... Iters>
struct is_rmem<ZipIterator<Iters...>> : conjunction<is_rmem<Iters>...> {};
template <class... Iters>
struct is_smem<ZipIterator<Iters...>> : conjunction<is_smem<Iters>...> {};
template <class... Iters>
struct is_gmem<ZipIterator<Iters...>> : conjunction<is_gmem<Iters>...> {};
template <class... Iters>
struct is_tmem<ZipIterator<Iters...>> : conjunction<is_tmem<Iters>...> {};
template <class... Layouts>
struct ZipLayout
{
static constexpr int rank = (int(0) | ... | Layouts::rank);
static_assert((is_layout<Layouts>::value && ...), "All template parameters must be layouts");
static_assert(((Layouts::rank == rank) && ...), "All layouts must have the same rank");
CUTE_HOST_DEVICE constexpr
ZipLayout(Layouts const&... layouts)
: layouts_(layouts...)
{}
CUTE_HOST_DEVICE constexpr
ZipLayout(cute::tuple<Layouts...> const& layouts)
: layouts_(layouts)
{}
template <class Coord>
CUTE_HOST_DEVICE constexpr
auto
operator()(Coord const& coord) const {
if constexpr (has_underscore<Coord>::value) {
return ZipLayout(cute::transform(layouts_, [&] (auto layout) { return layout(coord); }));
} else {
return cute::transform(layouts_, [&] (auto layout) { return layout(coord); });
}
CUTE_GCC_UNREACHABLE;
}
template <class Coord0, class Coord1, class... Coords>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const {
return operator()(make_coord(c0,c1,cs...));
}
cute::tuple<Layouts...> layouts_;
};
template <class... Layouts>
struct is_layout<ZipLayout<Layouts...>> : true_type {};
template <class... Engines, class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
make_zip_tensor(Tensor<Engines,Layouts> const&... tensors)
{
return make_tensor(ZipIterator(tensors.data()...),
ZipLayout(tensors.layout()...));
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
unzip_tensor(Tensor<Engine,Layout> const& tensor)
{
return cute::transform(tensor.data().iters_, tensor.layout().layouts_,
[](auto iter, auto layout) { return make_tensor(iter, layout); });
}
template <int... Is, class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
rank(ZipLayout<Layouts...> const& layouts)
{
return rank<Is...>(get<0>(layouts.layouts_));
}
template <int... Is, class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
size(ZipLayout<Layouts...> const& layouts)
{
return size<Is...>(get<0>(layouts.layouts_));
}
template <int N, class... Layouts, class ShapeX = _1, class StrideX = _0>
CUTE_HOST_DEVICE constexpr
auto
append(ZipLayout<Layouts...> const& layouts,
Layout<ShapeX,StrideX> const& x = {})
{
return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return append<N>(t, x); }));
}
template <int N, class... Layouts, class ShapeX = _1, class StrideX = _0>
CUTE_HOST_DEVICE constexpr
auto
prepend(ZipLayout<Layouts...> const& layouts,
Layout<ShapeX,StrideX> const& x = {})
{
return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return prepend<N>(t, x); }));
}
template <class... Layouts, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
logical_divide(ZipLayout<Layouts...> const& layouts,
Tiler const& tiler)
{
return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return logical_divide(t, tiler); }));
}
template <class... Layouts, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
zipped_divide(ZipLayout<Layouts...> const& layouts,
Tiler const& tiler)
{
return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return zipped_divide(t, tiler); }));
}
template <class Coord, class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
slice_and_offset(Coord const& c, ZipLayout<Layouts...> const& layouts)
{
auto result = cute::zip(cute::transform(layouts.layouts_, [&c](auto const& layout) { return slice_and_offset(c, layout); }));
return cute::make_tuple(ZipLayout(get<0>(result)), get<1>(result));
}
}