zust-vm-spirv 0.9.2

SPIR-V code generation backend for the Zust scripting language.
Documentation
use anyhow::{Result, anyhow};
use dynamic::Type;
use parser::{BinaryOp, Expr, ExprKind};
use rspirv::dr::{InsertPoint, Instruction, Operand};
use spirv::{BuiltIn, Decoration, StorageClass};
use std::rc::Rc;

use crate::{
    api::BuiltinFn,
    context::{SpirvCompiler, SpirvTy, Value},
};

impl SpirvCompiler {
    pub(crate) fn add_storage_buffer(&mut self, ty: Type, binding: u32) -> Result<u32> {
        let ty = self.resolve_type(&ty);
        let struct_id = self.get_type(SpirvTy::Buffer(ty.clone()));
        let ptr_ty = self.builder.type_pointer(None, StorageClass::StorageBuffer, struct_id);
        let var = self.builder.variable(ptr_ty, None, StorageClass::StorageBuffer, None);
        self.builder.decorate(var, Decoration::DescriptorSet, [Operand::LiteralBit32(0)]);
        self.builder.decorate(var, Decoration::Binding, [Operand::LiteralBit32(binding)]);
        self.buffers.push(var);
        Ok(var)
    }

    pub(crate) fn add_workgroup_statics(&mut self) -> Result<()> {
        self.workgroup_statics.clear();
        let statics = self.workgroup_static_tys.clone();
        for (id, ty) in statics {
            let ty = self.resolve_type(&ty);
            let ptr_ty = self.get_type(SpirvTy::Pointer(ty.clone(), StorageClass::Workgroup));
            let var = self.builder.variable(ptr_ty, None, StorageClass::Workgroup, None);
            self.interfaces.push(var);
            self.workgroup_statics.insert(id, Value { id: var, ty });
        }
        Ok(())
    }

    pub(crate) fn load_buffer_value(&mut self, buffer: u32, ty: Type) -> Result<Value> {
        let ty = self.resolve_type(&ty);
        let zero = self.const_u32(0);
        let ptr_ty = self.get_type(SpirvTy::Pointer(ty.clone(), StorageClass::StorageBuffer));
        let ptr = self.builder.access_chain(ptr_ty, None, buffer, [zero])?;
        let ty_id = self.get_type(SpirvTy::Value(ty.clone()));
        let id = self.builder.load(ty_id, None, ptr, None, None)?;
        Ok(Value { id, ty })
    }

    pub(crate) fn store_buffer_value(&mut self, buffer: u32, value: Value) -> Result<()> {
        let zero = self.const_u32(0);
        let ptr_ty = self.get_type(SpirvTy::Pointer(value.ty.clone(), StorageClass::StorageBuffer));
        let ptr = self.builder.access_chain(ptr_ty, None, buffer, [zero])?;
        self.builder.store(ptr, value.id, None, None)?;
        Ok(())
    }

    pub(crate) fn workgroup_static_ptr(&self, id: u32) -> Result<Value> {
        self.workgroup_statics.get(&id).cloned().ok_or_else(|| anyhow!("SPIR-V workgroup static {id} not found"))
    }

    pub(crate) fn load_workgroup_static(&mut self, id: u32) -> Result<Value> {
        let ptr = self.workgroup_static_ptr(id)?;
        let ty_id = self.get_type(SpirvTy::Value(ptr.ty.clone()));
        let id = self.builder.load(ty_id, None, ptr.id, None, None)?;
        Ok(Value { id, ty: ptr.ty })
    }

    pub(crate) fn get_builtin_input(&mut self, builtin: BuiltinFn) -> (u32, Type, u32) {
        if let Some(value) = self.builtin_vars.get(&builtin) {
            return value.clone();
        }
        let ty = Type::Vec(Rc::new(Type::U32), 3);
        let ptr_ty = self.get_type(SpirvTy::Pointer(ty.clone(), StorageClass::Input));
        let ty_id = self.get_type(SpirvTy::Value(ty.clone()));
        let var = self.builder.variable(ptr_ty, None, StorageClass::Input, None);
        let built_in = match builtin {
            BuiltinFn::GroupId => BuiltIn::WorkgroupId,
            BuiltinFn::LocalId => BuiltIn::LocalInvocationId,
            BuiltinFn::Barrier | BuiltinFn::AtomicAdd => unreachable!("builtin has no input variable"),
        };
        self.builder.decorate(var, Decoration::BuiltIn, [Operand::BuiltIn(built_in)]);
        self.interfaces.push(var);
        self.builtin_vars.insert(builtin, (var, ty.clone(), ty_id));
        (var, ty, ty_id)
    }

    pub(crate) fn set_var(&mut self, idx: usize, value: Value) {
        self.ensure_var_slot(idx);
        if value.ty.is_array() || value.ty.is_struct() {
            if let Some(ptr) = self.local_ptrs[idx].clone() {
                self.builder.store(ptr.id, value.id, None, None).expect("store aggregate local");
            } else {
                let ptr = self.create_function_var(value.ty.clone()).expect("aggregate local variable");
                self.builder.store(ptr.id, value.id, None, None).expect("initialize aggregate local");
                self.local_ptrs[idx] = Some(ptr);
            }
            self.vars[idx] = None;
        } else {
            self.local_ptrs[idx] = None;
            self.vars[idx] = Some(value);
        }
    }

    pub(crate) fn set_var_lazy(&mut self, idx: usize, value: Value) {
        self.ensure_var_slot(idx);
        if value.ty.is_array() || value.ty.is_struct() {
            if let Some(ptr) = self.local_ptrs[idx].clone() {
                self.builder.store(ptr.id, value.id, None, None).expect("store aggregate local");
                self.vars[idx] = None;
            } else {
                self.vars[idx] = Some(value);
            }
        } else {
            self.local_ptrs[idx] = None;
            self.vars[idx] = Some(value);
        }
    }

