rasen 0.12.0

Build a SPIR-V module from a data flow graph
use spirv_headers::*;
use rspirv::mr::{
    Instruction, Operand
};

use builder::Builder;
use types::*;
use errors::*;

#[inline]
fn imul<B: Builder>(builder: &mut B, res_type: &'static TypeName, res_id: u32, lhs: u32, rhs: u32) {
    let res_type = builder.register_type(res_type);

    builder.push_instruction(
        Instruction::new(
            Op::IMul,
            Some(res_type),
            Some(res_id),
            vec![
                Operand::IdRef(lhs),
                Operand::IdRef(rhs),
            ]
        )
    );
}
#[inline]
fn fmul<B: Builder>(builder: &mut B, res_type: &'static TypeName, res_id: u32, lhs: u32, rhs: u32) {
    let res_type = builder.register_type(res_type);

    builder.push_instruction(
        Instruction::new(
            Op::FMul,
            Some(res_type),
            Some(res_id), vec![
                Operand::IdRef(lhs),
                Operand::IdRef(rhs),
            ]
        )
    );
}
#[inline]
fn vector_times_scalar<B: Builder>(builder: &mut B, res_type: &'static TypeName, res_id: u32, vector: u32, scalar: u32) {
    let res_type = builder.register_type(res_type);

    builder.push_instruction(
        Instruction::new(
            Op::VectorTimesScalar,
            Some(res_type),
            Some(res_id),
            vec![
                Operand::IdRef(vector),
                Operand::IdRef(scalar),
            ]
        )
    );
}
#[inline]
fn matrix_times_scalar<B: Builder>(builder: &mut B, res_type: &'static TypeName, res_id: u32, matrix: u32, scalar: u32) {
    let res_type = builder.register_type(res_type);

    builder.push_instruction(
        Instruction::new(
            Op::MatrixTimesScalar,
            Some(res_type),
            Some(res_id),
            vec![
                Operand::IdRef(matrix),
                Operand::IdRef(scalar),
            ]
        )
    );
}

