ptx-90-parser 0.1.0

Parse NVIDIA PTX 9.0 assembly into a structured AST and explore modules via a CLI.
Documentation
use ptx_parser::{
    parse_module_directive, GlobalInitializer, InitializerValue, ModuleDirective,
    ModuleVariableDirective, NumericLiteral, PtxParseError, ScalarType, VariableDirective,
};

#[test]
fn parses_const_scalar_initializer() {
    let variable = expect_variable(".const .u32 foo = 42;", VariableKind::Const);

    let value = match variable.initializer.as_ref().expect("expected initializer") {
        GlobalInitializer::Scalar(inner) => inner,
        other => panic!("expected scalar initializer, got {:?}", other),
    };
    assert_eq!(initializer_numeric_value(value), 42);
}

#[test]
fn parses_global_array_numeric_initializer() {
    let variable = expect_variable(".global .u32 bar[] = { 2, 3, 5 };", VariableKind::Global);

    let initializer = variable.initializer.as_ref().expect("expected initializer");
    assert_eq!(collect_numeric_array(initializer), vec![2, 3, 5]);
}

#[test]
fn parses_global_scalar_without_initializer() {
    let variable = expect_variable(".global .u32 loc;", VariableKind::Global);
    assert!(variable.initializer.is_none());
}

#[test]
fn parses_global_symbol_reference() {
    let variable = expect_variable(".global .u32 p1 = foo;", VariableKind::Global);
    let initializer = variable.initializer.as_ref().expect("expected initializer");
    assert_eq!(collect_symbol(initializer), "foo");
}

#[test]
fn parses_global_generic_symbol_reference() {
    let variable = expect_variable(".global .u32 p2 = generic(foo);", VariableKind::Global);
    let initializer = variable.initializer.as_ref().expect("expected initializer");
    assert_eq!(collect_symbol(initializer), "generic(foo)");
}

#[test]
fn parses_global_generic_pointer_array() {
    let variable = expect_variable(
        ".global .u32 parr[] = { generic(bar), generic(bar)+4, generic(bar)+8 };",
        VariableKind::Global,
    );
    let initializer = variable.initializer.as_ref().expect("expected initializer");
    assert_eq!(
        collect_symbol_array(initializer),
        vec!["generic(bar)", "generic(bar)+4", "generic(bar)+8"]
    );
}

#[test]
fn parses_global_mask_expressions() {
    let variable = expect_variable(
        ".global .u8 addr[] = {0xff(foo), 0xff00(foo), 0xff0000(foo)};",
        VariableKind::Global,
    );
    let initializer = variable.initializer.as_ref().expect("expected initializer");
    assert_eq!(
        collect_symbol_array(initializer),
        vec!["0xff(foo)", "0xff00(foo)", "0xff0000(foo)"]
    );
}

#[test]
fn parses_global_mask_with_offset() {
    let variable = expect_variable(
        ".global .u8 addr2[] = {0xff(foo+4), 0xff00(foo+4), 0xff0000(foo+4)};",
        VariableKind::Global,
    );
    let initializer = variable.initializer.as_ref().expect("expected initializer");
    assert_eq!(
        collect_symbol_array(initializer),
        vec!["0xff(foo+4)", "0xff00(foo+4)", "0xff0000(foo+4)"]
    );
}

#[test]
fn parses_global_mask_with_generic_expressions() {
    let variable = expect_variable(
        ".global .u8 addr3[] = {0xff(generic(foo)), 0xff00(generic(foo))};",
        VariableKind::Global,
    );
    let initializer = variable.initializer.as_ref().expect("expected initializer");
    assert_eq!(
        collect_symbol_array(initializer),
        vec!["0xff(generic(foo))", "0xff00(generic(foo))"]
    );
}

