use std::ffi::c_void;
use std::ptr;
use flodl_sys as ffi;
pub fn is_grad_enabled() -> bool {
unsafe { ffi::flodl_is_grad_enabled() != 0 }
}
pub struct NoGradGuard {
guard: *mut c_void,
}
impl Default for NoGradGuard {
fn default() -> Self {
Self::new()
}
}
impl NoGradGuard {
pub fn new() -> Self {
let guard = unsafe { ffi::flodl_no_grad_guard_new() };
NoGradGuard { guard }
}
}
impl Drop for NoGradGuard {
fn drop(&mut self) {
if !self.guard.is_null() {
unsafe { ffi::flodl_no_grad_guard_delete(self.guard) };
self.guard = ptr::null_mut();
}
}
}
pub fn no_grad<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
let _guard = NoGradGuard::new();
f()
}