#[inline]
pub fn multiply<B: Builder>(builder: &mut B, args: &[(&'static TypeName, u32)]) -> Result<(&'static TypeName, u32)> {
    use types::TypeName::*;

    let (l_arg, r_arg) = match args.len() {
        2 => (
            args[0],
            args[1],
        ),
        n if n > 2 => (
            multiply(builder, &args[0..n - 1])?,
            args[n - 1],
        ),
        n => bail!(ErrorKind::WrongArgumentsCount(n, 2)),
    };

    let (l_type, l_value) = l_arg;
    let (r_type, r_value) = r_arg;
    let res_id = builder.get_id();

    let res_type = match (l_type, r_type) {
        _ if l_type == r_type && r_type.is_integer() => {
            imul(builder, l_type, res_id, l_value, r_value);
            l_type
        },
        (&Vec(l_len, l_scalar), &Vec(r_len, r_scalar)) if l_len == r_len && l_scalar == r_scalar && r_scalar.is_integer() => {
            imul(builder, l_type, res_id, l_value, r_value);
            l_type
        },

        _ if l_type == r_type && l_type.is_float() => {
            fmul(builder, l_type, res_id, l_value, r_value);
            l_type
        },
        (&Vec(l_len, l_scalar), &Vec(r_len, r_scalar)) if l_len == r_len && l_scalar == r_scalar && r_scalar.is_float() => {
            fmul(builder, l_type, res_id, l_value, r_value);
            l_type
        },

        (&Vec(_, v_scalar), t_scalar) if t_scalar == v_scalar && t_scalar.is_float() => {
            vector_times_scalar(builder, l_type, res_id, r_value, l_value);
            l_type
        },
        (t_scalar, &Vec(_, v_scalar)) if t_scalar == v_scalar && t_scalar.is_float() => {
            vector_times_scalar(builder, r_type, res_id, l_value, r_value);
            r_type
        },

        (&Mat(_, &Vec(_, m_scalar)), t_scalar) if t_scalar == m_scalar && t_scalar.is_float() => {
            matrix_times_scalar(builder, l_type, res_id, l_value, r_value);
            l_type
        },
        (t_scalar, &Mat(_, &Vec(_, m_scalar))) if t_scalar == m_scalar && t_scalar.is_float() => {
            matrix_times_scalar(builder, r_type, res_id, r_value, l_value);
            r_type
        },

        (&Vec(v_len, l_scalar), &Mat(m_len, &Vec(_, r_scalar))) if v_len == m_len && l_scalar == r_scalar && l_scalar.is_float() => {
            let res_type = builder.register_type(l_type);

            builder.push_instruction(
                Instruction::new(
                    Op::VectorTimesMatrix,
                    Some(res_type),
                    Some(res_id),
                    vec![
                        Operand::IdRef(l_value),
                        Operand::IdRef(r_value),
                     ]
                 )
             );

            l_type
        },
        (&Mat(m_len, &Vec(_, l_scalar)), &Vec(v_len, r_scalar)) if v_len == m_len && l_scalar == r_scalar && l_scalar.is_float() => {
            let res_type = builder.register_type(r_type);

            builder.push_instruction(
                Instruction::new(
                    Op::MatrixTimesVector,
                    Some(res_type),
                    Some(res_id),
                    vec![
                        Operand::IdRef(l_value),
                        Operand::IdRef(r_value),
                    ]
                )
            );

            r_type
        },

        (&Mat(l_len, &Vec(_, l_scalar)), &Mat(r_len, &Vec(_, r_scalar))) if l_len == r_len && l_scalar == r_scalar && l_scalar.is_float() => {
            let res_type = builder.register_type(l_type);

            builder.push_instruction(Instruction::new(Op::MatrixTimesMatrix, Some(res_type), Some(res_id), vec![ Operand::IdRef(l_value),
                Operand::IdRef(r_value),
             ]));

            l_type
        },

        _ => bail!(ErrorKind::BadArguments(Box::new([ l_type, r_type ]))),
    };

    Ok((res_type, res_id))
}

#[inline]
pub fn dot<B: Builder>(builder: &mut B, args: &[(&'static TypeName, u32)]) -> Result<(&'static TypeName, u32)> {
    use types::TypeName::*;

    if args.len() != 2 {
        bail!(ErrorKind::WrongArgumentsCount(args.len(), 2));
    }

    let (l_type, l_value) = args[0];
    let (r_type, r_value) = args[1];
    match (l_type, r_type) {
        (&Vec(l_size, l_scalar), &Vec(r_size, r_scalar)) if l_size == r_size && l_scalar == r_scalar && l_scalar.is_float() => {
            let res_type = builder.register_type(l_scalar);

            let result_id = builder.get_id();

            builder.push_instruction(Instruction::new(Op::Dot, Some(res_type), Some(result_id), vec![ Operand::IdRef(l_value),
                Operand::IdRef(r_value),
             ]));

            Ok((l_scalar, result_id))
        },
        _ => bail!(ErrorKind::BadArguments(Box::new([ l_type, r_type ]))),
    }
}

macro_rules! impl_math_op {
    ( $name:ident, $node:ident, $variadic:expr, $( $opcode:ident ),* ) => {
        #[inline]
        pub fn $name<B: Builder>(builder: &mut B, args: Vec<(&'static TypeName, u32)>) -> Result<(&'static TypeName, u32)> {
            use types::TypeName::*;

            let (l_arg, r_arg) = match args.len() {
                2 => (
                    args[0],
                    args[1],
                ),
                n if $variadic && n > 2 => (
                    $name(builder, args[0..n - 1].to_vec())?,
                    args[n - 1],
                ),
                n => bail!(ErrorKind::WrongArgumentsCount(n, 2)),
            };

            let result_id = builder.get_id();

            let (l_type, l_value) = l_arg;
            let (r_type, r_value) = r_arg;

            macro_rules! match_types {
                ( $uopcode:ident, $sopcode:ident, $fopcode:ident ) => {
                    match (l_type, r_type) {
                        _ if l_type == r_type && r_type.is_signed() => {
                            let res_type = builder.register_type(l_type);

                            builder.push_instruction(Instruction::new(Op::$sopcode, Some(res_type), Some(result_id), vec![ Operand::IdRef(l_value),
                                Operand::IdRef(r_value),
                             ]));

                            l_type
                        },
                        (&Vec(l_len, l_scalar), &Vec(r_len, r_scalar)) if l_len == r_len && l_scalar == r_scalar && r_scalar.is_signed() => {
                            let res_type = builder.register_type(l_type);

                            builder.push_instruction(Instruction::new(Op::$sopcode, Some(res_type), Some(result_id), vec![ Operand::IdRef(l_value),
                                Operand::IdRef(r_value),
                             ]));

                            l_type
                        },

                        _ if l_type == r_type && r_type.is_integer() && !r_type.is_signed() => {
                            let res_type = builder.register_type(l_type);

                            builder.push_instruction(Instruction::new(Op::$uopcode, Some(res_type), Some(result_id), vec![ Operand::IdRef(l_value),
                                Operand::IdRef(r_value),
                             ]));

                            l_type
                        },
                        (&Vec(l_len, l_scalar), &Vec(r_len, r_scalar)) if l_len == r_len && l_scalar == r_scalar && r_scalar.is_integer() => {
                            let res_type = builder.register_type(l_type);

                            builder.push_instruction(Instruction::new(Op::$uopcode, Some(res_type), Some(result_id), vec![ Operand::IdRef(l_value),
                                Operand::IdRef(r_value),
                             ]));

                            l_type
                        },

                        _ if l_type == r_type && r_type.is_float() => {
                            let res_type = builder.register_type(l_type);

                            builder.push_instruction(Instruction::new(Op::$fopcode, Some(res_type), Some(result_id), vec![ Operand::IdRef(l_value),
                                Operand::IdRef(r_value),
                             ]));

                            l_type
                        },
                        (&Vec(l_len, l_scalar), &Vec(r_len, r_scalar)) if l_len == r_len && l_scalar == r_scalar && r_scalar.is_float() => {
                            let res_type = builder.register_type(l_type);

                            builder.push_instruction(
                                Instruction::new(
                                    Op::$fopcode,
                                    Some(res_type),
                                    Some(result_id),
                                    vec![
                                        Operand::IdRef(l_value),
                                        Operand::IdRef(r_value),
                                    ]
                                )
                            );

                            l_type
                        },

                        _ => bail!(ErrorKind::BadArguments(Box::new([ l_type, r_type ]))),
                    }
                };
                ( $iopcode:ident, $fopcode:ident ) => {
                    match (l_type, r_type) {
                        _ if l_type == r_type && r_type.is_integer() => {
                            let res_type = builder.register_type(l_type);

                            builder.push_instruction(Instruction::new(Op::$iopcode, Some(res_type), Some(result_id), vec![ Operand::IdRef(l_value),
                                Operand::IdRef(r_value),
                             ]));

                            l_type
                        },
                        (&Vec(l_len, l_scalar), &Vec(r_len, r_scalar)) if l_len == r_len && l_scalar == r_scalar && r_scalar.is_integer() => {
                            let res_type = builder.register_type(l_type);

                            builder.push_instruction(Instruction::new(Op::$iopcode, Some(res_type), Some(result_id), vec![ Operand::IdRef(l_value),
                                Operand::IdRef(r_value),
                             ]));

                            l_type
                        },

                        _ if l_type == r_type && r_type.is_float() => {
                            let res_type = builder.register_type(l_type);

                            builder.push_instruction(Instruction::new(Op::$fopcode, Some(res_type), Some(result_id), vec![ Operand::IdRef(r_value),
                                Operand::IdRef(l_value),
                             ]));

                            l_type
                        },
                        (&Vec(l_len, l_scalar), &Vec(r_len, r_scalar)) if l_len == r_len && l_scalar == r_scalar && l_scalar.is_float() => {
                            let res_type = builder.register_type(l_type);

                            builder.push_instruction(Instruction::new(Op::$fopcode, Some(res_type), Some(result_id), vec![ Operand::IdRef(r_value),
                                Operand::IdRef(l_value),
                             ]));

                            l_type
                        },

                        _ => bail!(ErrorKind::BadArguments(Box::new([ l_type, r_type ]))),
                    }
                };
            }

            let res_type = match_types!( $( $opcode ),* );
            Ok((res_type, result_id))
        }
    };
}

impl_math_op!(add, Add, true, IAdd, FAdd);
impl_math_op!(subtract, Subtract, true, ISub, FSub);
impl_math_op!(divide, Divide, true, UDiv, SDiv, FDiv);
impl_math_op!(modulus, Modulus, false, UMod, SMod, FMod);