use rocmrc::driver::{HipContext, HipSlice};
use rocmrc::driver::result as drv;
use rocmrc::rocblas::{
Axpy, AxpyConfig, Copy as BlasCopy, CopyConfig, Dot, DotConfig, Gemm, GemmConfig, Nrm2,
Nrm2Config, Operation, RocblasHandle, Scal, ScalConfig, rocblas_pointer_mode,
};
fn main() {
let arch = std::env::var("ROCMRC_GFX").unwrap_or_else(|_| "gfx1102".to_string());
let ctx = HipContext::new(0, &arch).expect("HipContext");
let stream = ctx.default_stream();
println!("device = {}", ctx.name().unwrap_or_else(|_| "<unknown>".into()));
let handle = RocblasHandle::new(stream.clone()).expect("RocblasHandle");
const N: usize = 32;
let a_host: Vec<f32> = (0..N * N).map(|i| (i % 7) as f32).collect();
let b_host: Vec<f32> = (0..N * N).map(|i| (i % 5) as f32 - 2.0).collect();
let d_a: HipSlice<f32> = ctx.alloc(N * N).unwrap();
let d_b: HipSlice<f32> = ctx.alloc(N * N).unwrap();
let d_c: HipSlice<f32> = ctx.alloc(N * N).unwrap();
unsafe {
drv::memcpy_htod_async(
d_a.device_ptr(),
bytemuck::cast_slice(&a_host),
stream.hip_stream(),
)
.unwrap();
drv::memcpy_htod_async(
d_b.device_ptr(),
bytemuck::cast_slice(&b_host),
stream.hip_stream(),
)
.unwrap();
}
let cfg = GemmConfig::<f32> {
transa: Operation::None,
transb: Operation::None,
m: N as i32,
n: N as i32,
k: N as i32,
alpha: 1.0,
lda: N as i32,
ldb: N as i32,
beta: 0.0,
ldc: N as i32,
};
unsafe {
handle
.gemm(cfg, d_a.device_ptr(), d_b.device_ptr(), d_c.device_ptr())
.expect("sgemm");
}
let mut c_bytes = vec![0u8; N * N * std::mem::size_of::<f32>()];
unsafe {
drv::memcpy_dtoh_async(&mut c_bytes, d_c.device_ptr(), stream.hip_stream()).unwrap();
}
stream.synchronize().expect("sync");
let c: &[f32] = bytemuck::cast_slice(&c_bytes);
let mut max_err = 0f32;
for j in 0..N {
for i in 0..N {
let mut acc = 0.0f32;
for k in 0..N {
acc += a_host[k * N + i] * b_host[j * N + k];
}
let got = c[j * N + i];
max_err = max_err.max((got - acc).abs());
}
}
println!("sgemm ({N}x{N}, col-major) max abs err = {max_err:.3e}");
assert!(max_err < 1e-3, "sgemm precision");
const M: usize = 1024;
let x_host: Vec<f32> = (0..M).map(|i| 0.001 * i as f32).collect();
let y_host: Vec<f32> = (0..M).map(|i| 0.5 - 0.001 * i as f32).collect();
let d_x: HipSlice<f32> = ctx.alloc(M).unwrap();
let d_y: HipSlice<f32> = ctx.alloc(M).unwrap();
let d_z: HipSlice<f32> = ctx.alloc(M).unwrap(); let d_scratch: HipSlice<f32> = ctx.alloc(1).unwrap();
unsafe {
drv::memcpy_htod_async(
d_x.device_ptr(),
bytemuck::cast_slice(&x_host),
stream.hip_stream(),
)
.unwrap();
drv::memcpy_htod_async(
d_y.device_ptr(),
bytemuck::cast_slice(&y_host),
stream.hip_stream(),
)
.unwrap();
}
unsafe {
handle
.scal(ScalConfig { n: M as i32, alpha: 2.0f32, incx: 1 }, d_x.device_ptr())
.expect("scal");
}
unsafe {
handle
.axpy(
AxpyConfig { n: M as i32, alpha: 3.0f32, incx: 1, incy: 1 },
d_x.device_ptr(),
d_y.device_ptr(),
)
.expect("axpy");
}
unsafe {
BlasCopy::<f32>::copy(
&*handle,
CopyConfig { n: M as i32, incx: 1, incy: 1 },
d_y.device_ptr(),
d_z.device_ptr(),
)
.expect("copy");
}
handle
.set_pointer_mode(rocblas_pointer_mode::rocblas_pointer_mode_device)
.expect("ptr mode");
unsafe {
Dot::<f32>::dot(
&*handle,
DotConfig { n: M as i32, incx: 1, incy: 1 },
d_x.device_ptr(),
d_z.device_ptr(),
d_scratch.device_ptr(),
)
.expect("dot");
}
let mut dot_buf = [0f32];
unsafe {
drv::memcpy_dtoh_async(
bytemuck::cast_slice_mut(&mut dot_buf),
d_scratch.device_ptr(),
stream.hip_stream(),
)
.unwrap();
}
stream.synchronize().unwrap();
unsafe {
Nrm2::<f32>::nrm2(
&*handle,
Nrm2Config { n: M as i32, incx: 1 },
d_x.device_ptr(),
d_scratch.device_ptr(),
)
.expect("nrm2");
}
let mut nrm_buf = [0f32];
unsafe {
drv::memcpy_dtoh_async(
bytemuck::cast_slice_mut(&mut nrm_buf),
d_scratch.device_ptr(),
stream.hip_stream(),
)
.unwrap();
}
stream.synchronize().unwrap();
let x_ref: Vec<f32> = x_host.iter().map(|&v| 2.0 * v).collect();
let y_ref: Vec<f32> = y_host
.iter()
.zip(&x_ref)
.map(|(&y, &x)| 3.0 * x + y)
.collect();
let dot_ref: f32 = x_ref.iter().zip(&y_ref).map(|(a, b)| a * b).sum();
let nrm_ref: f32 = x_ref.iter().map(|v| v * v).sum::<f32>().sqrt();
let dot_err = (dot_buf[0] - dot_ref).abs() / dot_ref.abs().max(1e-6);
let nrm_err = (nrm_buf[0] - nrm_ref).abs() / nrm_ref.abs().max(1e-6);
println!("dot = {} (ref {}), rel err {:.3e}", dot_buf[0], dot_ref, dot_err);
println!("nrm2(x) = {} (ref {}), rel err {:.3e}", nrm_buf[0], nrm_ref, nrm_err);
assert!(dot_err < 1e-4 && nrm_err < 1e-4, "L1 precision");
println!("ok");
}