#![allow(clippy::all)]
use blaze_rs::prelude::{blaze, global_context, Result, SimpleContext};
use std::mem::MaybeUninit;
#[global_context]
static CONTEXT: SimpleContext = SimpleContext::default();
#[blaze(pub FloatTanh)]
#[link = KERNEL]
extern "C" {
fn forward(n: u64, x_buffer: *mut f32);
fn backward(n: u64, x_buffer: *const f32, y_buffer: *mut MaybeUninit<f32>);
}
const KERNEL: &str = r#"
__kernel void forward (ulong n, __global float* x_buffer) {
for (ulong i = get_global_id(0); i < n; i += get_global_size(0)) {
x_buffer[i] = tanh(x_buffer[i]);
}
}
__kernel void backward (ulong n, const __global float* x_buffer, __global float* y_buffer) {
for (ulong i = get_global_id(0); i < n; i += get_global_size(0)) {
const float c = cosh(x_buffer[i]);
y_buffer[i] = 1.0 / (c * c);
}
}
"#;
#[test]
fn gemm() -> Result<()> {
let tanh = FloatTanh::new(None)?;
std::hint::black_box(tanh);
Ok(())
}