zust-vm-spirv 0.9.4

SPIR-V code generation backend for the Zust scripting language.
Documentation
use anyhow::Result;
use compiler::Compiler;
use dynamic::{Dynamic, Type};
use parser::Stmt;
use rspirv::{binary::Assemble, dr::Builder};
use smol_str::SmolStr;
use spirv::{AddressingModel, Capability, ExecutionModel, MemoryModel, StorageClass};
use std::{collections::BTreeMap, sync::Arc};

use crate::api::{BuiltinFn, ExternalFnKind, SpirvModule};

#[derive(Debug, Clone)]
pub(crate) struct Value {
    pub(crate) id: u32,
    pub(crate) ty: Type,
}

#[derive(Debug, Clone)]
pub(crate) struct Phi {
    pub(crate) idx: usize,
    pub(crate) ty: Type,
    pub(crate) result_id: u32,
    pub(crate) incoming: Vec<(u32, u32)>,
}

#[derive(Debug, Clone)]
pub(crate) struct UserFn {
    pub(crate) arg_names: Vec<SmolStr>,
    pub(crate) arg_tys: Vec<Type>,
    pub(crate) generic_params: Vec<Type>,
    pub(crate) body: Arc<Stmt>,
}

#[derive(Debug, Clone)]
pub(crate) struct UserFnInstance {
    pub(crate) source_id: u32,
    pub(crate) generic_args: Vec<Type>,
    pub(crate) arg_names: Vec<SmolStr>,
    pub(crate) arg_tys: Vec<Type>,
    pub(crate) ret_ty: Type,
    pub(crate) body: Stmt,
    pub(crate) fn_id: u32,
    pub(crate) fn_ty: u32,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum SpirvTy {
    Value(Type),
    LayoutValue(Type),
    Pointer(Type, StorageClass),
    Buffer(Type),
}

pub(crate) struct SpirvCompiler {
    pub(crate) builder: Builder,
    pub(crate) types: Vec<(SpirvTy, u32)>,
    pub(crate) consts: Vec<(Dynamic, u32)>,
    pub(crate) vars: Vec<Option<Value>>,
    pub(crate) local_ptrs: Vec<Option<Value>>,
    pub(crate) names: Vec<Option<SmolStr>>,
    pub(crate) buffers: Vec<u32>,
    pub(crate) interfaces: Vec<u32>,
    pub(crate) current_block: Option<u32>,
    pub(crate) loop_stack: Vec<(u32, u32)>,
    pub(crate) externs: BTreeMap<u32, ExternalFnKind>,
    pub(crate) user_fns: BTreeMap<u32, UserFn>,
    pub(crate) type_defs: BTreeMap<u32, Type>,
    pub(crate) workgroup_static_tys: BTreeMap<u32, Type>,
    pub(crate) workgroup_statics: BTreeMap<u32, Value>,
    pub(crate) compiler: Compiler,
    pub(crate) inline_stack: Vec<u32>,
    pub(crate) fn_instances: Vec<UserFnInstance>,
    pub(crate) compiled_fn_instances: usize,
    pub(crate) builtin_vars: BTreeMap<BuiltinFn, (u32, Type, u32)>,
    pub(crate) workgroup_size: [u32; 3],
    pub(crate) statement_temp_ptrs: Vec<Value>,
    pub(crate) function_var_pool: Vec<(Type, Vec<Value>)>,
}

impl SpirvCompiler {
    pub(crate) fn new(
        externs: BTreeMap<u32, ExternalFnKind>,
        user_fns: BTreeMap<u32, UserFn>,
        type_defs: BTreeMap<u32, Type>,
        workgroup_static_tys: BTreeMap<u32, Type>,
        compiler: Compiler,
        workgroup_size: [u32; 3],
    ) -> Self {
        let mut builder = Builder::new();
        builder.set_version(1, 5);
        builder.capability(Capability::Shader);
        builder.memory_model(AddressingModel::Logical, MemoryModel::GLSL450);
        Self {
            builder,
            types: Vec::new(),
            consts: Vec::new(),
            vars: Vec::new(),
            local_ptrs: Vec::new(),
            names: Vec::new(),
            buffers: Vec::new(),
            interfaces: Vec::new(),
            current_block: None,
            loop_stack: Vec::new(),
            externs,
            user_fns,
            type_defs,
            workgroup_static_tys,
            workgroup_statics: BTreeMap::new(),
            compiler,
            inline_stack: Vec::new(),
            fn_instances: Vec::new(),
            compiled_fn_instances: 0,
            builtin_vars: BTreeMap::new(),
            workgroup_size,
            statement_temp_ptrs: Vec::new(),
            function_var_pool: Vec::new(),
        }
    }

