vyre-emit-ptx 0.6.1

PTX text emitter for vyre KernelDescriptor. Produces NVRTC-compatible CUDA assembly.
Documentation
//! Test: subgroup.
use super::*;

#[test]
fn subgroup_ballot_emits_vote_sync_ballot() {
    let kernel = KernelDescriptor {
        id: "ballot".into(),
        bindings: BindingLayout { slots: vec![] },
        dispatch: Dispatch::new(64, 1, 1),
        body: KernelBody {
            ops: vec![
                KernelOp {
                    kind: KernelOpKind::Literal,
                    operands: vec![0],
                    result: Some(0),
                },
                KernelOp {
                    kind: KernelOpKind::SubgroupBallot,
                    operands: vec![0],
                    result: Some(1),
                },
            ],
            child_bodies: vec![],
            literals: vec![LiteralValue::Bool(true)],
        },
    };
    let s = emit(&kernel).unwrap();
    assert!(s.contains("activemask.b32"));
    assert!(s.contains("vote.sync.ballot.b32"));
}

#[test]
fn subgroup_shuffle_emits_shfl_sync_idx() {
    let kernel = KernelDescriptor {
        id: "shuffle".into(),
        bindings: BindingLayout { slots: vec![] },
        dispatch: Dispatch::new(64, 1, 1),
        body: KernelBody {
            ops: vec![
                KernelOp {
                    kind: KernelOpKind::Literal,
                    operands: vec![0],
                    result: Some(0),
                },
                KernelOp {
                    kind: KernelOpKind::Literal,
                    operands: vec![1],
                    result: Some(1),
                },
                KernelOp {
                    kind: KernelOpKind::SubgroupShuffle,
                    operands: vec![0, 1],
                    result: Some(2),
                },
            ],
            child_bodies: vec![],
            literals: vec![LiteralValue::U32(7), LiteralValue::U32(3)],
        },
    };
    let s = emit(&kernel).unwrap();
    assert!(s.contains("shfl.sync.idx.b32"));
}

#[test]
fn subgroup_add_emits_redux_sync() {
    let kernel = KernelDescriptor {
        id: "add".into(),
        bindings: BindingLayout { slots: vec![] },
        dispatch: Dispatch::new(64, 1, 1),
        body: KernelBody {
            ops: vec![
                KernelOp {
                    kind: KernelOpKind::Literal,
                    operands: vec![0],
                    result: Some(0),
                },
                KernelOp {
                    kind: KernelOpKind::SubgroupAdd,
                    operands: vec![0],
                    result: Some(1),
                },
            ],
            child_bodies: vec![],
            literals: vec![LiteralValue::F32(5.0)],
        },
    };
    let s = emit(&kernel).unwrap();
    assert!(s.contains("redux.sync.add.f32"));
}

#[test]
fn subgroup_local_id_emits_laneid() {
    let kernel = KernelDescriptor {
        id: "lane".into(),
        bindings: BindingLayout { slots: vec![] },
        dispatch: Dispatch::new(64, 1, 1),
        body: KernelBody {
            ops: vec![KernelOp {
                kind: KernelOpKind::SubgroupLocalId,
                operands: vec![],
                result: Some(0),
            }],
            child_bodies: vec![],
            literals: vec![],
        },
    };
    let s = emit(&kernel).unwrap();
    assert!(s.contains("%laneid"));
}

#[test]
fn subgroup_size_emits_probed_width_literal() {
    let kernel = KernelDescriptor {
        id: "wsz".into(),
        bindings: BindingLayout { slots: vec![] },
        dispatch: Dispatch::new(64, 1, 1),
        body: KernelBody {
            ops: vec![KernelOp {
                kind: KernelOpKind::SubgroupSize,
                operands: vec![],
                result: Some(0),
            }],
            child_bodies: vec![],
            literals: vec![],
        },
    };
    let s = emit_with_options(
        &kernel,
        PtxEmitOptions {
            target: ComputeCapability::SM_70,
            subgroup_size: 16,
            ulp_budget: None,
        },
    )
    .unwrap();
    assert!(s.contains("mov.u32") && s.contains(", 16;"));
}

#[test]
fn atomic_unsupported_op_returns_error() {
    use vyre_foundation::ir::AtomicOp;
    let kernel = KernelDescriptor {
        id: "k".into(),
        bindings: BindingLayout {
            slots: vec![BindingSlot {
                slot: 0,
                element_type: DataType::U32,
                element_count: None,
                memory_class: MemoryClass::Global,
                visibility: BindingVisibility::ReadWrite,
                name: "b".into(),
            }],
        },
        dispatch: Dispatch::new(1, 1, 1),
        body: KernelBody {
            ops: vec![
                KernelOp {
                    kind: KernelOpKind::Literal,
                    operands: vec![0],
                    result: Some(0),
                },
                KernelOp {
                    kind: KernelOpKind::Literal,
                    operands: vec![1],
                    result: Some(1),
                },
                KernelOp {
                    kind: KernelOpKind::Atomic {
                        op: AtomicOp::FetchNand,
                        ordering: vyre_foundation::runtime::memory_model::MemoryOrdering::SeqCst,
                    },
                    operands: vec![0, 0, 1],
                    result: Some(2),
                },
            ],
            child_bodies: vec![],
            literals: vec![LiteralValue::U32(0), LiteralValue::U32(1)],
        },
    };
    let r = emit(&kernel);
    assert!(matches!(r, Err(EmitError::UnsupportedOp(_))));
}

#[test]
fn for_loop_var_name_appears_in_comment() {
    let kernel = KernelDescriptor {
        id: "named_loop".into(),
        bindings: BindingLayout { slots: vec![] },
        dispatch: Dispatch::new(64, 1, 1),
        body: KernelBody {
            ops: vec![
                KernelOp {
                    kind: KernelOpKind::Literal,
                    operands: vec![0],
                    result: Some(0),
                },
                KernelOp {
                    kind: KernelOpKind::Literal,
                    operands: vec![1],
                    result: Some(1),
                },
                KernelOp {
                    kind: KernelOpKind::StructuredForLoop {
                        loop_var: "row_idx".into(),
                    },
                    operands: vec![0, 1, 0],
                    result: None,
                },
            ],
            child_bodies: vec![empty_child_body()],
            literals: vec![LiteralValue::U32(0), LiteralValue::U32(16)],
        },
    };
    let s = emit(&kernel).unwrap();
    assert!(s.contains("// for row_idx in"));
}