    pub(crate) fn ensure_var_slot(&mut self, idx: usize) {
        if idx >= self.vars.len() {
            self.vars.resize(idx + 1, None);
        }
        if idx >= self.local_ptrs.len() {
            self.local_ptrs.resize(idx + 1, None);
        }
        if idx >= self.names.len() {
            self.names.resize(idx + 1, None);
        }
    }

    pub(crate) fn create_function_var(&mut self, ty: Type) -> Result<Value> {
        let ty = self.resolve_type(&ty);
        if let Some((_, values)) = self.function_var_pool.iter_mut().find(|(existing, values)| existing == &ty && !values.is_empty())
            && let Some(value) = values.pop()
        {
            return Ok(value);
        }
        let ptr_ty = self.get_type(SpirvTy::Pointer(ty.clone(), StorageClass::Function));
        let id = self.function_variable(ptr_ty)?;
        Ok(Value { id, ty })
    }

    pub(crate) fn materialize_var_ptr(&mut self, idx: usize) -> Result<Option<Value>> {
        if let Some(ptr) = self.local_ptrs.get(idx).and_then(Clone::clone) {
            return Ok(Some(ptr));
        }
        let Some(value) = self.vars.get(idx).and_then(Clone::clone) else {
            return Ok(None);
        };
        if !(value.ty.is_array() || value.ty.is_struct()) {
            return Ok(None);
        }
        self.ensure_var_slot(idx);
        let ptr = self.create_function_var(value.ty.clone())?;
        self.builder.store(ptr.id, value.id, None, None)?;
        self.local_ptrs[idx] = Some(ptr.clone());
        self.vars[idx] = None;
        Ok(Some(ptr))
    }

    pub(crate) fn materialize_temp_ptr(&mut self, value: Value) -> Result<Value> {
        if !(value.ty.is_array() || value.ty.is_struct()) {
            return Err(anyhow!("SPIR-V temporary pointer requires aggregate value, got {:?}", value.ty));
        }
        let ptr = self.create_function_var(value.ty.clone())?;
        self.builder.store(ptr.id, value.id, None, None)?;
        self.statement_temp_ptrs.push(ptr.clone());
        Ok(ptr)
    }

    pub(crate) fn materialize_assigned_aggregate_vars(&mut self, assigned: &std::collections::BTreeSet<usize>) -> Result<()> {
        for &idx in assigned {
            let Some(value) = self.vars.get(idx).and_then(Clone::clone) else {
                continue;
            };
            if value.ty.is_array() || value.ty.is_struct() {
                self.materialize_var_ptr(idx)?;
            }
        }
        Ok(())
    }

    pub(crate) fn function_variable(&mut self, ptr_ty: u32) -> Result<u32> {
        let selected_block = self.builder.selected_block();
        let id = self.builder.id();
        let inst = Instruction::new(spirv::Op::Variable, Some(ptr_ty), Some(id), vec![Operand::StorageClass(StorageClass::Function)]);
        self.builder.select_block(Some(0))?;
        self.builder.insert_into_block(InsertPoint::Begin, inst)?;
        self.builder.select_block(selected_block)?;
        Ok(id)
    }

    pub(crate) fn get_var(&mut self, idx: usize) -> Option<Value> {
        if let Some(ptr) = self.local_ptrs.get(idx).and_then(Clone::clone) {
            let ty_id = self.get_type(SpirvTy::Value(ptr.ty.clone()));
            let id = self.builder.load(ty_id, None, ptr.id, None, None).ok()?;
            return Some(Value { id, ty: ptr.ty });
        }
        self.vars.get(idx).and_then(Clone::clone)
    }

    pub(crate) fn expr_ptr(&mut self, expr: &Expr) -> Result<Option<Value>> {
        self.expr_ptr_with_materialize(expr, true)
    }

    pub(crate) fn expr_existing_ptr(&mut self, expr: &Expr) -> Result<Option<Value>> {
        self.expr_ptr_with_materialize(expr, false)
    }

    pub(crate) fn expr_ptr_with_materialize(&mut self, expr: &Expr, materialize: bool) -> Result<Option<Value>> {
        match &expr.kind {
            ExprKind::Var(idx) => {
                if materialize {
                    self.materialize_var_ptr(*idx as usize)
                } else {
                    Ok(self.local_ptrs.get(*idx as usize).and_then(Clone::clone))
                }
            }
            ExprKind::Ident(name) => Ok(self.names.iter().enumerate().rev().find_map(|(idx, existing)| if existing.as_deref() == Some(name) { self.local_ptrs.get(idx).and_then(Clone::clone) } else { None })),
            ExprKind::Binary { left, op: BinaryOp::Idx, right } => {
                if let Some(ptr) = self.expr_ptr_with_materialize(left, materialize)? {
                    let idx = self.gen_expr(right)?;
                    return self.index_local_ptr(ptr, idx).map(Some);
                }
                Ok(None)
            }
            _ => Ok(None),
        }
    }

    pub(crate) fn get_named_var(&mut self, name: &str) -> Result<Value> {
        let idx = self.names.iter().enumerate().rev().find_map(|(idx, existing)| if existing.as_deref() == Some(name) { Some(idx) } else { None }).ok_or_else(|| anyhow!("SPIR-V identifier {name} not found"))?;
        self.get_var(idx).ok_or_else(|| anyhow!("SPIR-V identifier {name} has no value"))
    }
}