cubecl-cpp 0.10.0-pre.3

CPP transpiler for CubeCL
Documentation
use std::fmt::Display;

use crate::shared::{Component, FmtLeft};

use super::{Dialect, IndexedVariable, Item, Variable};

#[derive(Clone, Debug)]
pub enum WarpInstruction<D: Dialect> {
    ReduceSum {
        input: Variable<D>,
        out: Variable<D>,
    },
    InclusiveSum {
        input: Variable<D>,
        out: Variable<D>,
    },
    ExclusiveSum {
        input: Variable<D>,
        out: Variable<D>,
    },
    ReduceProd {
        input: Variable<D>,
        out: Variable<D>,
    },
    InclusiveProd {
        input: Variable<D>,
        out: Variable<D>,
    },
    ExclusiveProd {
        input: Variable<D>,
        out: Variable<D>,
    },
    ReduceMax {
        input: Variable<D>,
        out: Variable<D>,
    },
    ReduceMin {
        input: Variable<D>,
        out: Variable<D>,
    },
    ElectFallback {
        out: Variable<D>,
    },
    Elect {
        out: Variable<D>,
    },
    All {
        input: Variable<D>,
        out: Variable<D>,
    },
    Any {
        input: Variable<D>,
        out: Variable<D>,
    },
    Ballot {
        input: Variable<D>,
        out: Variable<D>,
    },
    Broadcast {
        input: Variable<D>,
        id: Variable<D>,
        out: Variable<D>,
    },
    Shuffle {
        input: Variable<D>,
        src_lane: Variable<D>,
        out: Variable<D>,
    },
    ShuffleXor {
        input: Variable<D>,
        mask: Variable<D>,
        out: Variable<D>,
    },
    ShuffleUp {
        input: Variable<D>,
        delta: Variable<D>,
        out: Variable<D>,
    },
    ShuffleDown {
        input: Variable<D>,
        delta: Variable<D>,
        out: Variable<D>,
    },
}

impl<D: Dialect> Display for WarpInstruction<D> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            WarpInstruction::ReduceSum { input, out } => D::warp_reduce_sum(f, input, out),
            WarpInstruction::ReduceProd { input, out } => D::warp_reduce_prod(f, input, out),
            WarpInstruction::ReduceMax { input, out } => D::warp_reduce_max(f, input, out),
            WarpInstruction::ReduceMin { input, out } => D::warp_reduce_min(f, input, out),
            WarpInstruction::All { input, out } => D::warp_reduce_all(f, input, out),
            WarpInstruction::Any { input, out } => D::warp_reduce_any(f, input, out),

            WarpInstruction::InclusiveSum { input, out } => {
                D::warp_reduce_sum_inclusive(f, input, out)
            }
            WarpInstruction::InclusiveProd { input, out } => {
                D::warp_reduce_prod_inclusive(f, input, out)
            }
            WarpInstruction::ExclusiveSum { input, out } => {
                D::warp_reduce_sum_exclusive(f, input, out)
            }
            WarpInstruction::ExclusiveProd { input, out } => {
                D::warp_reduce_prod_exclusive(f, input, out)
            }
            WarpInstruction::Ballot { input, out } => {
                assert_eq!(
                    input.item().vectorization,
                    1,
                    "Ballot can't support vectorized input"
                );
                let out_fmt = out.fmt_left();
                write!(
                    f,
                    "
{out_fmt} = {{ "
                )?;
                D::compile_warp_ballot(f, input, out.item().elem())?;
                writeln!(f, ", 0, 0, 0 }};")
            }
            WarpInstruction::Broadcast { input, id, out } => reduce_broadcast(f, input, out, id),
            WarpInstruction::Shuffle {
                input,
                src_lane,
                out,
            } => {
                let out_fmt = out.fmt_left();
                write!(f, "{out_fmt} = {{ ")?;
                for i in 0..input.item().vectorization {
                    let comma = if i > 0 { ", " } else { "" };
                    write!(f, "{comma}")?;
                    D::compile_warp_shuffle(
                        f,
                        &format!("{}", input.index(i)),
                        &format!("{src_lane}"),
                    )?;
                }
                writeln!(f, " }};")
            }
            WarpInstruction::ShuffleXor { input, mask, out } => {
                let out_fmt = out.fmt_left();
                write!(f, "{out_fmt} = {{ ")?;
                for i in 0..input.item().vectorization {
                    let comma = if i > 0 { ", " } else { "" };
                    write!(f, "{comma}")?;
                    D::compile_warp_shuffle_xor(
                        f,
                        &format!("{}", input.index(i)),
                        input.item().elem(),
                        &format!("{mask}"),
                    )?;
                }
                writeln!(f, " }};")
            }
            WarpInstruction::ShuffleUp { input, delta, out } => {
                let out_fmt = out.fmt_left();
                write!(f, "{out_fmt} = {{ ")?;
                for i in 0..input.item().vectorization {
                    let comma = if i > 0 { ", " } else { "" };
                    write!(f, "{comma}")?;
                    D::compile_warp_shuffle_up(
                        f,
                        &format!("{}", input.index(i)),
                        &format!("{delta}"),
                    )?;
                }
                writeln!(f, " }};")
            }
            WarpInstruction::ShuffleDown { input, delta, out } => {
                let out_fmt = out.fmt_left();
                write!(f, "{out_fmt} = {{ ")?;
                for i in 0..input.item().vectorization {
                    let comma = if i > 0 { ", " } else { "" };
                    write!(f, "{comma}")?;
                    D::compile_warp_shuffle_down(
                        f,
                        &format!("{}", input.index(i)),
                        &format!("{delta}"),
                    )?;
                }
                writeln!(f, " }};")
            }
            WarpInstruction::ElectFallback { out } => {
                let out = out.fmt_left();
                write!(
                    f,
                    "
unsigned int mask = __activemask();
unsigned int leader = __ffs(mask) - 1;
{out} = threadIdx.x % warpSize == leader;
            "
                )
            }
            WarpInstruction::Elect { out } => {
                let out = out.fmt_left();
                D::compile_warp_elect(f, &out)
            }
        }
    }
}

