1mod api;
2mod constants;
3mod context;
4mod expr;
5mod externs;
6mod memory;
7mod ops;
8mod stmt;
9mod symbols;
10mod types;
11
12pub use api::{BuiltinFn, ExternalFn, ExternalFnKind, Kernel, SpirvModule, spirv_builtins};
13
14use anyhow::{Context, Result, bail};
15use compiler::{Compiler, Symbol, substitute_stmt, substitute_type};
16use dynamic::Type;
17use parser::{Span, Stmt, StmtKind};
18use std::{collections::BTreeMap, path::Path};
19
20use crate::{
21 context::SpirvCompiler,
22 externs::register_externs,
23 symbols::{collect_type_defs, collect_user_fns, collect_workgroup_statics},
24};
25
26pub fn compile_source(source: impl AsRef<[u8]>, module_name: &str, fn_name: &str) -> Result<Kernel> {
27 compile_source_with_externs(source, module_name, fn_name, spirv_builtins())
28}
29
30pub fn compile_source_with_workgroup_size(source: impl AsRef<[u8]>, module_name: &str, fn_name: &str, workgroup_size: [u32; 3]) -> Result<Kernel> {
31 compile_source_with_externs_and_workgroup_size(source, module_name, fn_name, spirv_builtins(), workgroup_size)
32}
33
34pub fn compile_file_with_workgroup_size(path: impl AsRef<Path>, module_name: &str, fn_name: &str, workgroup_size: [u32; 3]) -> Result<Kernel> {
35 compile_file_with_generic_args_and_workgroup_size(path, module_name, fn_name, &[], workgroup_size)
36}
37
38pub fn compile_file_with_generic_args_and_workgroup_size(path: impl AsRef<Path>, module_name: &str, fn_name: &str, generic_args: &[Type], workgroup_size: [u32; 3]) -> Result<Kernel> {
39 let mut compiler = Compiler::new();
40 let externs = register_externs(&mut compiler, spirv_builtins())?;
41 compiler.import_file(module_name, path)?;
42 compile_function_with_externs_generic_args_and_workgroup_size(&mut compiler, module_name, fn_name, externs, generic_args, workgroup_size)
43}
44
45pub fn compile_source_with_externs(source: impl AsRef<[u8]>, module_name: &str, fn_name: &str, externs: impl IntoIterator<Item = ExternalFn>) -> Result<Kernel> {
46 compile_source_with_externs_and_workgroup_size(source, module_name, fn_name, externs, [1, 1, 1])
47}
48
49pub fn compile_source_with_externs_and_workgroup_size(source: impl AsRef<[u8]>, module_name: &str, fn_name: &str, externs: impl IntoIterator<Item = ExternalFn>, workgroup_size: [u32; 3]) -> Result<Kernel> {
50 compile_source_with_externs_generic_args_and_workgroup_size(source, module_name, fn_name, externs, &[], workgroup_size)
51}
52
53pub fn compile_source_with_externs_generic_args_and_workgroup_size(
54 source: impl AsRef<[u8]>,
55 module_name: &str,
56 fn_name: &str,
57 externs: impl IntoIterator<Item = ExternalFn>,
58 generic_args: &[Type],
59 workgroup_size: [u32; 3],
60) -> Result<Kernel> {
61 let mut compiler = Compiler::new();
62 let externs = register_externs(&mut compiler, externs)?;
63 compiler.import_code(module_name, source.as_ref().to_vec())?;
64 compile_function_with_externs_generic_args_and_workgroup_size(&mut compiler, module_name, fn_name, externs, generic_args, workgroup_size)
65}
66
67pub fn compile_function(compiler: &mut Compiler, module_name: &str, fn_name: &str) -> Result<Kernel> {
68 let externs = register_externs(compiler, spirv_builtins())?;
69 compile_function_with_externs(compiler, module_name, fn_name, externs)
70}
71
72pub fn compile_function_with_externs(compiler: &mut Compiler, module_name: &str, fn_name: &str, externs: BTreeMap<u32, ExternalFnKind>) -> Result<Kernel> {
73 compile_function_with_externs_and_workgroup_size(compiler, module_name, fn_name, externs, [1, 1, 1])
74}
75
76pub fn compile_function_with_externs_and_workgroup_size(compiler: &mut Compiler, module_name: &str, fn_name: &str, externs: BTreeMap<u32, ExternalFnKind>, workgroup_size: [u32; 3]) -> Result<Kernel> {
77 compile_function_with_externs_generic_args_and_workgroup_size(compiler, module_name, fn_name, externs, &[], workgroup_size)
78}
79
80pub fn compile_function_with_externs_generic_args_and_workgroup_size(
81 compiler: &mut Compiler,
82 module_name: &str,
83 fn_name: &str,
84 externs: BTreeMap<u32, ExternalFnKind>,
85 generic_args: &[Type],
86 workgroup_size: [u32; 3],
87) -> Result<Kernel> {
88 let full_name = format!("{module_name}::{fn_name}");
89 let id = compiler.symbols.get_id(&full_name).or_else(|_| compiler.symbols.get_id(fn_name)).with_context(|| format!("function {full_name} not found"))?;
90 let symbol = compiler.symbols.get_symbol(id)?.1.clone();
91 let Symbol::Fn { ty, args, generic_params, cap, body, is_pub: _ } = symbol else {
92 bail!("{full_name} is not a zust function");
93 };
94 let Type::Fn { tys: decl_arg_tys, ret: _ } = ty else {
95 bail!("{full_name} has non-function type {ty:?}");
96 };
97 let (arg_tys, body) = specialize_entry_function(compiler, module_name, &args, &generic_params, generic_args, &decl_arg_tys, body.as_ref(), cap)?;
98 let ret_ty = compiler.infer_fn_with_params(id, &arg_tys, generic_args)?;
99 let ret_ty = compiler.symbols.get_type(&ret_ty)?;
100 let type_defs = collect_type_defs(compiler);
101 let user_fns = collect_user_fns(compiler)?;
102 let workgroup_static_tys = collect_workgroup_statics(compiler)?;
103 let builder = SpirvCompiler::new(externs, user_fns, type_defs, workgroup_static_tys, compiler.clone(), workgroup_size);
104 let spirv = builder.compile_kernel(&arg_tys, ret_ty.clone(), &body)?;
105 Ok(Kernel { spirv, entry: "main".into(), arg_tys: arg_tys.clone(), ret_ty })
106}
107
108fn specialize_entry_function(
109 compiler: &mut Compiler,
110 module_name: &str,
111 args: &[smol_str::SmolStr],
112 generic_params: &[Type],
113 generic_args: &[Type],
114 decl_arg_tys: &[Type],
115 body: &Stmt,
116 cap: compiler::Capture,
117) -> Result<(Vec<Type>, Stmt)> {
118 if generic_params.is_empty() {
119 let arg_tys = decl_arg_tys.iter().map(|ty| resolve_entry_type(compiler, module_name, ty)).collect::<Result<Vec<_>>>()?;
120 return Ok((arg_tys, body.clone()));
121 }
122 if generic_params.len() != generic_args.len() {
123 bail!("entry function expects {} generic args, got {}", generic_params.len(), generic_args.len());
124 }
125 let body = substitute_stmt(body, generic_params, generic_args);
126 let mut arg_tys = decl_arg_tys.iter().map(|ty| resolve_entry_type(compiler, module_name, &substitute_type(ty, generic_params, generic_args))).collect::<Result<Vec<_>>>()?;
127 let mut cap = cap;
128 let compiled = compiler.compile_fn(args, &mut arg_tys, body, &mut cap)?;
129 Ok((arg_tys, Stmt::new(StmtKind::Block(compiled), Span::default())))
130}
131
132fn resolve_entry_type(compiler: &Compiler, module_name: &str, ty: &Type) -> Result<Type> {
133 compiler.symbols.get_type(ty).or_else(|_| compiler.symbols.get_type(&qualify_entry_type(module_name, ty)))
134}
135
136fn qualify_entry_type(module_name: &str, ty: &Type) -> Type {
137 match ty {
138 Type::Ident { name, params } if !name.contains("::") && name.as_str() != "Vec" => {
139 Type::Ident { name: format!("{module_name}::{name}").into(), params: params.iter().map(|param| qualify_entry_type(module_name, param)).collect() }
140 }
141 Type::Ident { name, params } => Type::Ident { name: name.clone(), params: params.iter().map(|param| qualify_entry_type(module_name, param)).collect() },
142 Type::Array(elem, len) => Type::Array(std::rc::Rc::new(qualify_entry_type(module_name, elem)), *len),
143 Type::ArrayParam(elem, len) => Type::ArrayParam(std::rc::Rc::new(qualify_entry_type(module_name, elem)), std::rc::Rc::new(qualify_entry_type(module_name, len))),
144 Type::Vec(elem, len) => Type::Vec(std::rc::Rc::new(qualify_entry_type(module_name, elem)), *len),
145 Type::Tuple(items) => Type::Tuple(items.iter().map(|item| qualify_entry_type(module_name, item)).collect()),
146 Type::Struct { params, fields } => {
147 Type::Struct { params: params.iter().map(|param| qualify_entry_type(module_name, param)).collect(), fields: fields.iter().map(|(name, ty)| (name.clone(), qualify_entry_type(module_name, ty))).collect() }
148 }
149 Type::Fn { tys, ret } => Type::Fn { tys: tys.iter().map(|ty| qualify_entry_type(module_name, ty)).collect(), ret: std::rc::Rc::new(qualify_entry_type(module_name, ret)) },
150 Type::Symbol { id, params } => Type::Symbol { id: *id, params: params.iter().map(|param| qualify_entry_type(module_name, param)).collect() },
151 other => other.clone(),
152 }
153}
154
155#[cfg(test)]
156mod tests;