#pragma once
#include <cute/config.hpp>
#include <cute/layout.hpp>
#include <cute/layout_composed.hpp>
#include <cute/pointer.hpp>
#include <cute/pointer_base.hpp>
#include <cute/container/array_aligned.hpp>
#include <cute/container/array_subbyte.hpp>
#include <cute/container/tuple.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/util/type_traits.hpp>
namespace cute
{
template <class T, size_t N>
struct ArrayEngine
{
using Storage = typename conditional<(sizeof_bits<T>::value % 8 == 0),
array_aligned<T,N>,
array_subbyte<T,N>>::type;
using iterator = typename Storage::iterator;
using reference = typename iterator_traits<iterator>::reference;
using element_type = typename iterator_traits<iterator>::element_type;
using value_type = typename iterator_traits<iterator>::value_type;
Storage storage_;
CUTE_HOST_DEVICE constexpr auto begin() const { return storage_.begin(); }
CUTE_HOST_DEVICE constexpr auto begin() { return storage_.begin(); }
};
template <int S, class T, size_t N>
struct ArrayEngine<sparse_elem<S,T>, N>
{
static_assert(N % S == 0, "Expected a multiple of the sparsity.");
using value_type = sparse_elem<S,T>;
using Storage = typename conditional<(sizeof_bits<T>::value % 8 == 0),
array_aligned<T,N/S>,
array_subbyte<T,N/S>>::type;
using iterator = sparse_ptr<S,sparse_elem<S,T>*>;
using reference = typename iterator_traits<iterator>::reference;
using element_type = typename iterator_traits<iterator>::element_type;
Storage storage_;
CUTE_HOST_DEVICE constexpr auto begin() const { return recast_ptr<value_type>(storage_.begin()); }
CUTE_HOST_DEVICE constexpr auto begin() { return recast_ptr<value_type>(storage_.begin()); }
};
template <class Iterator>
struct ViewEngine
{
using iterator = Iterator;
using reference = typename iterator_traits<iterator>::reference;
using element_type = typename iterator_traits<iterator>::element_type;
using value_type = typename iterator_traits<iterator>::value_type;
iterator storage_;
CUTE_HOST_DEVICE constexpr iterator const& begin() const { return storage_; }
CUTE_HOST_DEVICE constexpr iterator & begin() { return storage_; }
};
template <class Iterator>
struct ConstViewEngine
{
using iterator = Iterator;
using reference = typename iterator_traits<iterator>::reference;
using element_type = typename iterator_traits<iterator>::element_type;
using value_type = typename iterator_traits<iterator>::value_type;
iterator storage_;
CUTE_HOST_DEVICE constexpr iterator const& begin() const { return storage_; }
};
template <class Engine, class Layout>
struct Tensor
{
using iterator = typename Engine::iterator;
using value_type = typename Engine::value_type;
using element_type = typename Engine::element_type;
using reference = typename Engine::reference;
using engine_type = Engine;
using layout_type = Layout;
CUTE_HOST_DEVICE constexpr
Tensor() {}
CUTE_HOST_DEVICE constexpr
Tensor(Engine const& engine, Layout const& layout)
: rep_(layout, engine) {
}
static constexpr int rank = Layout::rank;
CUTE_HOST_DEVICE constexpr
decltype(auto)
tensor() const {
return *this;
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
engine() const {
return get<1>(rep_);
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
engine() {
return get<1>(rep_);
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
data() const {
return engine().begin();
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
data() {
return engine().begin();
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
layout() const {
return get<0>(rep_);
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
shape() const {
return layout().shape();
}
CUTE_HOST_DEVICE constexpr
auto
size() const {
return cute::size(shape());
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
stride() const {
return layout().stride();
}
template <class Coord>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator[](Coord const& coord) {
return data()[layout()(coord)];
}
template <class Coord>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator[](Coord const& coord) const {
return data()[layout()(coord)];
}
template <class Coord>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(Coord const& coord) {
if constexpr (has_underscore<Coord>::value) {
auto [sliced_layout,offset] = slice_and_offset(coord, layout());
return make_tensor(data() + offset, sliced_layout);
} else {
return data()[layout()(coord)];
}
CUTE_GCC_UNREACHABLE;
}
template <class Coord>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(Coord const& coord) const {
if constexpr (has_underscore<Coord>::value) {
auto [sliced_layout,offset] = slice_and_offset(coord, layout());
return make_tensor(data() + offset, sliced_layout);
} else {
return data()[layout()(coord)];
}
CUTE_GCC_UNREACHABLE;
}
template <class Coord0, class Coord1, class... Coords>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) {
return operator()(make_coord(c0,c1,cs...));
}
template <class Coord0, class Coord1, class... Coords>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const {
return operator()(make_coord(c0,c1,cs...));
}
template <class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
compose(Layouts const&... layouts) {
return make_tensor(data(), layout().compose(layouts...));
}
template <class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
compose(Layouts const&... layouts) const {
return make_tensor(data(), layout().compose(layouts...));
}
template <class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
tile(Layouts const&... layouts) {
return make_tensor(data(), layout().tile(layouts...));
}
template <class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
tile(Layouts const&... layouts) const {
return make_tensor(data(), layout().tile(layouts...));
}
template <class Int,
__CUTE_REQUIRES(is_integral<Int>::value)>
CUTE_HOST_DEVICE constexpr
auto
get_1d_coord(Int const& linear_idx) const {
return layout().get_1d_coord(linear_idx);
}
template <class Int,
__CUTE_REQUIRES(is_integral<Int>::value)>
CUTE_HOST_DEVICE constexpr
auto
get_hier_coord(Int const& linear_idx) const {
return layout().get_hier_coord(linear_idx);
}
template <class Int,
__CUTE_REQUIRES(is_integral<Int>::value)>
CUTE_HOST_DEVICE constexpr
auto
get_flat_coord(Int const& linear_idx) const {
return layout().get_flat_coord(linear_idx);
}
cute::tuple<layout_type, engine_type> rep_;
};
template <class T>
struct is_tensor : false_type {};
template <class Engine, class Layout>
struct is_tensor<Tensor<Engine,Layout>> : true_type {};
template <class T>
constexpr bool is_tensor_v = is_tensor<T>::value;
template <class T>
struct MakeTensor
{
template <class Arg0, class... Args>
CUTE_HOST_DEVICE constexpr auto
operator()(Arg0 const& arg0, Args const&... args) const
{
if constexpr (has_dereference<Arg0>::value) {
using Engine = ViewEngine<Arg0>;
if constexpr (sizeof...(Args) == 1 && (is_layout<Args>::value && ...)) {
return Tensor{Engine{arg0}, args...};
} else {
return Tensor{Engine{arg0}, make_layout(args...)};
}
} else {
static_assert((is_static<Arg0>::value && ... && is_static<Args>::value),
"Dynamic owning tensors not supported");
if constexpr (sizeof...(Args) == 0 && is_layout<Arg0>::value) {
using Layout = Arg0;
using Engine = ArrayEngine<T, cosize_v<Layout>>;
return Tensor<Engine,Layout>();
} else {
using Layout = decltype(make_layout(arg0, args...));
using Engine = ArrayEngine<T, cosize_v<Layout>>;
return Tensor<Engine,Layout>();
}
}
}
};
template <class T, class... Args>
CUTE_HOST_DEVICE constexpr
auto
make_tensor(Args const&... args)
{
static_assert((not has_dereference<Args>::value && ...), "Expected layout args... in make_tensor<T>(args...)");
return MakeTensor<T>{}(args...);
}
template <class Iterator, class... Args>
CUTE_HOST_DEVICE constexpr
auto
make_tensor(Iterator const& iter, Args const&... args)
{
static_assert(has_dereference<Iterator>::value, "Expected iterator iter in make_tensor(iter, args...)");
static_assert((not has_dereference<Args>::value && ...), "Expected layout args... in make_tensor(iter, args...)");
return MakeTensor<Iterator>{}(iter, args...);
}
template <class NewT, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_tensor_like(Layout const& layout)
{
return make_tensor<NewT>(make_layout_like(layout));
}
template <class NewT, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_tensor_like(Tensor<Engine,Layout> const& tensor)
{
return make_tensor_like<NewT>(tensor.layout());
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_tensor_like(Tensor<Engine,Layout> const& tensor)
{
return make_tensor_like<typename Engine::value_type>(tensor.layout());
}
template <class NewT, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_fragment_like(Layout const& layout)
{
return make_tensor<NewT>(make_fragment_like(layout));
}
template <class NewT, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_fragment_like(Tensor<Engine,Layout> const& tensor)
{
return make_fragment_like<NewT>(tensor.layout());
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_fragment_like(Tensor<Engine,Layout> const& tensor)
{
return make_fragment_like<typename Engine::value_type>(tensor.layout());
}
template <class Layout, __CUTE_REQUIRES(is_layout<Layout>::value)>
CUTE_HOST_DEVICE constexpr
auto
make_counting_tensor(Layout const& layout)
{
return make_tensor(make_inttuple_iter(coprofile(layout)), layout);
}
template <class Shape>
CUTE_HOST_DEVICE constexpr
auto
make_identity_tensor(Shape const& shape)
{
return make_counting_tensor(make_identity_layout(shape));
}
template <int... Is, class Tensor>
CUTE_HOST_DEVICE constexpr
auto
tensor(Tensor&& tensor)
{
if constexpr (sizeof...(Is) == 0) {
return tensor;
} else {
return make_tensor(tensor.data(), get<Is...>(tensor.layout()));
}
CUTE_GCC_UNREACHABLE;
}
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
layout(Tensor<Engine,Layout> const& tensor)
{
return layout<Is...>(tensor.layout());
}
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
shape(Tensor<Engine,Layout> const& tensor)
{
return shape<Is...>(tensor.layout());
}
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
stride(Tensor<Engine,Layout> const& tensor)
{
return stride<Is...>(tensor.layout());
}
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
size(Tensor<Engine,Layout> const& tensor)
{
return size<Is...>(tensor.layout());
}
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
rank(Tensor<Engine,Layout> const& tensor)
{
return rank<Is...>(tensor.layout());
}
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
depth(Tensor<Engine, Layout> const& tensor)
{
return depth<Is...>(tensor.layout());
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
flatten(Tensor<Engine,Layout> const& tensor) {
return make_tensor(tensor.data(), flatten(tensor.layout()));
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
flatten(Tensor<Engine,Layout>& tensor) {
return make_tensor(tensor.data(), flatten(tensor.layout()));
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
flatten(Tensor<Engine,Layout>&& tensor) {
return make_tensor(tensor.data(), flatten(tensor.layout()));
}
template <class Engine, class Layout, class Profile = Int<1>>
CUTE_HOST_DEVICE constexpr
auto
coalesce(Tensor<Engine,Layout> const& tensor, Profile const& profile = {}) {
return make_tensor(tensor.data(), coalesce(tensor.layout(), profile));
}
template <class Engine, class Layout, class Profile = Int<1>>
CUTE_HOST_DEVICE constexpr
auto
coalesce(Tensor<Engine,Layout>& tensor, Profile const& profile = {}) {
return make_tensor(tensor.data(), coalesce(tensor.layout(), profile));
}
template <class Engine, class Layout, class Profile = Int<1>>
CUTE_HOST_DEVICE constexpr
auto
coalesce(Tensor<Engine,Layout>&& tensor, Profile const& profile = {}) {
return make_tensor(tensor.data(), coalesce(tensor.layout(), profile));
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
filter_zeros(Tensor<Engine,Layout> const& tensor) {
return make_tensor(tensor.data(), filter_zeros(tensor.layout()));
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
filter_zeros(Tensor<Engine,Layout>& tensor) {
return make_tensor(tensor.data(), filter_zeros(tensor.layout()));
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
filter_zeros(Tensor<Engine,Layout>&& tensor) {
return make_tensor(tensor.data(), filter_zeros(tensor.layout()));
}
template <class Engine, class Layout, class Profile>
CUTE_HOST_DEVICE constexpr
auto
filter_zeros(Tensor<Engine,Layout> const& tensor, Profile const& profile)
{
return make_tensor(tensor.data(), filter_zeros(tensor.layout(), profile));
}
template <class Engine, class Layout, class Profile>
CUTE_HOST_DEVICE constexpr
auto
filter_zeros(Tensor<Engine,Layout>& tensor, Profile const& profile)
{
return make_tensor(tensor.data(), filter_zeros(tensor.layout(), profile));
}
template <class Engine, class Layout, class Profile>
CUTE_HOST_DEVICE constexpr
auto
filter_zeros(Tensor<Engine,Layout>&& tensor, Profile const& profile)
{
return make_tensor(tensor.data(), filter_zeros(tensor.layout(), profile));
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
filter(Tensor<Engine,Layout> const& tensor) {
return make_tensor(tensor.data(), filter(tensor.layout()));
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
filter(Tensor<Engine,Layout>& tensor) {
return make_tensor(tensor.data(), filter(tensor.layout()));
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
filter(Tensor<Engine,Layout>&& tensor) {
return make_tensor(tensor.data(), filter(tensor.layout()));
}
template <int B, int E, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
group_modes(Tensor<Engine,Layout> const& tensor) {
return make_tensor(tensor.data(), group<B,E>(tensor.layout()));
}
template <int B, int E, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
group_modes(Tensor<Engine,Layout>& tensor) {
return make_tensor(tensor.data(), group<B,E>(tensor.layout()));
}
template <int B, int E, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
group_modes(Tensor<Engine,Layout>&& tensor) {
return make_tensor(tensor.data(), group<B,E>(tensor.layout()));
}
template <int B, int E, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
take(Tensor<Engine,Layout> const& tensor) {
return make_tensor(tensor.data(), take<B,E>(tensor.layout()));
}
template <int B, int E, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
take(Tensor<Engine,Layout>& tensor) {
return make_tensor(tensor.data(), take<B,E>(tensor.layout()));
}
template <int B, int E, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
take(Tensor<Engine,Layout>&& tensor) {
return make_tensor(tensor.data(), take<B,E>(tensor.layout()));
}
template <class Coord, class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
domain_offset(Coord const& coord, Tensor&& tensor)
{
auto [layout, ptr_offset] = domain_offset(coord, tensor.layout());
return make_tensor(static_cast<Tensor&&>(tensor).data() + ptr_offset, layout);
}
template <class NewType, class Tensor>
CUTE_HOST_DEVICE constexpr
auto
recast(Tensor&& tensor)
{
using OldType = typename remove_cvref_t<Tensor>::value_type;
auto old_layout = tensor.layout();
auto new_layout = recast_layout<OldType,NewType>(old_layout);
if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout<decltype(old_layout)>::value) {
auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{});
auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{});
auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); });
return make_tensor(recast_ptr<NewType>(static_cast<Tensor&&>(tensor).data() + offset), new_layout);
} else {
return make_tensor(recast_ptr<NewType>(static_cast<Tensor&&>(tensor).data() ), new_layout);
}
CUTE_GCC_UNREACHABLE;
}
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE constexpr
auto
max_common_vector(Tensor<SrcEngine,SrcLayout> const& a,
Tensor<DstEngine,DstLayout> const& b)
{
using SrcType = typename SrcEngine::value_type;
using SrcRef = typename SrcEngine::reference;
using DstType = typename DstEngine::value_type;
using DstRef = typename DstEngine::reference;
if constexpr ( cute::is_same<SrcType, DstType>::value &&
is_trivially_copyable<SrcType>::value &&
is_trivially_copyable<DstType>::value &&
is_reference<SrcRef>::value &&
is_reference<DstRef>::value)
{
return max_common_vector(a.layout(), b.layout());
} else {
return Int<0>{};
}
CUTE_GCC_UNREACHABLE;
}
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE constexpr
auto
max_common_layout(Tensor<SrcEngine,SrcLayout> const& a,
Tensor<DstEngine,DstLayout> const& b)
{
using SrcType = typename SrcEngine::value_type;
using SrcRef = typename SrcEngine::reference;
using DstType = typename DstEngine::value_type;
using DstRef = typename DstEngine::reference;
if constexpr ( cute::is_same<SrcType, DstType>::value &&
is_trivially_copyable<SrcType>::value &&
is_trivially_copyable<DstType>::value &&
is_reference<SrcRef>::value &&
is_reference<DstRef>::value)
{
return max_common_layout(a.layout(), b.layout());
} else {
return Layout<_1,_0>{};
}
CUTE_GCC_UNREACHABLE;
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
max_alignment(Tensor<Engine,Layout> const& t)
{
return gcd(max_alignment(t.data()),
max_alignment(t.layout()) * static_value<sizeof_bits<typename Engine::value_type>>());
}
template <class Tensor, class Tiler,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
composition(Tensor && tensor,
Tiler const& tiler) {
return make_tensor(static_cast<Tensor&&>(tensor).data(),
composition(tensor.layout(), tiler));
}
template <class Tensor, class Tiler,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
logical_divide(Tensor && tensor,
Tiler const& tiler) {
return make_tensor(static_cast<Tensor&&>(tensor).data(),
logical_divide(tensor.layout(), tiler));
}
template <class Tensor, class Tiler,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
zipped_divide(Tensor && tensor,
Tiler const& tiler) {
return make_tensor(static_cast<Tensor&&>(tensor).data(),
zipped_divide(tensor.layout(), tiler));
}
template <class Tensor, class Tiler,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
tiled_divide(Tensor && tensor,
Tiler const& tiler) {
return make_tensor(static_cast<Tensor&&>(tensor).data(),
tiled_divide(tensor.layout(), tiler));
}
template <class Tensor, class Tiler,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
flat_divide(Tensor && tensor,
Tiler const& tiler) {
return make_tensor(static_cast<Tensor&&>(tensor).data(),
flat_divide(tensor.layout(), tiler));
}
template <class Tensor, class Tiler, class Coord,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
inner_partition(Tensor && tensor,
Tiler const& tiler,
Coord const& coord)
{
auto tensor_tiled = zipped_divide(static_cast<Tensor&&>(tensor), tiler);
constexpr int R0 = decltype(rank<0>(tensor_tiled))::value;
if constexpr (is_tuple<Coord>::value) {
constexpr int R1 = decltype(rank<1>(tensor_tiled))::value;
return tensor_tiled(repeat<R0>(_), append<R1>(coord,_));
} else {
return tensor_tiled(repeat<R0>(_), coord);
}
}
template <class Tensor, class Tiler, class Coord,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
outer_partition(Tensor && tensor,
Tiler const& tiler,
Coord const& coord)
{
auto tensor_tiled = zipped_divide(static_cast<Tensor&&>(tensor), tiler);
constexpr int R1 = decltype(rank<1>(tensor_tiled))::value;
if constexpr (is_tuple<Coord>::value) {
constexpr int R0 = decltype(rank<0>(tensor_tiled))::value;
return tensor_tiled(append<R0>(coord,_), repeat<R1>(_));
} else {
return tensor_tiled(coord, repeat<R1>(_));
}
}
template <class Tensor, class Tiler, class Coord,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
local_tile(Tensor && tensor,
Tiler const& tiler, Coord const& coord) {
return inner_partition(static_cast<Tensor&&>(tensor),
tiler,
coord);
}
template <class Tensor, class Tiler, class Coord, class Proj,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE
auto
local_tile(Tensor && tensor,
Tiler const& tiler, Coord const& coord, Proj const& proj) {
return local_tile(static_cast<Tensor&&>(tensor),
dice(proj, tiler),
dice(proj, coord));
}
template <class Tensor, class LShape, class LStride, class Index,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE
auto
local_partition(Tensor && tensor,
Layout<LShape,LStride> const& tile, Index const& index) {
static_assert(is_integral<Index>::value);
return outer_partition(static_cast<Tensor&&>(tensor),
product_each(shape(tile)),
tile.get_flat_coord(index));
}
template <class Tensor, class LShape, class LStride, class Index, class Projection,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE
auto
local_partition(Tensor && tensor,
Layout<LShape,LStride> const& tile, Index const& index, Projection const& proj)
{
return local_partition(static_cast<Tensor&&>(tensor),
dice(proj, tile),
index);
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE void print(Tensor<Engine,Layout> const& tensor)
{
print(tensor.data()); print(" o "); print(tensor.layout());
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE void print_tensor(Tensor<Engine,Layout> const& tensor, bool print_type = true)
{
if (print_type) {
print(tensor); print(":\n");
}
if constexpr (Layout::rank == 1)
{
for (int m = 0; m < size(tensor); ++m) {
pretty_print(tensor(m));
printf("\n");
}
} else
if constexpr (Layout::rank == 2)
{
for (int m = 0; m < size<0>(tensor); ++m) {
for (int n = 0; n < size<1>(tensor); ++n) {
pretty_print(tensor(m,n));
}
printf("\n");
}
} else
if constexpr (Layout::rank == 3)
{
print_tensor(tensor(_,_,0), false);
for (int k = 1; k < size<2>(tensor); ++k) {
for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n");
print_tensor(tensor(_,_,k), false);
}
} else
if constexpr (Layout::rank == 4)
{
print_tensor(tensor(_,_,_,0), false);
for (int p = 1; p < size<3>(tensor); ++p) {
for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n");
print_tensor(tensor(_,_,_,p), false);
}
}
}
#if !defined(__CUDACC_RTC__)
template <class Engine, class Layout>
CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor<Engine,Layout> const& tensor)
{
int digits = 9;
if constexpr (Layout::rank == 1)
{
for (int m = 0; m < size(tensor); ++m) {
os << std::setw(digits) << tensor(m) << std::endl;
}
} else
if constexpr (Layout::rank == 2)
{
for (int m = 0; m < size<0>(tensor); ++m) {
for (int n = 0; n < size<1>(tensor); ++n) {
os << std::setw(digits) << tensor(m,n);
}
os << std::endl;
}
} else
if constexpr (Layout::rank == 3)
{
print_tensor_os(os, tensor(_,_,0));
for (int k = 1; k < size<2>(tensor); ++k) {
for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl;
print_tensor_os(os, tensor(_,_,k));
}
} else
if constexpr (Layout::rank == 4)
{
print_tensor_os(os, tensor(_,_,_,0));
for (int p = 1; p < size<3>(tensor); ++p) {
for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl;
print_tensor_os(os, tensor(_,_,_,p));
}
}
return os;
}
template <class Engine, class Layout>
CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor<Engine,Layout> const& tensor)
{
os << tensor.layout() << std::endl;
return print_tensor_os(os, tensor);
}
#endif
}