#[test]
fn parses_global_mask_with_generic_offset_expressions() {
    let variable = expect_variable(
        ".global .u8 addr4[] = {0xff(generic(foo)+4), 0xff00(generic(foo)+4)};",
        VariableKind::Global,
    );
    let initializer = variable.initializer.as_ref().expect("expected initializer");
    assert_eq!(
        collect_symbol_array(initializer),
        vec!["0xff(generic(foo)+4)", "0xff00(generic(foo)+4)"]
    );
}

#[test]
fn parses_global_mask_with_const_expression() {
    let variable = expect_variable(
        ".global .u8 addr5[] = { 0xFF(1000 + 546), 0xFF00(131187) };",
        VariableKind::Global,
    );
    let initializer = variable.initializer.as_ref().expect("expected initializer");
    assert_eq!(
        collect_symbol_array(initializer),
        vec!["0xFF(1000 + 546)", "0xFF00(131187)"]
    );
}

#[test]
fn parses_multi_dimensional_global_array() {
    let variable = expect_variable(
        ".global .s32 offset[][2] = { {-1, 0}, {0, -1}, {1, 0}, {0, 1} };",
        VariableKind::Global,
    );

    let array = variable
        .array
        .as_ref()
        .expect("expected array dimensions on global");
    assert_eq!(array.dimensions, vec![None, Some(2)]);

    let initializer = variable
        .initializer
        .as_ref()
        .expect("expected initializer on global");
    let rows = match initializer {
        GlobalInitializer::Aggregate(rows) => rows,
        other => panic!("expected aggregate initializer, got {:?}", other),
    };
    assert_eq!(rows.len(), 4);

    let expected = [[-1, 0], [0, -1], [1, 0], [0, 1]];

    for (row, expected_values) in rows.iter().zip(expected.iter()) {
        let cols = match row {
            GlobalInitializer::Aggregate(cols) => cols,
            other => panic!("expected aggregate row, got {:?}", other),
        };
        assert_eq!(cols.len(), 2);
        for (value, expected_value) in cols.iter().zip(expected_values.iter()) {
            let numeric = match value {
                GlobalInitializer::Scalar(inner) => inner,
                other => panic!("expected scalar initializer, got {:?}", other),
            };
            let actual = initializer_numeric_value(numeric);
            assert_eq!(actual, *expected_value);
        }
    }
}

#[test]
fn parses_const_float_array_initializer() {
    let variable = expect_variable(".const .f32 bias[] = {-1.0, 1.0};", VariableKind::Const);
    let initializer = variable.initializer.as_ref().expect("expected initializer");
    match initializer {
        GlobalInitializer::Aggregate(values) => {
            assert_eq!(values.len(), 2);
            for (value, expected) in values.iter().zip([-1.0_f64, 1.0_f64]) {
                let numeric = match value {
                    GlobalInitializer::Scalar(InitializerValue::Numeric(
                        NumericLiteral::Float64(bits),
                    )) => f64::from_bits(*bits),
                    other => panic!("expected float literal, got {:?}", other),
                };
                assert!((numeric - expected).abs() < f64::EPSILON);
            }
        }
        other => panic!("expected aggregate float initializer, got {:?}", other),
    }
}

#[test]
fn parses_global_byte_array_initializer() {
    let variable = expect_variable(".global .u8 bg[4] = {0, 0, 0, 0};", VariableKind::Global);
    let initializer = variable.initializer.as_ref().expect("expected initializer");
    assert_eq!(collect_numeric_array(initializer), vec![0, 0, 0, 0]);
}

#[test]
fn parses_shared_variable() {
    let variable = expect_variable(".shared .align 16 .u8 s0[4];", VariableKind::Shared);
    assert_eq!(variable.name, "s0");
    assert_eq!(variable.alignment, Some(16));
}

#[test]
fn parses_tex_variable() {
    let variable = expect_variable(".tex .sampler foo;", VariableKind::Tex);
    assert_eq!(variable.name, "foo");
}