pub(crate) fn reduce_operator<D: Dialect>(
    f: &mut core::fmt::Formatter<'_>,
    input: &Variable<D>,
    out: &Variable<D>,
    op: &str,
) -> core::fmt::Result {
    let in_optimized = input.optimized();
    let acc_item = in_optimized.item();

    reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
        let acc_indexed = maybe_index(acc, index);
        write!(f, "{acc_indexed} {op} ")?;
        D::compile_warp_shuffle_xor(f, &acc_indexed, acc.item().elem(), "offset")?;
        writeln!(f, ";")
    })
}

pub(crate) fn reduce_comparison<
    D: Dialect,
    I: Fn(&mut core::fmt::Formatter<'_>, Item<D>) -> std::fmt::Result,
>(
    f: &mut core::fmt::Formatter<'_>,
    input: &Variable<D>,
    out: &Variable<D>,
    instruction: I,
) -> core::fmt::Result {
    let in_optimized = input.optimized();
    let acc_item = in_optimized.item();
    reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
        let acc_indexed = maybe_index(acc, index);
        let acc_elem = acc_item.elem();
        write!(f, "        {acc_indexed} = ")?;
        instruction(f, in_optimized.item())?;
        write!(f, "({acc_indexed}, ")?;
        D::compile_warp_shuffle_xor(f, &acc_indexed, acc_elem, "offset")?;
        writeln!(f, ");")
    })
}

pub(crate) fn reduce_inclusive<D: Dialect>(
    f: &mut core::fmt::Formatter<'_>,
    input: &Variable<D>,
    out: &Variable<D>,
    op: &str,
) -> core::fmt::Result {
    let in_optimized = input.optimized();
    let acc_item = in_optimized.item();

    reduce_with_loop(f, input, out, acc_item, |f, acc, index| {
        let acc_indexed = maybe_index(acc, index);
        let tmp = Variable::tmp(Item::scalar(acc_item.elem, false));
        let tmp_left = tmp.fmt_left();
        let lane_id = Variable::<D>::UnitPosPlane;
        write!(
            f,
            "
{tmp_left} = "
        )?;
        D::compile_warp_shuffle_up(f, &acc_indexed, "offset")?;
        write!(
            f,
            ";
if({lane_id} >= offset) {{
    {acc_indexed} {op} {tmp};
}}
"
        )
    })
}

