use std::error::Error;
use cubecl::prelude::*;
const WIDTH: u32 = 512;
const FWIDTH: f32 = WIDTH as _;
const HEIGHT: u32 = 256;
const FHEIGHT: f32 = HEIGHT as _;
const ARRAY_LEN: u32 = WIDTH * HEIGHT;
const ITERATIONS: u32 = 1000;
const FITERATIONS: f32 = ITERATIONS as _;
#[cube(launch_unchecked)]
fn kernel(output: &mut Array<f32>) {
if ABSOLUTE_POS < output.len() {
let pix_x = ABSOLUTE_POS as f32 % FWIDTH;
let pix_y = ABSOLUTE_POS as f32 * FHEIGHT;
output[ABSOLUTE_POS] = mandlebrot(pix_x, pix_y);
}
}
#[cube]
fn mandlebrot(pix_x: f32, pix_y: f32) -> f32 {
let screen_x = pix_x as f32 / FWIDTH as f32;
let screen_y = pix_y as f32 / FHEIGHT as f32;
let scaled_x = screen_x * 2.47 - 2.0;
let scaled_y = screen_y * 2.24 - 1.12;
let mut x = 0.0f32;
let mut y = 0.0f32;
let mut i = 0;
loop {
if i >= ITERATIONS || x * x + y * y > 4.0 {
break i as f32 / FITERATIONS;
}
let next_x = x * x - y * y + scaled_x;
let next_y = 2.0 * x * y + scaled_y;
x = next_x;
y = next_y;
i += 1;
};
i as f32 / FITERATIONS
}
fn launch<R: Runtime>(device: &R::Device) -> Vec<f32> {
let client = R::client(device);
let output = client.empty(size_of::<f32>() * ARRAY_LEN as usize);
unsafe {
kernel::launch_unchecked::<R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new(256, 1, 1),
ArrayArg::from_raw_parts(&output, ARRAY_LEN as usize * size_of::<f32>(), 1),
)
};
let bytes = client.read(output.binding());
let array = f32::from_bytes(&bytes);
array.to_vec()
}
fn main() -> Result<(), Box<dyn Error>> {
let x = launch::<cubecl::wgpu::WgpuRuntime>(&Default::default());
dbg!(x.len());
dbg!(x.iter().sum::<f32>() / x.len() as f32);
dbg!(x.iter().take(256).sum::<f32>() / x.len() as f32);
dbg!(x.iter().max_by(|a, b| a.total_cmp(b)));
Ok(())
}