zust-vm-spirv 0.9.3

SPIR-V code generation backend for the Zust scripting language.
Documentation
mod api;
mod constants;
mod context;
mod expr;
mod externs;
mod memory;
mod ops;
mod stmt;
mod symbols;
mod types;

pub use api::{BuiltinFn, ExternalFn, ExternalFnKind, Kernel, SpirvModule, spirv_builtins};

use anyhow::{Context, Result, bail};
use compiler::{Compiler, Symbol, substitute_stmt, substitute_type};
use dynamic::Type;
use parser::{Span, Stmt, StmtKind};
use std::{collections::BTreeMap, path::Path};

use crate::{
    context::SpirvCompiler,
    externs::register_externs,
    symbols::{collect_type_defs, collect_user_fns, collect_workgroup_statics},
};

pub fn compile_source(source: impl AsRef<[u8]>, module_name: &str, fn_name: &str) -> Result<Kernel> {
    compile_source_with_externs(source, module_name, fn_name, spirv_builtins())
}

pub fn compile_source_with_workgroup_size(source: impl AsRef<[u8]>, module_name: &str, fn_name: &str, workgroup_size: [u32; 3]) -> Result<Kernel> {
    compile_source_with_externs_and_workgroup_size(source, module_name, fn_name, spirv_builtins(), workgroup_size)
}

pub fn compile_file_with_workgroup_size(path: impl AsRef<Path>, module_name: &str, fn_name: &str, workgroup_size: [u32; 3]) -> Result<Kernel> {
    compile_file_with_generic_args_and_workgroup_size(path, module_name, fn_name, &[], workgroup_size)
}

pub fn compile_file_with_generic_args_and_workgroup_size(path: impl AsRef<Path>, module_name: &str, fn_name: &str, generic_args: &[Type], workgroup_size: [u32; 3]) -> Result<Kernel> {
    let mut compiler = Compiler::new();
    let externs = register_externs(&mut compiler, spirv_builtins())?;
    compiler.import_file(module_name, path)?;
    compile_function_with_externs_generic_args_and_workgroup_size(&mut compiler, module_name, fn_name, externs, generic_args, workgroup_size)
}

pub fn compile_source_with_externs(source: impl AsRef<[u8]>, module_name: &str, fn_name: &str, externs: impl IntoIterator<Item = ExternalFn>) -> Result<Kernel> {
    compile_source_with_externs_and_workgroup_size(source, module_name, fn_name, externs, [1, 1, 1])
}

pub fn compile_source_with_externs_and_workgroup_size(source: impl AsRef<[u8]>, module_name: &str, fn_name: &str, externs: impl IntoIterator<Item = ExternalFn>, workgroup_size: [u32; 3]) -> Result<Kernel> {
    compile_source_with_externs_generic_args_and_workgroup_size(source, module_name, fn_name, externs, &[], workgroup_size)
}

pub fn compile_source_with_externs_generic_args_and_workgroup_size(
    source: impl AsRef<[u8]>,
    module_name: &str,
    fn_name: &str,
    externs: impl IntoIterator<Item = ExternalFn>,
    generic_args: &[Type],
    workgroup_size: [u32; 3],
) -> Result<Kernel> {
    let mut compiler = Compiler::new();
    let externs = register_externs(&mut compiler, externs)?;
    compiler.import_code(module_name, source.as_ref().to_vec())?;
    compile_function_with_externs_generic_args_and_workgroup_size(&mut compiler, module_name, fn_name, externs, generic_args, workgroup_size)
}

pub fn compile_function(compiler: &mut Compiler, module_name: &str, fn_name: &str) -> Result<Kernel> {
    let externs = register_externs(compiler, spirv_builtins())?;
    compile_function_with_externs(compiler, module_name, fn_name, externs)
}

pub fn compile_function_with_externs(compiler: &mut Compiler, module_name: &str, fn_name: &str, externs: BTreeMap<u32, ExternalFnKind>) -> Result<Kernel> {
    compile_function_with_externs_and_workgroup_size(compiler, module_name, fn_name, externs, [1, 1, 1])
}

pub fn compile_function_with_externs_and_workgroup_size(compiler: &mut Compiler, module_name: &str, fn_name: &str, externs: BTreeMap<u32, ExternalFnKind>, workgroup_size: [u32; 3]) -> Result<Kernel> {
    compile_function_with_externs_generic_args_and_workgroup_size(compiler, module_name, fn_name, externs, &[], workgroup_size)
}

