use std::sync::Mutex;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::tensor::Tensor;
pub(crate) type GradHookFn<T> = Box<dyn Fn(&Tensor<T>) -> Option<Tensor<T>> + Send + Sync>;
pub(crate) type PostAccumulateHookFn<T> = Box<dyn Fn(&Tensor<T>) + Send + Sync>;
static NEXT_HOOK_ID: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct HookHandle(u64);
impl HookHandle {
fn next() -> Self {
Self(NEXT_HOOK_ID.fetch_add(1, Ordering::Relaxed))
}
}
pub(crate) struct GradHook<T: Float> {
pub handle: HookHandle,
pub func: GradHookFn<T>,
}
pub(crate) struct PostAccumulateGradHook<T: Float> {
pub handle: HookHandle,
pub func: PostAccumulateHookFn<T>,
}
pub(crate) struct HookStorage<T: Float> {
pub grad_hooks: Vec<GradHook<T>>,
pub post_accumulate_hooks: Vec<PostAccumulateGradHook<T>>,
}
impl<T: Float> HookStorage<T> {
pub fn new() -> Self {
Self {
grad_hooks: Vec::new(),
post_accumulate_hooks: Vec::new(),
}
}
pub fn add_grad_hook<F>(&mut self, func: F) -> HookHandle
where
F: Fn(&Tensor<T>) -> Option<Tensor<T>> + Send + Sync + 'static,
{
let handle = HookHandle::next();
self.grad_hooks.push(GradHook {
handle,
func: Box::new(func),
});
handle
}
pub fn add_post_accumulate_hook<F>(&mut self, func: F) -> HookHandle
where
F: Fn(&Tensor<T>) + Send + Sync + 'static,
{
let handle = HookHandle::next();
self.post_accumulate_hooks.push(PostAccumulateGradHook {
handle,
func: Box::new(func),
});
handle
}
pub fn remove(&mut self, handle: HookHandle) -> bool {
let before = self.grad_hooks.len() + self.post_accumulate_hooks.len();
self.grad_hooks.retain(|h| h.handle != handle);
self.post_accumulate_hooks.retain(|h| h.handle != handle);
let after = self.grad_hooks.len() + self.post_accumulate_hooks.len();
after < before
}
pub fn has_grad_hooks(&self) -> bool {
!self.grad_hooks.is_empty()
}
pub fn has_post_accumulate_hooks(&self) -> bool {
!self.post_accumulate_hooks.is_empty()
}
}
pub(crate) fn run_grad_hooks<T: Float>(
hooks: &Mutex<HookStorage<T>>,
grad: Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
let guard = hooks.lock().map_err(|e| FerrotorchError::LockPoisoned {
message: format!("hook storage mutex: {e}"),
})?;
let mut current = grad;
for hook in &guard.grad_hooks {
if let Some(replacement) = (hook.func)(¤t) {
current = replacement;
}
}
Ok(current)
}
pub(crate) fn run_post_accumulate_hooks<T: Float>(
hooks: &Mutex<HookStorage<T>>,
tensor: &Tensor<T>,
) -> FerrotorchResult<()> {
let guard = hooks.lock().map_err(|e| FerrotorchError::LockPoisoned {
message: format!("hook storage mutex: {e}"),
})?;
for hook in &guard.post_accumulate_hooks {
(hook.func)(tensor);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
fn scalar(val: f32) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(vec![val]), vec![], false).unwrap()
}
#[test]
fn test_hook_handle_uniqueness() {
let h1 = HookHandle::next();
let h2 = HookHandle::next();
assert_ne!(h1, h2);
}
#[test]
fn test_hook_storage_add_remove() {
let storage: Mutex<HookStorage<f32>> = Mutex::new(HookStorage::new());
let handle = {
let mut guard = storage.lock().unwrap();
guard.add_grad_hook(|_g| None)
};
assert!(storage.lock().unwrap().has_grad_hooks());
assert!(storage.lock().unwrap().remove(handle));
assert!(!storage.lock().unwrap().has_grad_hooks());
}
#[test]
fn test_run_grad_hooks_passthrough() {
let storage: Mutex<HookStorage<f32>> = Mutex::new(HookStorage::new());
let grad = scalar(3.0);
let result = run_grad_hooks(&storage, grad).unwrap();
assert!((result.item().unwrap() - 3.0).abs() < 1e-7);
}
#[test]
fn test_run_grad_hooks_replace() {
let storage: Mutex<HookStorage<f32>> = Mutex::new(HookStorage::new());
{
let mut guard = storage.lock().unwrap();
guard.add_grad_hook(|_g| {
Some(Tensor::from_storage(TensorStorage::cpu(vec![99.0]), vec![], false).unwrap())
});
}
let grad = scalar(3.0);
let result = run_grad_hooks(&storage, grad).unwrap();
assert!((result.item().unwrap() - 99.0).abs() < 1e-7);
}
#[test]
fn test_run_grad_hooks_chain() {
let storage: Mutex<HookStorage<f32>> = Mutex::new(HookStorage::new());
{
let mut guard = storage.lock().unwrap();
guard.add_grad_hook(|g| {
let v = g.item().unwrap() * 2.0;
Some(Tensor::from_storage(TensorStorage::cpu(vec![v]), vec![], false).unwrap())
});
guard.add_grad_hook(|g| {
let v = g.item().unwrap() + 1.0;
Some(Tensor::from_storage(TensorStorage::cpu(vec![v]), vec![], false).unwrap())
});
}
let grad = scalar(5.0);
let result = run_grad_hooks(&storage, grad).unwrap();
assert!((result.item().unwrap() - 11.0).abs() < 1e-7);
}
#[test]
fn test_post_accumulate_hook_fires() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let fired = Arc::new(AtomicBool::new(false));
let fired_clone = Arc::clone(&fired);
let storage: Mutex<HookStorage<f32>> = Mutex::new(HookStorage::new());
{
let mut guard = storage.lock().unwrap();
guard.add_post_accumulate_hook(move |_t| {
fired_clone.store(true, Ordering::Relaxed);
});
}
let t = scalar(1.0);
run_post_accumulate_hooks(&storage, &t).unwrap();
assert!(fired.load(Ordering::Relaxed));
}
#[test]
fn test_remove_nonexistent_handle() {
let storage: Mutex<HookStorage<f32>> = Mutex::new(HookStorage::new());
let fake = HookHandle(999_999);
assert!(!storage.lock().unwrap().remove(fake));
}
}