use st_zrt::{
CustomOp, CustomOpDomain, KernelContext, KernelInfo, OpIoSpec, SessionOptions, custom_op,
};
struct MyRelu;
impl CustomOp for MyRelu {
const NAME: &'static str = "MyRelu";
const DOMAIN: &'static str = "com.example";
fn create(_info: &KernelInfo<'_>) -> st_zrt::Result<Self> {
Ok(Self)
}
fn compute(&mut self, ctx: &KernelContext<'_>) -> st_zrt::Result<()> {
let input = ctx.input(0)?.expect("MyRelu: input[0] required");
let dims = input.dims()?;
let inp = input.as_slice::<f32>()?;
ctx.output_mut::<f32>(0, &dims, |out| {
for (o, &v) in out.iter_mut().zip(inp) {
*o = v.max(0.0);
}
Ok(())
})
}
fn inputs() -> &'static [OpIoSpec] {
static IN: [OpIoSpec; 1] = [OpIoSpec::required(st_zrt::sys::ElementType::Float)];
&IN
}
fn outputs() -> &'static [OpIoSpec] {
static OUT: [OpIoSpec; 1] = [OpIoSpec::required(st_zrt::sys::ElementType::Float)];
&OUT
}
}
custom_op!(MyRelu, "MyRelu", as MY_RELU_VTABLE);
fn main() -> st_zrt::Result<()> {
let domain = CustomOpDomain::new(MyRelu::DOMAIN)?;
domain.add_op(&MY_RELU_VTABLE)?;
let _opts = SessionOptions::default().with_custom_op_domain(&domain);
println!(
"custom op '{}' registered on domain '{}'",
MyRelu::NAME,
MyRelu::DOMAIN
);
Ok(())
}