pub fn compile_function_with_externs_generic_args_and_workgroup_size(
    compiler: &mut Compiler,
    module_name: &str,
    fn_name: &str,
    externs: BTreeMap<u32, ExternalFnKind>,
    generic_args: &[Type],
    workgroup_size: [u32; 3],
) -> Result<Kernel> {
    let full_name = format!("{module_name}::{fn_name}");
    let id = compiler.symbols.get_id(&full_name).or_else(|_| compiler.symbols.get_id(fn_name)).with_context(|| format!("function {full_name} not found"))?;
    let symbol = compiler.symbols.get_symbol(id)?.1.clone();
    let Symbol::Fn { ty, args, generic_params, cap, body, is_pub: _ } = symbol else {
        bail!("{full_name} is not a zust function");
    };
    let Type::Fn { tys: decl_arg_tys, ret: _ } = ty else {
        bail!("{full_name} has non-function type {ty:?}");
    };
    let (arg_tys, body) = specialize_entry_function(compiler, module_name, &args, &generic_params, generic_args, &decl_arg_tys, body.as_ref(), cap)?;
    let ret_ty = compiler.infer_fn_with_params(id, &arg_tys, generic_args)?;
    let ret_ty = compiler.symbols.get_type(&ret_ty)?;
    let type_defs = collect_type_defs(compiler);
    let user_fns = collect_user_fns(compiler)?;
    let workgroup_static_tys = collect_workgroup_statics(compiler)?;
    let builder = SpirvCompiler::new(externs, user_fns, type_defs, workgroup_static_tys, compiler.clone(), workgroup_size);
    let spirv = builder.compile_kernel(&arg_tys, ret_ty.clone(), &body)?;
    Ok(Kernel { spirv, entry: "main".into(), arg_tys: arg_tys.clone(), ret_ty })
}

fn specialize_entry_function(
    compiler: &mut Compiler,
    module_name: &str,
    args: &[smol_str::SmolStr],
    generic_params: &[Type],
    generic_args: &[Type],
    decl_arg_tys: &[Type],
    body: &Stmt,
    cap: compiler::Capture,
) -> Result<(Vec<Type>, Stmt)> {
    if generic_params.is_empty() {
        let arg_tys = decl_arg_tys.iter().map(|ty| resolve_entry_type(compiler, module_name, ty)).collect::<Result<Vec<_>>>()?;
        return Ok((arg_tys, body.clone()));
    }
    if generic_params.len() != generic_args.len() {
        bail!("entry function expects {} generic args, got {}", generic_params.len(), generic_args.len());
    }
    let body = substitute_stmt(body, generic_params, generic_args);
    let mut arg_tys = decl_arg_tys.iter().map(|ty| resolve_entry_type(compiler, module_name, &substitute_type(ty, generic_params, generic_args))).collect::<Result<Vec<_>>>()?;
    let mut cap = cap;
    let compiled = compiler.compile_fn(args, &mut arg_tys, body, &mut cap)?;
    Ok((arg_tys, Stmt::new(StmtKind::Block(compiled), Span::default())))
}

fn resolve_entry_type(compiler: &Compiler, module_name: &str, ty: &Type) -> Result<Type> {
    compiler.symbols.get_type(ty).or_else(|_| compiler.symbols.get_type(&qualify_entry_type(module_name, ty)))
}

fn qualify_entry_type(module_name: &str, ty: &Type) -> Type {
    match ty {
        Type::Ident { name, params } if !name.contains("::") && name.as_str() != "Vec" => {
            Type::Ident { name: format!("{module_name}::{name}").into(), params: params.iter().map(|param| qualify_entry_type(module_name, param)).collect() }
        }
        Type::Ident { name, params } => Type::Ident { name: name.clone(), params: params.iter().map(|param| qualify_entry_type(module_name, param)).collect() },
        Type::Array(elem, len) => Type::Array(std::rc::Rc::new(qualify_entry_type(module_name, elem)), *len),
        Type::ArrayParam(elem, len) => Type::ArrayParam(std::rc::Rc::new(qualify_entry_type(module_name, elem)), std::rc::Rc::new(qualify_entry_type(module_name, len))),
        Type::Vec(elem, len) => Type::Vec(std::rc::Rc::new(qualify_entry_type(module_name, elem)), *len),
        Type::Tuple(items) => Type::Tuple(items.iter().map(|item| qualify_entry_type(module_name, item)).collect()),
        Type::Struct { params, fields } => {
            Type::Struct { params: params.iter().map(|param| qualify_entry_type(module_name, param)).collect(), fields: fields.iter().map(|(name, ty)| (name.clone(), qualify_entry_type(module_name, ty))).collect() }
        }
        Type::Fn { tys, ret } => Type::Fn { tys: tys.iter().map(|ty| qualify_entry_type(module_name, ty)).collect(), ret: std::rc::Rc::new(qualify_entry_type(module_name, ret)) },
        Type::Symbol { id, params } => Type::Symbol { id: *id, params: params.iter().map(|param| qualify_entry_type(module_name, param)).collect() },
        other => other.clone(),
    }
}

#[cfg(test)]
mod tests;