#[test]
fn parses_global_texref_type() {
    let variable = expect_variable(".global .texref tex_handle;", VariableKind::Global);
    assert_eq!(variable.ty, Some(ScalarType::TexRef));
}

#[test]
fn parses_global_samplerref_type() {
    let variable = expect_variable(".global .samplerref sampler_handle;", VariableKind::Global);
    assert_eq!(variable.ty, Some(ScalarType::SamplerRef));
}

#[test]
fn parses_global_surfref_type() {
    let variable = expect_variable(".global .surfref surface_handle;", VariableKind::Global);
    assert_eq!(variable.ty, Some(ScalarType::SurfRef));
}

#[test]
fn rejects_reg_scalar_directive() {
    assert!(parse_module_directive(".reg .s32 i;", 1).is_err());
}

#[test]
fn rejects_reg_vector_directive() {
    assert!(parse_module_directive(".reg .v4 .f32 accel;", 1).is_err());
}

#[test]
fn rejects_reg_predicate_directive() {
    assert!(parse_module_directive(".reg .pred p, q, r;", 1).is_err());
}

#[test]
fn rejects_volatile_global_declaration() {
    let err = parse_module_directive(".global .volatile .b32 bad;", 1)
        .expect_err(".volatile global should be rejected");
    match err {
        PtxParseError::InvalidGlobal { message, .. } => {
            assert!(message.contains(".volatile"), "unexpected error: {message}");
        }
        other => panic!("expected InvalidGlobal error, got {other:?}"),
    }
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum VariableKind {
    Const,
    Global,
    Shared,
    Tex,
}

fn expect_variable(line: &str, kind: VariableKind) -> VariableDirective {
    match parse_module_directive(line, 1).expect("directive should parse") {
        ModuleDirective::ModuleVariable(module_var) => match module_var {
            ModuleVariableDirective::Const(var) => match kind {
                VariableKind::Const => var,
                other => panic!("expected {:?} directive, got const", other),
            },
            ModuleVariableDirective::Global(var) => match kind {
                VariableKind::Global => var,
                other => panic!("expected {:?} directive, got global", other),
            },
            ModuleVariableDirective::Shared(var) => match kind {
                VariableKind::Shared => var,
                other => panic!("expected {:?} directive, got shared", other),
            },
            ModuleVariableDirective::Tex(var) => match kind {
                VariableKind::Tex => var,
                other => panic!("expected {:?} directive, got tex", other),
            },
        },
        other => panic!("expected module variable directive, got {:?}", other),
    }
}

fn initializer_numeric_value(value: &InitializerValue) -> i64 {
    match value {
        InitializerValue::Numeric(NumericLiteral::Signed(v)) => *v,
        InitializerValue::Numeric(NumericLiteral::Unsigned(v)) => *v as i64,
        other => panic!("expected integer numeric initializer, got {:?}", other),
    }
}

fn collect_numeric_array(initializer: &GlobalInitializer) -> Vec<i64> {
    match initializer {
        GlobalInitializer::Aggregate(items) => items
            .iter()
            .map(|item| match item {
                GlobalInitializer::Scalar(inner) => initializer_numeric_value(inner),
                other => panic!("expected scalar numeric element, got {:?}", other),
            })
            .collect(),
        GlobalInitializer::Scalar(inner) => vec![initializer_numeric_value(inner)],
    }
}

fn collect_symbol_array(initializer: &GlobalInitializer) -> Vec<&str> {
    match initializer {
        GlobalInitializer::Aggregate(items) => items.iter().map(collect_symbol).collect(),
        GlobalInitializer::Scalar(_) => vec![collect_symbol(initializer)],
    }
}

fn collect_symbol(initializer: &GlobalInitializer) -> &str {
    match initializer {
        GlobalInitializer::Scalar(InitializerValue::Symbol(symbol)) => symbol.as_str(),
        other => panic!("expected scalar symbol initializer, got {:?}", other),
    }
}