    pub(crate) fn compile_kernel(mut self, arg_tys: &[Type], ret_ty: Type, body: &Stmt) -> Result<SpirvModule> {
        let void_ty = self.get_type(SpirvTy::Value(Type::Void));
        let fn_ty = self.builder.type_function(void_ty, Vec::<u32>::new());

        self.vars.clear();
        self.local_ptrs.clear();
        self.names.clear();
        self.buffers.clear();
        self.function_var_pool.clear();
        let arg_buffers = arg_tys.iter().enumerate().map(|(idx, ty)| self.add_storage_buffer(ty.clone(), idx as u32)).collect::<Result<Vec<_>>>()?;
        let ret_buffer = if ret_ty.is_void() { None } else { Some(self.add_storage_buffer(ret_ty.clone(), arg_tys.len() as u32)?) };
        self.add_workgroup_statics()?;

        self.get_builtin_input(BuiltinFn::GroupId);
        self.get_builtin_input(BuiltinFn::LocalId);

        let fn_id = self.builder.id();
        self.builder.begin_function(void_ty, Some(fn_id), spirv::FunctionControl::NONE, fn_ty)?;
        let entry_block = self.builder.begin_block(None)?;
        self.current_block = Some(entry_block);

        for (idx, ty) in arg_tys.iter().enumerate() {
            let ptr = arg_buffers[idx];
            let val = if Self::is_runtime_array(ty) { Value { id: ptr, ty: ty.clone() } } else { self.load_buffer_value(ptr, ty.clone())? };
            self.vars.push(Some(val));
            self.local_ptrs.push(None);
            self.names.push(None);
        }

        let ret = self.gen_stmt(body)?;
        if let (Some(buf), Some(value)) = (ret_buffer, ret) {
            let value = self.convert(value, ret_ty.clone())?;
            self.store_buffer_value(buf, value)?;
        }
        if self.current_block.is_some() {
            self.builder.ret()?;
        }
        self.builder.end_function()?;

        while self.compiled_fn_instances < self.fn_instances.len() {
            let idx = self.compiled_fn_instances;
            self.compiled_fn_instances += 1;
            self.compile_user_fn_instance(idx)?;
        }

        let mut interfaces = self.buffers.clone();
        interfaces.extend(self.interfaces.iter().copied());
        self.builder.entry_point(ExecutionModel::GLCompute, fn_id, "main", interfaces);
        self.builder.execution_mode(fn_id, spirv::ExecutionMode::LocalSize, self.workgroup_size);
        let module = self.builder.module();
        let words = module.assemble();
        Ok(SpirvModule { words, module })
    }

    pub(crate) fn compile_user_fn_instance(&mut self, idx: usize) -> Result<()> {
        let instance = self.fn_instances[idx].clone();
        let ret_id = self.get_type(SpirvTy::Value(instance.ret_ty.clone()));

        self.vars.clear();
        self.local_ptrs.clear();
        self.names.clear();
        self.loop_stack.clear();
        self.current_block = None;
        self.function_var_pool.clear();

        self.builder.begin_function(ret_id, Some(instance.fn_id), spirv::FunctionControl::NONE, instance.fn_ty)?;

        let mut params = Vec::with_capacity(instance.arg_tys.len());
        for ty in &instance.arg_tys {
            let ty_id = self.get_type(SpirvTy::Value(ty.clone()));
            let id = self.builder.function_parameter(ty_id)?;
            params.push(Value { id, ty: ty.clone() });
        }

        self.current_block = Some(self.builder.begin_block(None)?);
        for (arg_idx, param) in params.into_iter().enumerate() {
            self.set_var(arg_idx, param);
            self.names[arg_idx] = instance.arg_names.get(arg_idx).cloned();
        }
        let ret = self.gen_stmt(&instance.body)?;
        if self.current_block.is_some() {
            if instance.ret_ty.is_void() {
                self.builder.ret()?;
            } else {
                let value = ret.ok_or_else(|| anyhow::anyhow!("SPIR-V function {} did not produce a return value", instance.source_id))?;
                let value = self.convert(value, instance.ret_ty.clone())?;
                self.builder.ret_value(value.id)?;
            }
        }
        self.builder.end_function()?;

        self.vars.clear();
        self.local_ptrs.clear();
        self.names.clear();
        self.loop_stack.clear();
        self.current_block = None;
        self.function_var_pool.clear();
        Ok(())
    }
}