#pragma once
#include <cute/config.hpp>
#include <cute/int_tuple.hpp>
#include <cute/stride.hpp>
#include <cute/underscore.hpp>
#include <cute/numeric/arithmetic_tuple.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/numeric/integral_ratio.hpp>
#include <cute/numeric/numeric_types.hpp>
namespace cute
{
template <class... Shapes>
using Shape = cute::tuple<Shapes...>;
template <class... Strides>
using Stride = cute::tuple<Strides...>;
template <class... Strides>
using Step = cute::tuple<Strides...>;
template <class... Coords>
using Coord = cute::tuple<Coords...>;
template <class... Layouts>
using Tile = cute::tuple<Layouts...>;
template <class... Ts>
CUTE_HOST_DEVICE constexpr
Shape<Ts...>
make_shape(Ts const&... t) {
return {t...};
}
template <class... Ts>
CUTE_HOST_DEVICE constexpr
Stride<Ts...>
make_stride(Ts const&... t) {
return {t...};
}
template <class... Ts>
CUTE_HOST_DEVICE constexpr
Step<Ts...>
make_step(Ts const&... t) {
return {t...};
}
template <class... Ts>
CUTE_HOST_DEVICE constexpr
Coord<Ts...>
make_coord(Ts const&... t) {
return {t...};
}
template <class... Ts>
CUTE_HOST_DEVICE constexpr
Tile<Ts...>
make_tile(Ts const&... t)
{
return {t...};
}
template <class Shape, class Stride = LayoutLeft::Apply<Shape> >
struct Layout
: private cute::tuple<Shape, Stride> {
CUTE_HOST_DEVICE constexpr
Layout(Shape const& shape = {}, Stride const& stride = {})
: cute::tuple<Shape, Stride>(shape, stride)
{}
static constexpr int rank = rank_v<Shape>;
CUTE_HOST_DEVICE constexpr
decltype(auto)
layout() {
return *this;
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
layout() const {
return *this;
}
template <int... I>
CUTE_HOST_DEVICE constexpr
decltype(auto)
shape() {
return get<0,I...>(static_cast<cute::tuple<Shape, Stride>&>(*this));
}
template <int... I>
CUTE_HOST_DEVICE constexpr
decltype(auto)
shape() const {
return get<0,I...>(static_cast<cute::tuple<Shape, Stride> const&>(*this));
}
template <int... I>
CUTE_HOST_DEVICE constexpr
decltype(auto)
stride() {
return get<1,I...>(static_cast<cute::tuple<Shape, Stride>&>(*this));
}
template <int... I>
CUTE_HOST_DEVICE constexpr
decltype(auto)
stride() const {
return get<1,I...>(static_cast<cute::tuple<Shape, Stride> const&>(*this));
}
template <class Coord>
CUTE_HOST_DEVICE constexpr
auto
operator()(Coord const& coord) const {
if constexpr (has_underscore<Coord>::value) {
return slice(coord, *this);
} else {
return crd2idx(coord, shape(), stride());
}
CUTE_GCC_UNREACHABLE;
}
template <class Coord0, class Coord1, class... Coords>
CUTE_HOST_DEVICE constexpr
auto
operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const {
return operator()(make_coord(c0,c1,cs...));
}
template <class OtherLayout>
CUTE_HOST_DEVICE constexpr
auto
compose(OtherLayout const& other) const {
return composition(*this, other);
}
template <class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
compose(Layouts const&... layouts) const {
return composition(*this, make_tile(layouts...));
}
template <class OtherShape>
CUTE_HOST_DEVICE constexpr
auto
with_shape(OtherShape const& shape) const {
return composition(*this, make_layout(shape));
}
template <class... Shapes>
CUTE_HOST_DEVICE constexpr
auto
with_shape(Shapes const&... shapes) const {
return composition(*this, make_layout(make_shape(shapes...)));
}
template <class OtherLayout>
CUTE_HOST_DEVICE constexpr
auto
tile(OtherLayout const& other) const {
return tiled_divide(*this, other);
}
template <class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
tile(Layouts const&... layouts) const {
return tiled_divide(*this, make_tile(layouts...));
}
template <class IInt,
__CUTE_REQUIRES(is_integral<IInt>::value)>
CUTE_HOST_DEVICE constexpr
auto
get_hier_coord(IInt const& idx) const {
return cute::idx2crd(idx, shape(), stride());
}
template <class IInt,
__CUTE_REQUIRES(is_integral<IInt>::value)>
CUTE_HOST_DEVICE constexpr
auto
get_flat_coord(IInt const& idx) const {
return cute::crd2crd(this->get_hier_coord(idx), shape(), repeat<rank>(Int<1>{}));
}
template <class IInt,
__CUTE_REQUIRES(is_integral<IInt>::value)>
CUTE_HOST_DEVICE constexpr
auto
get_1d_coord(IInt const& idx) const {
return cute::crd2idx(this->get_hier_coord(idx), shape());
}
#if 0#endif
};
template <class ShapeA, class StrideA,
class ShapeB, class StrideB>
CUTE_HOST_DEVICE constexpr
auto
operator==(Layout<ShapeA,StrideA> const& layoutA, Layout<ShapeB,StrideB> const& layoutB)
{
return layoutA.shape() == layoutB.shape() && layoutA.stride() == layoutB.stride();
}
template <class Layout>
struct is_layout : false_type {};
template <class Shape, class Stride>
struct is_layout<Layout<Shape,Stride>> : true_type {};
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
make_layout(Shape const& shape, Stride const& stride)
{
static_assert(is_tuple<Shape >::value || is_integral<Shape >::value);
static_assert(is_tuple<Stride>::value || is_integral<Stride>::value);
return Layout<Shape,Stride>(shape, stride);
}
template <class Shape>
CUTE_HOST_DEVICE constexpr
auto
make_layout(Shape const& shape)
{
static_assert(is_tuple<Shape >::value || is_integral<Shape >::value);
return make_layout(shape, compact_major<LayoutLeft>(shape));
}
template <class Shape>
CUTE_HOST_DEVICE constexpr
auto
make_layout(Shape const& shape, LayoutLeft)
{
return make_layout(shape, compact_major<LayoutLeft>(shape));
}
template <class Shape>
CUTE_HOST_DEVICE constexpr
auto
make_layout(Shape const& shape, LayoutRight)
{
return make_layout(shape, compact_major<LayoutRight>(shape));
}
template <class Shape0, class Stride0>
CUTE_HOST_DEVICE constexpr
auto
make_layout(Layout<Shape0,Stride0> const& layout0)
{
return make_layout(make_shape (layout0.shape() ),
make_stride(layout0.stride()));
}
template <class Shape0, class Stride0,
class Shape1, class Stride1>
CUTE_HOST_DEVICE constexpr
auto
make_layout(Layout<Shape0,Stride0> const& layout0,
Layout<Shape1,Stride1> const& layout1)
{
return make_layout(make_shape (layout0.shape() , layout1.shape() ),
make_stride(layout0.stride(), layout1.stride()));
}
template <class Shape0, class Stride0,
class Shape1, class Stride1,
class... Shapes, class... Strides>
CUTE_HOST_DEVICE constexpr
auto
make_layout(Layout<Shape0,Stride0> const& layout0,
Layout<Shape1,Stride1> const& layout1,
Layout<Shapes,Strides> const&... layouts)
{
return make_layout(make_shape (layout0.shape() , layout1.shape() , layouts.shape()... ),
make_stride(layout0.stride(), layout1.stride(), layouts.stride()...));
}
template <class Shape, class Order>
CUTE_HOST_DEVICE constexpr
auto
make_ordered_layout(Shape const& shape, Order const& order)
{
return make_layout(shape, compact_order(shape, order));
}
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
make_layout_like(Layout<Shape,Stride> const& layout)
{
return make_layout(layout.shape(),
compact_order(filter_zeros(layout.stride(), layout.shape()), layout.stride()));
}
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
make_fragment_like(Layout<Shape,Stride> const& layout)
{
constexpr int R = Layout<Shape,Stride>::rank;
if constexpr (R > 1 && is_static<Shape>::value) {
return tiled_product(make_layout(get<0>(layout.shape()),
compact_major<LayoutLeft>(filter_zeros(get<0>(layout.stride()), get<0>(layout.shape())))),
make_ordered_layout(take<1,R>(layout.shape()), take<1,R>(layout.stride())));
} else {
return make_layout(layout.shape());
}
CUTE_GCC_UNREACHABLE;
}
template <class Shape,
__CUTE_REQUIRES(is_tuple<Shape>::value || is_integral<Shape>::value)>
CUTE_HOST_DEVICE constexpr
auto
make_fragment_like(Shape const& shape)
{
return make_layout(shape);
}
template <class Shape>
CUTE_HOST_DEVICE constexpr
auto
make_identity_layout(Shape const& shape)
{
return make_layout(shape, make_basis_like(shape));
}
template <size_t... Is, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
get(Layout<Shape,Stride> const& layout)
{
return make_layout(get<Is...>(layout.shape()),
get<Is...>(layout.stride()));
}
template <int B, int E, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
take(Layout<Shape,Stride> const& layout)
{
static_assert(B < E, "take: empty range error");
static_assert(0 <= B && E <= Layout<Shape,Stride>::rank, "take: range out of bounds");
return make_layout(take<B,E>(layout.shape()),
take<B,E>(layout.stride()));
}
template <int... Is, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
select(Layout<Shape,Stride> const& layout)
{
return make_layout(select<Is...>(layout.shape()),
select<Is...>(layout.stride()));
}
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
flatten(Layout<Shape,Stride> const& layout)
{
return make_layout(flatten(layout.shape()),
flatten(layout.stride()));
}
template <class Shape, class Stride, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten(Layout<Shape,Stride> const& layout, TargetProfile const& target_profile)
{
return make_layout(unflatten(layout.shape(), target_profile),
unflatten(layout.stride(), target_profile));
}
template <int... Is, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
decltype(auto)
layout(Layout<Shape,Stride> const& layout)
{
if constexpr (sizeof...(Is) == 0) {
return layout;
} else {
return get<Is...>(layout);
}
CUTE_GCC_UNREACHABLE;
}
template <int... Is, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
decltype(auto)
shape(Layout<Shape,Stride>& layout)
{
return layout.template shape<Is...>();
}
template <int... Is, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
decltype(auto)
shape(Layout<Shape,Stride> const& layout)
{
return layout.template shape<Is...>();
}
template <int... Is, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
decltype(auto)
stride(Layout<Shape,Stride>& layout)
{
return layout.template stride<Is...>();
}
template <int... Is, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
decltype(auto)
stride(Layout<Shape,Stride> const& layout)
{
return layout.template stride<Is...>();
}
template <int... Is, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
size(Layout<Shape,Stride> const& layout)
{
return size(shape<Is...>(layout));
}
template <int... Is, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
rank(Layout<Shape,Stride> const& layout)
{
return rank(shape<Is...>(layout));
}
template <int... Is, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
depth(Layout<Shape,Stride> const& layout)
{
return depth(shape<Is...>(layout));
}
template <int... Is, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
coprofile(Layout<Shape,Stride> const& layout)
{
return repeat_like(as_arithmetic_tuple(sum(stride<Is...>(layout))), Int<0>{});
}
template <int... Is, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
coshape(Layout<Shape,Stride> const& layout)
{
auto m1_shapes = transform_leaf( shape<Is...>(layout), [](auto s) { return s - Int<1>{}; });
auto abs_strides = transform_leaf(stride<Is...>(layout), abs_fn{});
auto co_coord = as_arithmetic_tuple(inner_product(m1_shapes, abs_strides));
return transform_leaf(co_coord, [](auto c) { return c + Int<1>{}; });
}
template <int... Is, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
cosize(Layout<Shape,Stride> const& layout)
{
return size(coshape<Is...>(layout));
}
template <class Layout>
using cosize_t = decltype(cosize(declval<Layout>()));
template <class Layout>
static constexpr auto cosize_v = cosize_t<Layout>::value;
template <class Coord, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
crd2idx(Coord const& c, Layout<Shape,Stride> const& layout)
{
return crd2idx(c, layout.shape(), layout.stride());
}
template <class Coord, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
slice(Coord const& c, Layout<Shape,Stride> const& layout)
{
return make_layout(slice(c, layout.shape()),
slice(c, layout.stride()));
}
template <class Coord, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
slice_and_offset(Coord const& c, Layout<Shape,Stride> const& layout)
{
return cute::make_tuple(slice(c, layout), crd2idx(c, layout));
}
template <class Coord, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
dice(Coord const& c, Layout<Shape,Stride> const& layout)
{
return make_layout(dice(c, layout.shape()),
dice(c, layout.stride()));
}
template <class Coord, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
domain_offset(Coord const& coord, Layout<Shape,Stride> const& layout)
{
return cute::make_tuple(layout, layout(coord));
}
namespace detail {
template <class Tuple, class F, int... I>
CUTE_HOST_DEVICE constexpr
auto
transform_layout(Tuple const& t, F&& f, seq<I...>)
{
return make_layout(f(get<I>(t))...);
}
template <class Tuple0, class Tuple1, class F, int... I, int... I0, int... I1>
CUTE_HOST_DEVICE constexpr
auto
transform_layout(Tuple0 const& t0, Tuple1 const& t1, F&& f, seq<I...>, seq<I0...>, seq<I1...>)
{
return make_layout(f(get<I>(t0),get<I>(t1))..., get<I0>(t0)..., get<I1>(t1)...);
}
}
template <class Tuple, class F>
CUTE_HOST_DEVICE constexpr
auto
transform_layout(Tuple const& t, F&& f)
{
return detail::transform_layout(t, f, make_seq<decltype(rank(t))::value>{});
}
template <class Tuple0, class Tuple1, class F>
CUTE_HOST_DEVICE constexpr
auto
transform_layout(Tuple0 const& t0, Tuple1 const& t1, F&& f)
{
constexpr int R0 = decltype(rank(t0))::value;
constexpr int R1 = decltype(rank(t1))::value;
constexpr int R = (R0 < R1) ? R0 : R1;
return detail::transform_layout(t0, t1, f, make_seq<R>{}, make_range<R,R0>{}, make_range<R,R1>{});
}
namespace detail {
template <int I, class OldShape, class OldStride, class NewShape, class NewStride>
CUTE_HOST_DEVICE constexpr
auto
bw_coalesce(OldShape const& old_shape, OldStride const& old_stride,
NewShape const& new_shape, NewStride const& new_stride)
{
if constexpr (I == -1) {
if constexpr (is_constant<1, NewShape>::value) {
return Layout<_1,_0>{};
} else {
return Layout<NewShape,NewStride>{new_shape,new_stride};
}
} else if constexpr (is_constant<1, decltype(get<I>(old_shape))>::value) {
return bw_coalesce<I-1>(old_shape, old_stride, new_shape, new_stride);
} else if constexpr (is_constant<1, NewShape>::value) {
return bw_coalesce<I-1>(old_shape, old_stride, get<I>(old_shape), get<I>(old_stride));
} else if constexpr (is_static<decltype(get<0>(new_shape))>::value &&
is_constant<true, decltype(get<I>(old_shape) * get<I>(old_stride) == get<0>(new_stride))>::value) {
return bw_coalesce<I-1>(old_shape, old_stride,
replace_front(new_shape, get<I>(old_shape) * get<0>(new_shape)),
replace_front(new_stride, get<I>(old_stride)));
} else {
return bw_coalesce<I-1>(old_shape, old_stride,
prepend(new_shape, get<I>(old_shape)),
prepend(new_stride, get<I>(old_stride)));
}
CUTE_GCC_UNREACHABLE;
}
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
coalesce_x(Layout<Shape,Stride> const& layout)
{
auto flat_shape = flatten(layout.shape());
auto flat_stride = flatten(layout.stride());
constexpr int R = decltype(rank(flat_shape))::value;
if constexpr (is_constant<1, decltype(get<R-1>(flat_shape))>::value) {
return detail::bw_coalesce<R-2>(flat_shape, flat_stride, Int<2>{}, get<R-1>(flat_stride));
} else {
return detail::bw_coalesce<R-2>(flat_shape, flat_stride, get<R-1>(flat_shape), get<R-1>(flat_stride));
}
}
template <class Shape, class Stride, class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
coalesce_x(Layout<Shape,Stride> const& layout, IntTuple const& trg_profile)
{
if constexpr (is_tuple<IntTuple>::value) {
static_assert(tuple_size<IntTuple>::value <= Layout<Shape,Stride>::rank);
return cute::transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return coalesce_x(l,t); });
} else {
return coalesce_x(layout);
}
CUTE_GCC_UNREACHABLE;
}
}
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
coalesce(Layout<Shape,Stride> const& layout)
{
auto flat_shape = flatten(layout.shape());
auto flat_stride = flatten(layout.stride());
constexpr int R = decltype(rank(flat_shape))::value;
return detail::bw_coalesce<R-2>(flat_shape, flat_stride, get<R-1>(flat_shape), get<R-1>(flat_stride));
}
template <class Shape, class Stride, class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
coalesce(Layout<Shape,Stride> const& layout, IntTuple const& trg_profile)
{
if constexpr (is_tuple<IntTuple>::value) {
static_assert(tuple_size<IntTuple>::value <= Layout<Shape,Stride>::rank);
return transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return coalesce(l,t); });
} else {
return coalesce(layout);
}
CUTE_GCC_UNREACHABLE;
}
template <class Shape>
CUTE_HOST_DEVICE constexpr
auto
coalesce(Shape const& shape)
{
static_assert(is_integral<Shape>::value || is_tuple<Shape>::value);
return cute::fold_first(flatten(shape), [](auto const& init, auto const& a) {
if constexpr (is_static<decltype(back(init))>::value == is_static<decltype(a)>::value) {
return replace_back(init, back(init) * a); } else {
return append(init, a); }
});
}
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
filter_zeros(Layout<Shape,Stride> const& layout)
{
return make_layout(filter_zeros(layout.stride(), layout.shape()), layout.stride());
}
template <class Shape, class Stride, class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
filter_zeros(Layout<Shape,Stride> const& layout, IntTuple const& trg_profile)
{
return make_layout(filter_zeros(trg_profile, layout.shape()), layout.stride());
}
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
filter(Layout<Shape,Stride> const& layout)
{
return coalesce(filter_zeros(layout));
}
template <class Shape, class Stride, class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
filter(Layout<Shape,Stride> const& layout, IntTuple const& trg_profile)
{
if constexpr (is_tuple<IntTuple>::value) {
static_assert(tuple_size<IntTuple>::value <= Layout<Shape,Stride>::rank);
return transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return filter(l,t); });
} else {
return filter(layout);
}
CUTE_GCC_UNREACHABLE;
}
template <int N, class ShapeA, class StrideA, class ShapeX = _1, class StrideX = _0>
CUTE_HOST_DEVICE constexpr
auto
append(Layout<ShapeA,StrideA> const& layout,
Layout<ShapeX,StrideX> const& x = {})
{
return make_layout(append<N>(layout.shape(), x.shape()),
append<N>(layout.stride(), x.stride()));
}
template <class ShapeA, class StrideA, class ShapeX = _1, class StrideX = _0>
CUTE_HOST_DEVICE constexpr
auto
append(Layout<ShapeA,StrideA> const& layout,
Layout<ShapeX,StrideX> const& x = {})
{
return make_layout(append(layout.shape(), x.shape()),
append(layout.stride(), x.stride()));
}
template <int N, class ShapeA, class StrideA, class ShapeX = _1, class StrideX = _0>
CUTE_HOST_DEVICE constexpr
auto
prepend(Layout<ShapeA,StrideA> const& layout,
Layout<ShapeX,StrideX> const& x = {})
{
return make_layout(prepend<N>(layout.shape(), x.shape()),
prepend<N>(layout.stride(), x.stride()));
}
template <class ShapeA, class StrideA, class ShapeX = _1, class StrideX = _0>
CUTE_HOST_DEVICE constexpr
auto
prepend(Layout<ShapeA,StrideA> const& layout,
Layout<ShapeX,StrideX> const& x = {})
{
return make_layout(prepend(layout.shape(), x.shape()),
prepend(layout.stride(), x.stride()));
}
template <int N, class ShapeA, class StrideA, class ShapeX, class StrideX>
CUTE_HOST_DEVICE constexpr
auto
replace(Layout<ShapeA,StrideA> const& layout,
Layout<ShapeX,StrideX> const& x)
{
return make_layout(replace<N>(layout.shape(), x.shape()),
replace<N>(layout.stride(), x.stride()));
}
template <int B, int E, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
group(Layout<Shape,Stride> const& layout)
{
return make_layout(group<B,E>(layout.shape()),
group<B,E>(layout.stride()));
}
namespace detail {
template <class LShape, class LStride,
class RShape, class RStride>
CUTE_HOST_DEVICE constexpr
auto
composition_impl(LShape const& lhs_shape, LStride const& lhs_stride,
RShape const& rhs_shape, RStride const& rhs_stride)
{
if constexpr (is_tuple<RShape>::value) { return transform_layout(rhs_shape, rhs_stride, [&](auto const& s, auto const& d) {
return composition_impl(lhs_shape, lhs_stride, s, d);
});
} else
if constexpr (is_scaled_basis<RStride>::value) { return composition_impl(basis_get(rhs_stride, lhs_shape), basis_get(rhs_stride, lhs_stride),
rhs_shape, basis_value(rhs_stride));
} else
if constexpr (is_constant<0, RStride>::value) { return Layout<RShape, RStride>{rhs_shape, rhs_stride};
} else
if constexpr (is_integral<LShape>::value) { return Layout{rhs_shape, rhs_stride * lhs_stride};
} else { constexpr int R = tuple_size<LShape>::value;
auto [result_shape, result_stride, rest_shape, rest_stride] =
cute::fold(make_seq<R-1>{}, cute::make_tuple(cute::tuple<>{}, cute::tuple<>{}, rhs_shape, rhs_stride), [&](auto const& init, auto curr_i) { auto result_shape = get<0>(init);
auto result_stride = get<1>(init);
auto rest_shape = get<2>(init);
auto rest_stride = get<3>(init);
auto curr_shape = get<curr_i>(lhs_shape);
auto curr_stride = get<curr_i>(lhs_stride);
if constexpr (is_static<decltype(curr_shape)>::value and is_static<decltype(rest_stride)>::value) {
CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or (rest_stride < curr_shape), "Stride Divisibility Condition");
} else {
}
[[maybe_unused]] auto next_shape = cute::ceil_div(curr_shape, abs(rest_stride));
[[maybe_unused]] auto next_stride = cute::ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride);
if constexpr (is_constant<1, decltype(next_shape)>::value or is_constant<1, decltype(rest_shape)>::value) {
return cute::make_tuple(result_shape,
result_stride,
rest_shape,
next_stride);
} else {
auto new_shape = cute::min(next_shape, rest_shape);
if constexpr (is_static<decltype(new_shape)>::value and is_static<decltype(rest_shape)>::value) {
CUTE_STATIC_ASSERT_V(((rest_shape % new_shape) == Int<0>{}), "Shape Divisibility Condition");
} else {
}
return cute::make_tuple(append(result_shape, new_shape),
append(result_stride, rest_stride * curr_stride),
rest_shape / new_shape,
next_stride);
}
});
if constexpr (tuple_size<decltype(result_shape)>::value == 0) {
return Layout{rest_shape, rest_stride * get<R-1>(lhs_stride)};
} else
if constexpr (is_constant<1, decltype(rest_shape)>::value) {
return Layout{unwrap(result_shape), unwrap(result_stride)};
} else {
return Layout{append(result_shape, rest_shape),
append(result_stride, rest_stride * get<R-1>(lhs_stride))};
}
}
CUTE_GCC_UNREACHABLE;
}
}
template <class LShape, class LStride,
class RShape, class RStride>
CUTE_HOST_DEVICE constexpr
auto
composition(Layout<LShape,LStride> const& lhs,
Layout<RShape,RStride> const& rhs)
{
auto flat_lhs = detail::coalesce_x(lhs, coprofile(rhs));
return detail::composition_impl(flat_lhs.shape(), flat_lhs.stride(), rhs.shape(), rhs.stride());
}
template <class LShape, class LStride, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
composition(Layout<LShape,LStride> const& lhs,
Tiler const& rhs)
{
if constexpr (is_tuple<Tiler>::value) {
static_assert(tuple_size<Tiler>::value <= Layout<LShape,LStride>::rank);
return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq<tuple_size<Tiler>::value>{}, seq<>{}, seq<>{});
} else if constexpr (is_underscore<Tiler>::value) {
return lhs;
} else if constexpr (is_integral<Tiler>::value) {
auto flat_lhs = detail::coalesce_x(lhs);
return detail::composition_impl(flat_lhs.shape(), flat_lhs.stride(), rhs, Int<1>{});
}
CUTE_GCC_UNREACHABLE;
}
namespace detail {
template <class Shape, class Stride, class CoTarget>
CUTE_HOST_DEVICE constexpr
auto
complement(Shape const& shape, Stride const& stride, CoTarget const& cotarget)
{
if constexpr (is_constant<0, Stride>::value) {
return make_layout(coalesce(cotarget));
} else {
constexpr int R = rank_v<Shape>;
static_assert(R == 1 || is_static<Stride>::value,
"Dynamic-stride complement only for rank-1 layouts");
auto [shape_, stride_, result_shape_, result_stride] =
fold(make_seq<R-1>{},
cute::make_tuple(shape, stride, cute::make_tuple(), cute::make_tuple(Int<1>{})),
[](auto const& init, auto i)
{
auto [shape, stride, result_shape, result_stride] = init;
auto min_stride = cute::min(stride);
auto min_idx = cute::find(stride, min_stride);
auto new_shape = min_stride / get<i>(result_stride);
auto new_stride = min_stride * get<min_idx>(shape);
static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement.");
return cute::make_tuple(remove<min_idx>(shape), remove<min_idx>(stride), append(result_shape , new_shape ), append(result_stride, new_stride)); });
auto new_shape = get<0>(stride_) / get<R-1>(result_stride); static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement.");
auto result_shape = append(result_shape_, new_shape);
auto new_stride = get<0>(stride_) * get<0>(shape_); auto rest_shape = coalesce(ceil_div(cotarget, new_stride));
auto rest_stride = compact_major<LayoutLeft>(rest_shape, new_stride);
return coalesce(make_layout(make_shape (result_shape , rest_shape ),
make_stride(result_stride, rest_stride)));
}
CUTE_GCC_UNREACHABLE;
}
}
template <class Shape, class Stride, class CoTarget>
CUTE_HOST_DEVICE constexpr
auto
complement(Layout<Shape,Stride> const& layout, CoTarget const& cotarget)
{
auto filter_layout = filter(layout);
return detail::complement(filter_layout.shape(), filter_layout.stride(), shape(cotarget));
}
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
complement(Layout<Shape,Stride> const& layout)
{
auto filter_layout = filter(layout);
return detail::complement(filter_layout.shape(), filter_layout.stride(), cosize(filter_layout));
}
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
right_inverse(Layout<Shape,Stride> const& layout)
{
auto clayout = coalesce(layout);
auto lstride = wrap(clayout.stride());
auto lshape = wrap(clayout.shape());
auto preprod_shape = cute::fold(lshape, cute::tuple<_1>{}, [](auto c, auto vi) { return append(c, vi*back(c)); });
[[maybe_unused]] auto filtered_seq = filter_tuple(make_seq<rank(lstride)>{}, lstride, [](auto i, auto d) {
return conditional_return<is_static_v<decltype(d)>>(cute::tuple{i}, cute::tuple<>{}); });
[[maybe_unused]] auto filtered_stride = transform(filtered_seq, [&](auto i) { return get<i>(lstride); });
using Sorted = detail::SortByKey<decltype(filtered_stride), decltype(filtered_seq)>;
auto sorted_seq = typename Sorted::val_type{};
auto [result_shape, result_stride, curr] = cute::fold(sorted_seq, tuple<tuple<_1>,tuple<_0>,_1>{},
[&](auto const& init, auto i) {
[[maybe_unused]] auto ishape = get<i>(lshape);
[[maybe_unused]] auto istride = get<i>(lstride);
[[maybe_unused]] auto curr_stride = get<2>(init);
if constexpr (is_constant<decltype(istride)::value, decltype(curr_stride)>::value) {
return make_tuple(append(get<0>(init), ishape), append(get<1>(init), get<i>(preprod_shape)), ishape * istride);
} else {
return init;
}
});
return coalesce(make_layout(result_shape, result_stride));
}
CUTE_HOST_DEVICE constexpr
auto
right_inverse(Underscore const& _)
{
return _;
}
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
left_inverse(Layout<Shape,Stride> const& layout)
{
auto clayout = coalesce(layout);
auto lstride = wrap(clayout.stride());
auto lshape = wrap(clayout.shape());
auto preprod_shape = cute::fold(lshape, cute::tuple<_1>{}, [](auto c, auto vi) { return append(c, vi*back(c)); });
static_assert(is_static<decltype(lstride)>::value, "Left inverse requires static strides.");
using Sorted = detail::SortByKey<decltype(lstride), tuple_seq<decltype(lstride)>>;
auto sorted_seq = typename Sorted::val_type{};
auto [result_shape, result_stride] = cute::fold(sorted_seq, tuple<tuple<>,tuple<_0>>{},
[&](auto const& init, auto i) {
[[maybe_unused]] auto istride = get<i>(lstride);
if constexpr (is_constant<0, decltype(istride)>::value) {
return init;
} else {
auto result_shape = get<0>(init);
auto result_stride = get<1>(init);
CUTE_STATIC_ASSERT_V((istride % size(result_shape)) == Int<0>{}, "Left inverse divisibility condition");
return make_tuple(append(result_shape, istride / size(result_shape)),
append(result_stride, get<i>(preprod_shape)));
}
});
return coalesce(make_layout(append(result_shape, get<back(sorted_seq)>(lshape)),
result_stride));
}
CUTE_HOST_DEVICE constexpr
auto
left_inverse(Underscore const& _)
{
return _;
}
template <class ShapeA, class StrideA,
class ShapeB, class StrideB>
CUTE_HOST_DEVICE constexpr
auto
max_common_layout(Layout<ShapeA,StrideA> const& a,
Layout<ShapeB,StrideB> const& b)
{
Layout inv_b = right_inverse(b);
Layout common = coalesce(composition(a, inv_b));
if constexpr (is_static<decltype(shape<0>(common))>::value &&
is_constant<1, decltype(stride<0>(common))>::value) {
return composition(inv_b, layout<0>(common));
} else {
return Layout<_1,_0>{};
}
}
template <class ShapeA, class StrideA,
class ShapeB, class StrideB>
CUTE_HOST_DEVICE constexpr
auto
max_common_vector(Layout<ShapeA,StrideA> const& a,
Layout<ShapeB,StrideB> const& b)
{
Layout common = coalesce(composition(a, right_inverse(b)));
if constexpr (is_static<decltype(shape<0>(common))>::value &&
is_constant<1, decltype(stride<0>(common))>::value) {
return shape<0>(common);
} else {
return Int<1>{};
}
CUTE_GCC_UNREACHABLE;
}
template <class ShapeA, class ShapeB>
CUTE_HOST_DEVICE constexpr
auto
domain_distribute(ShapeA const& a, ShapeB const& b)
{
static_assert(is_integral<ShapeB>::value);
static_assert(is_static<ShapeB>::value);
auto flat_shape_a = flatten(shape(a));
static_assert(is_static<decltype(flat_shape_a)>::value);
auto [result_shape, b_rest] = cute::fold(flat_shape_a, cute::make_tuple(cute::tuple<>{}, size(b)), [](auto init, auto a_) {
auto [result, b_] = init;
auto gcd_ = gcd(a_, b_);
return cute::make_tuple(append(result, gcd_), b_ / gcd_);
});
auto result_stride = compact_major<LayoutLeft>(flat_shape_a);
return coalesce(make_layout(result_shape, result_stride));
}
namespace detail {
template <int NextI, class Stride, int... Is>
CUTE_HOST_DEVICE constexpr
auto
nullspace_seq(Stride const& stride, seq<Is...>)
{
if constexpr (NextI == rank_v<Stride>) {
return seq<Is...>{};
} else
if constexpr (is_constant<0, decltype(get<NextI>(stride))>::value) {
return detail::nullspace_seq<NextI+1>(stride, seq<Is..., NextI>{});
} else {
return detail::nullspace_seq<NextI+1>(stride, seq<Is...>{});
}
CUTE_GCC_UNREACHABLE;
}
}
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
nullspace(Layout<Shape,Stride> const& layout)
{
auto flat_layout = flatten(layout);
auto iseq = detail::nullspace_seq<0>(flat_layout.stride(), seq<>{});
if constexpr (iseq.size() == 0) {
return Layout<_1,_0>{}; } else {
auto rstride = compact_major<LayoutLeft>(flat_layout.shape());
return make_layout(unwrap(transform(iseq, [&](auto i) { return shape<i>(flat_layout); })),
unwrap(transform(iseq, [&](auto i) { return get<i>(rstride); })));
}
CUTE_GCC_UNREACHABLE;
}
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
zip(Layout<Shape,Stride> const& layout)
{
return make_layout(zip(layout.shape()),
zip(layout.stride()));
}
template <class TShape, class TStride,
class UShape, class UStride>
CUTE_HOST_DEVICE constexpr
auto
zip(Layout<TShape,TStride> const& layoutA,
Layout<UShape,UStride> const& layoutB)
{
return make_layout(zip(layoutA.shape(), layoutB.shape()),
zip(layoutA.stride(), layoutB.stride()));
}
template <class LShape, class LStride, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
tile_unzip(Layout<LShape,LStride> const& layout,
Tiler const& tiler)
{
return make_layout(zip2_by(layout.shape(), tiler),
zip2_by(layout.stride(), tiler));
}
template <class LShape, class LStride,
class TShape, class TStride>
CUTE_HOST_DEVICE constexpr
auto
logical_divide(Layout<LShape,LStride> const& layout,
Layout<TShape,TStride> const& tiler)
{
return composition(layout, make_layout(tiler, complement(tiler, shape(coalesce(layout)))));
}
template <class LShape, class LStride, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
logical_divide(Layout<LShape,LStride> const& layout,
Tiler const& tiler)
{
if constexpr (is_tuple<Tiler>::value) {
static_assert(tuple_size<Tiler>::value <= Layout<LShape,LStride>::rank, "logical_divide: Too many modes in tiler.");
return transform_layout(layout, tiler, [](auto const& l, auto const& t) { return logical_divide(l,t); });
} else if constexpr (is_underscore<Tiler>::value) {
return layout;
} else if constexpr (is_integral<Tiler>::value) {
return logical_divide(layout, make_layout(tiler));
}
CUTE_GCC_UNREACHABLE;
}
template <class Target, class TShape, class TStride>
CUTE_HOST_DEVICE constexpr
auto
ceil_div(Target const& target,
Layout<TShape,TStride> const& tiler)
{
return shape(complement(tiler, shape(target)));
}
template <class LShape, class LStride,
class Tiler>
CUTE_HOST_DEVICE constexpr
auto
zipped_divide(Layout<LShape,LStride> const& layout,
Tiler const& tiler)
{
return tile_unzip(logical_divide(layout, tiler), tiler);
}
template <class LShape, class LStride,
class Tiler>
CUTE_HOST_DEVICE constexpr
auto
tiled_divide(Layout<LShape,LStride> const& layout,
Tiler const& tiler)
{
auto result = zipped_divide(layout, tiler);
auto R1 = rank<1>(result);
return result(_, repeat<R1>(_));
}
template <class LShape, class LStride,
class Tiler>
CUTE_HOST_DEVICE constexpr
auto
flat_divide(Layout<LShape,LStride> const& layout,
Tiler const& tiler)
{
auto result = zipped_divide(layout, tiler);
auto R0 = rank<0>(result);
auto R1 = rank<1>(result);
return result(repeat<R0>(_), repeat<R1>(_));
}
template <class LShape, class LStride,
class TShape, class TStride>
CUTE_HOST_DEVICE constexpr
auto
logical_product(Layout<LShape,LStride> const& block,
Layout<TShape,TStride> const& tiler)
{
return make_layout(block, composition(complement(block, size(block)*cosize(tiler)), tiler));
}
template <class LShape, class LStride, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
logical_product(Layout<LShape,LStride> const& block,
Tiler const& tiler)
{
if constexpr (is_tuple<Tiler>::value) {
static_assert(tuple_size<Tiler>::value <= Layout<LShape,LStride>::rank, "logical_product: Too many modes in tiler.");
return transform_layout(block, tiler, [](auto const& l, auto const& t) { return logical_product(l,t); });
} else if constexpr (is_underscore<Tiler>::value) {
return block;
} else if constexpr (is_integral<Tiler>::value) {
return logical_product(block, make_layout(tiler));
}
CUTE_GCC_UNREACHABLE;
}
template <class LShape, class LStride,
class Tiler>
CUTE_HOST_DEVICE constexpr
auto
zipped_product(Layout<LShape,LStride> const& block,
Tiler const& tiler)
{
return tile_unzip(logical_product(block, tiler), tiler);
}
template <class LShape, class LStride,
class Tiler>
CUTE_HOST_DEVICE constexpr
auto
tiled_product(Layout<LShape,LStride> const& block,
Tiler const& tiler)
{
auto result = zipped_product(block, tiler);
auto R1 = rank<1>(result);
return result(_, repeat<R1>(_));
}
template <class LShape, class LStride,
class Tiler>
CUTE_HOST_DEVICE constexpr
auto
flat_product(Layout<LShape,LStride> const& block,
Tiler const& tiler)
{
auto result = zipped_product(block, tiler);
auto R0 = rank<0>(result);
auto R1 = rank<1>(result);
return result(repeat<R0>(_), repeat<R1>(_));
}
template <class TShape, class TStride,
class UShape, class UStride>
CUTE_HOST_DEVICE constexpr
auto
blocked_product(Layout<TShape,TStride> const& block,
Layout<UShape,UStride> const& tiler)
{
constexpr int R = cute::max(rank_v<TShape>, rank_v<UShape>);
auto result = logical_product(append<R>(block), append<R>(tiler));
return zip(get<0>(result), get<1>(result));
}
template <class TShape, class TStride,
class UShape, class UStride>
CUTE_HOST_DEVICE constexpr
auto
raked_product(Layout<TShape,TStride> const& block,
Layout<UShape,UStride> const& tiler)
{
constexpr int R = cute::max(rank_v<TShape>, rank_v<UShape>);
auto result = logical_product(append<R>(block), append<R>(tiler));
return zip(get<1>(result), get<0>(result));
}
template <class Shape, class Stride,
class TrgShape, class ModeOrder = LayoutLeft>
CUTE_HOST_DEVICE constexpr
auto
tile_to_shape(Layout<Shape,Stride> const& block,
TrgShape const& trg_shape,
ModeOrder const& ord_shape = {})
{
CUTE_STATIC_ASSERT_V(rank(block) <= rank(trg_shape), "Rank of layout must be <= rank of target shape.");
constexpr int R = rank_v<TrgShape>;
auto padded_block = append<R>(block);
auto block_shape = product_each(shape(padded_block));
auto target_shape = product_each(shape(trg_shape));
if constexpr (is_static<decltype(target_shape)>::value) {
CUTE_STATIC_ASSERT_V(evenly_divides(target_shape, block_shape),
"tile_to_shape: block shape does not divide the target shape.");
}
auto product_shape = ceil_div(target_shape, block_shape);
return blocked_product(padded_block, make_ordered_layout(product_shape, ord_shape));
}
template <int N, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
upcast(Shape const& shape, Stride const& stride)
{
if constexpr (is_tuple<Shape>::value) { return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast<N>(s,d); });
} else if constexpr (is_constant<0, Stride>::value) { return Layout<Shape,Stride>{shape,stride};
} else if constexpr (is_static<Stride>::value) { static_assert(Stride::value % N == 0 or N % Stride::value == 0, "Divisibility condition");
return make_layout(ceil_div(shape, ceil_div(Int<N>{}, abs(stride))),
signum(stride) * ceil_div(abs(stride), Int<N>{}));
} else { return make_layout(shape, safe_div(stride, Int<N>{}));
}
CUTE_GCC_UNREACHABLE;
}
template <int N, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
upcast(Layout<Shape,Stride> const& layout)
{
return upcast<N>(layout.shape(), layout.stride());
}
template <int N, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
downcast(Shape const& shape, Stride const& stride)
{
if constexpr (is_tuple<Shape>::value) {
return transform_layout(shape, stride, [](auto const& s, auto const& d) { return downcast<N>(s,d); });
} else if constexpr (is_constant<1, Stride>::value || is_constant<-1, Stride>::value) {
return make_layout(shape * Int<N>{}, stride);
} else {
return make_layout(shape, stride * Int<N>{});
}
CUTE_GCC_UNREACHABLE;
}
template <int N, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
downcast(Layout<Shape,Stride> const& layout)
{
CUTE_STATIC_ASSERT(has_int1<Stride>::value, "Downcast requires adjacent elements");
return downcast<N>(layout.shape(), layout.stride());
}
template <class OldType, class NewType,
class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
recast_layout(Layout<Shape,Stride> const& layout)
{
using scale = decltype(trait_ratio(sizeof_bits<NewType>{}, sizeof_bits<OldType>{}));
if constexpr (scale::num == 1 && scale::den == 1) {
return layout;
}
else if constexpr (scale::num == 1) {
return downcast<scale::den>(layout);
}
else if constexpr (scale::den == 1) {
return upcast<scale::num>(layout);
}
else {
return downcast<scale::den>(upcast<scale::num>(layout));
}
CUTE_GCC_UNREACHABLE;
}
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
max_alignment(Layout<Shape,Stride> const& layout)
{
auto flat_layout = coalesce(layout);
auto static_shape = transform( shape(flat_layout), [](auto s){ return conditional_return<is_static<decltype(s)>::value>(s, Int<1>{}); });
auto static_stride = transform(stride(flat_layout), [](auto d){ return conditional_return<is_static<decltype(d)>::value>(d, Int<0>{}); });
auto filter_layout = make_layout(static_shape, static_stride);
auto permuted = logical_divide(filter_layout, right_inverse(filter_layout));
return gcd(size<0>(permuted), stride<1>(permuted));
}
template <class Shape, class Stride>
CUTE_HOST_DEVICE void print(Layout<Shape,Stride> const& layout)
{
print(layout.shape()); print(":"); print(layout.stride());
}
#if !defined(__CUDACC_RTC__)
template <class Shape, class Stride>
CUTE_HOST std::ostream& operator<<(std::ostream& os, Layout<Shape,Stride> const& layout)
{
return os << shape(layout) << ":" << stride(layout);
}
#endif
template <class Layout>
CUTE_HOST_DEVICE
void
print_layout(Layout const& layout) {
CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{});
int idx_width = num_digits(cosize(layout)) + 2;
const char* delim = "+-----------------------";
print(layout); print("\n");
print(" ");
for (int n = 0; n < size<1>(layout); ++n) { printf(" %*d ", idx_width-2, n); }
printf("\n");
for (int m = 0; m < size<0>(layout); ++m) {
print(" ");
for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); }
printf("+\n");
printf("%2d ", m); for (int n = 0; n < size<1>(layout); ++n) { printf("| %*d ", idx_width-2, int(layout(m,n))); }
printf("|\n");
}
print(" ");
for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); }
printf("+\n");
}
template <class Layout, class ThrID>
CUTE_HOST_DEVICE
void
print_layout(Layout const& layout, ThrID const& thrid) {
CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{});
print(layout); print("\n");
print(thrid); print("\n");
for (int m = 0; m < size<0>(layout); ++m) {
for (int n = 0; n < size<1>(layout); ++n) printf("+------");
printf("+\n");
for (int n = 0; n < size<1>(layout); ++n) printf("|%03d-%02d", int(thrid(layout(m,n) % size(thrid))), int(layout(m,n) / size(thrid)));
printf("|\n");
}
for (int n = 0; n < size<1>(layout); ++n) printf("+------");
printf("+\n");
}
struct TikzColor_White {
CUTE_HOST_DEVICE char const*
operator()(int idx) const {
return "white";
}
};
struct TikzColor_BWx8 {
CUTE_HOST_DEVICE char const*
operator()(int idx) const {
static char const* color_map[8] = {"black!00", "black!40", "black!20", "black!60",
"black!10", "black!50", "black!30", "black!70"};
return color_map[idx % 8];
}
};
struct TikzColor_TV {
CUTE_HOST_DEVICE char const*
operator()(int tid, int vid) const {
static char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}",
"{rgb,255:red,175;green,255;blue,175}",
"{rgb,255:red,255;green,255;blue,175}",
"{rgb,255:red,255;green,175;blue,175}",
"{rgb,255:red,210;green,210;blue,255}",
"{rgb,255:red,210;green,255;blue,210}",
"{rgb,255:red,255;green,255;blue,210}",
"{rgb,255:red,255;green,210;blue,210}"};
return color_map[tid % 8];
}
};
template <class LayoutA, class TikzColorFn = TikzColor_BWx8>
CUTE_HOST_DEVICE
void
print_latex(LayoutA const& layout_a, TikzColorFn color = {}) {
CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{});
auto layout = append<2>(layout_a, Layout<_1,_0>{});
printf("%% Layout: "); print(layout); printf("\n");
printf("\\documentclass[convert]{standalone}\n"
"\\usepackage{tikz}\n\n"
"\\begin{document}\n"
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
for (int i = 0; i < size<0>(layout); ++i) {
for (int j = 0; j < size<1>(layout); ++j) {
int idx = layout(i,j);
printf("\\node[fill=%s] at (%d,%d) {%d};\n",
color(idx), i, j, idx);
}
}
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n",
int(size<0>(layout)), int(size<1>(layout)));
for (int i = 0, j = -1; i < size<0>(layout); ++i) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i);
}
for (int i = -1, j = 0; j < size<1>(layout); ++j) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j);
}
printf("\\end{tikzpicture}\n"
"\\end{document}\n");
}
template <class Layout, class ThrID, class TikzColorFn = TikzColor_TV>
CUTE_HOST_DEVICE
void
print_latex(Layout const& layout, ThrID const& thr, TikzColorFn color = {}) {
CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{});
printf("%% Layout: "); print(layout); printf("\n");
printf("%% ThrID : "); print(thr); printf("\n");
printf("\\documentclass[convert]{standalone}\n"
"\\usepackage{tikz}\n\n"
"\\begin{document}\n"
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");
for (int i = 0; i < size<0>(layout); ++i) {
for (int j = 0; j < size<1>(layout); ++j) {
int thrid = layout(i,j) % size(thr);
int val_idx = layout(i,j) / size(thr);
int thr_idx = thr(thrid);
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color(thr_idx, val_idx),
i, j,
thr_idx, val_idx);
}
}
printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n",
int(size<0>(layout)), int(size<1>(layout)));
for (int i = 0, j = -1; i < size<0>(layout); ++i) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i);
}
for (int j = 0, i = -1; j < size<1>(layout); ++j) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j);
}
printf("\\end{tikzpicture}\n"
"\\end{document}\n");
}
}