Skip to main content

vm_spirv_mandelbrot/
mandelbrot.rs

1use anyhow::Result;
2use dynamic::Type;
3
4const BIGFLOAT_LIMBS: usize = 2;
5
6fn main() -> Result<()> {
7    let source_path = std::env::var("MANDEL_ZS").unwrap_or_else(|_| "zusts/gpu/mandelbrot.zs".to_string());
8    let module_name = std::env::var("MANDEL_MODULE").unwrap_or_else(|_| "mandelbrot".to_string());
9    let spv_path = std::env::var("MANDEL_SPV").unwrap_or_else(|_| "mandel.spv".to_string());
10    let asm_path = std::env::var("MANDEL_SPVASM").unwrap_or_else(|_| "mandel.spvasm".to_string());
11    let local_size = std::env::var("MANDEL_LOCAL_SIZE").ok().map(|value| parse_workgroups(&value)).transpose()?.unwrap_or([16, 16, 1]);
12
13    let kernel = vm_spirv::compile_file_with_generic_args_and_workgroup_size(&source_path, &module_name, "main", &[Type::ConstInt(BIGFLOAT_LIMBS as i64)], local_size)?;
14    let bytes = kernel.spirv.words().iter().flat_map(|word| word.to_le_bytes()).collect::<Vec<_>>();
15    std::fs::write(&spv_path, bytes)?;
16    std::fs::write(&asm_path, kernel.spirv.disassemble())?;
17    println!("compiled {source_path}");
18    println!("bigfloat_limbs: {BIGFLOAT_LIMBS}");
19    println!("wrote {spv_path} ({} words)", kernel.spirv.words().len());
20    Ok(())
21}
22
23fn parse_workgroups(value: &str) -> Result<[u32; 3]> {
24    let parts = value.split(['x', ',', ' ']).filter(|part| !part.is_empty()).map(str::parse::<u32>).collect::<Result<Vec<_>, _>>()?;
25    match parts.as_slice() {
26        [x, y] => Ok([*x, *y, 1]),
27        [x, y, z] => Ok([*x, *y, *z]),
28        _ => anyhow::bail!("MANDEL_LOCAL_SIZE must be `x,y` or `x,y,z`, got {value:?}"),
29    }
30}