pub(crate) fn reduce_exclusive<D: Dialect>(
    f: &mut core::fmt::Formatter<'_>,
    input: &Variable<D>,
    out: &Variable<D>,
    op: &str,
    default: &str,
) -> core::fmt::Result {
    let in_optimized = input.optimized();
    let acc_item = in_optimized.item();

    let inclusive = Variable::tmp(acc_item);
    reduce_inclusive(f, input, &inclusive, op)?;
    let shfl = Variable::tmp(acc_item);
    writeln!(f, "{} = {{", shfl.fmt_left())?;
    for k in 0..acc_item.vectorization {
        let inclusive_indexed = maybe_index(&inclusive, k);
        let comma = if k > 0 { ", " } else { "" };
        write!(f, "{comma}")?;
        D::compile_warp_shuffle_up(f, &inclusive_indexed.to_string(), "1")?;
    }
    writeln!(f, "}};")?;
    let lane_id = Variable::<D>::UnitPosPlane;

    write!(
        f,
        "{} = ({lane_id} == 0) ? {}{{",
        out.fmt_left(),
        out.item(),
    )?;
    for _ in 0..out.item().vectorization {
        write!(f, "{default},")?;
    }
    writeln!(f, "}} : {};", cast(&shfl, out.item()))
}

pub(crate) fn reduce_broadcast<D: Dialect>(
    f: &mut core::fmt::Formatter<'_>,
    input: &Variable<D>,
    out: &Variable<D>,
    id: &Variable<D>,
) -> core::fmt::Result {
    let out_fmt = out.fmt_left();
    write!(f, "{out_fmt} = {{ ")?;
    for i in 0..input.item().vectorization {
        let comma = if i > 0 { ", " } else { "" };
        write!(f, "{comma}")?;
        D::compile_warp_shuffle(f, &format!("{}", input.index(i)), &format!("{id}"))?;
    }
    writeln!(f, " }};")
}

fn reduce_with_loop<
    D: Dialect,
    I: Fn(&mut core::fmt::Formatter<'_>, &Variable<D>, usize) -> std::fmt::Result,
>(
    f: &mut core::fmt::Formatter<'_>,
    input: &Variable<D>,
    out: &Variable<D>,
    acc_item: Item<D>,
    instruction: I,
) -> core::fmt::Result {
    let acc = Variable::Named {
        name: "acc",
        item: acc_item,
    };
    let vectorization = acc_item.vectorization;

    writeln!(f, "auto plane_{out} = [&]() -> {} {{", out.item())?;
    writeln!(f, "    {} {} = {};", acc_item, acc, cast(input, acc_item))?;
    write!(f, "    for (uint offset = 1; offset < ")?;
    D::compile_plane_dim_checked(f)?;
    writeln!(f, "; offset *=2 ) {{")?;
    for k in 0..vectorization {
        instruction(f, &acc, k)?;
    }
    writeln!(f, "    }};")?;
    writeln!(f, "    return {};", cast(&acc, out.item()))?;
    writeln!(f, "}};")?;
    writeln!(f, "{} = plane_{}();", out.fmt_left(), out)
}

pub(crate) fn reduce_quantifier<
    D: Dialect,
    Q: Fn(&mut core::fmt::Formatter<'_>, &IndexedVariable<D>) -> std::fmt::Result,
>(
    f: &mut core::fmt::Formatter<'_>,
    input: &Variable<D>,
    out: &Variable<D>,
    quantifier: Q,
) -> core::fmt::Result {
    let out_fmt = out.fmt_left();
    write!(f, "{out_fmt} = {{ ")?;
    for i in 0..input.item().vectorization {
        let comma = if i > 0 { ", " } else { "" };
        write!(f, "{comma}")?;
        quantifier(f, &input.index(i))?;
    }
    writeln!(f, "}};")
}

fn cast<D: Dialect>(input: &Variable<D>, target: Item<D>) -> String {
    if target != input.item() {
        let addr_space = D::address_space_for_variable(input);
        let qualifier = input.const_qualifier();
        format!("reinterpret_cast<{addr_space}{target}{qualifier}&>({input})")
    } else {
        format!("{input}")
    }
}

fn maybe_index<D: Dialect>(var: &Variable<D>, k: usize) -> String {
    if var.item().vectorization > 1 {
        format!("{var}.i_{k}")
    } else {
        format!("{var}")
    }
}