use std::cell::RefCell;
use std::sync::Arc;
use crate::dtype::Float;
use crate::error::FerrotorchResult;
use crate::tensor::Tensor;
pub type PackHook<T> = Arc<dyn Fn(Tensor<T>) -> FerrotorchResult<Tensor<T>> + Send + Sync>;
pub type UnpackHook<T> = Arc<dyn Fn(Tensor<T>) -> FerrotorchResult<Tensor<T>> + Send + Sync>;
thread_local! {
static HOOKS_F32: RefCell<Option<(PackHook<f32>, UnpackHook<f32>)>> =
const { RefCell::new(None) };
}
thread_local! {
static HOOKS_F64: RefCell<Option<(PackHook<f64>, UnpackHook<f64>)>> =
const { RefCell::new(None) };
}
pub fn saved_tensors_hooks<T, F, R>(
pack: impl Fn(Tensor<T>) -> FerrotorchResult<Tensor<T>> + Send + Sync + 'static,
unpack: impl Fn(Tensor<T>) -> FerrotorchResult<Tensor<T>> + Send + Sync + 'static,
f: F,
) -> FerrotorchResult<R>
where
T: Float,
F: FnOnce() -> FerrotorchResult<R>,
{
let pack = Arc::new(pack) as PackHook<T>;
let unpack = Arc::new(unpack) as UnpackHook<T>;
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let pack_f32: PackHook<f32> = unsafe { std::mem::transmute(pack) };
let unpack_f32: UnpackHook<f32> = unsafe { std::mem::transmute(unpack) };
let prev = HOOKS_F32.with(|h| h.borrow_mut().replace((pack_f32, unpack_f32)));
let result = f();
HOOKS_F32.with(|h| *h.borrow_mut() = prev);
result
} else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let pack_f64: PackHook<f64> = unsafe { std::mem::transmute(pack) };
let unpack_f64: UnpackHook<f64> = unsafe { std::mem::transmute(unpack) };
let prev = HOOKS_F64.with(|h| h.borrow_mut().replace((pack_f64, unpack_f64)));
let result = f();
HOOKS_F64.with(|h| *h.borrow_mut() = prev);
result
} else {
f()
}
}
pub fn pack_saved_tensor<T: Float>(tensor: Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
HOOKS_F32.with(|h| {
let guard = h.borrow();
if let Some((ref pack, _)) = *guard {
let t_f32: Tensor<f32> = unsafe { std::mem::transmute(tensor) };
let result = pack(t_f32)?;
Ok(unsafe { std::mem::transmute(result) })
} else {
Ok(tensor)
}
})
} else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
HOOKS_F64.with(|h| {
let guard = h.borrow();
if let Some((ref pack, _)) = *guard {
let t_f64: Tensor<f64> = unsafe { std::mem::transmute(tensor) };
let result = pack(t_f64)?;
Ok(unsafe { std::mem::transmute(result) })
} else {
Ok(tensor)
}
})
} else {
Ok(tensor)
}
}
pub fn unpack_saved_tensor<T: Float>(tensor: Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
HOOKS_F32.with(|h| {
let guard = h.borrow();
if let Some((_, ref unpack)) = *guard {
let t_f32: Tensor<f32> = unsafe { std::mem::transmute(tensor) };
let result = unpack(t_f32)?;
Ok(unsafe { std::mem::transmute(result) })
} else {
Ok(tensor)
}
})
} else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
HOOKS_F64.with(|h| {
let guard = h.borrow();
if let Some((_, ref unpack)) = *guard {
let t_f64: Tensor<f64> = unsafe { std::mem::transmute(tensor) };
let result = unpack(t_f64)?;
Ok(unsafe { std::mem::transmute(result) })
} else {
Ok(tensor)
}
})
} else {
Ok(tensor)
}
}
pub fn has_saved_tensor_hooks() -> bool {
HOOKS_F32.with(|h| h.borrow().is_some()) || HOOKS_F64.with(|h| h.borrow().is_some())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
#[test]
fn test_pack_unpack_identity() {
let t = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]),
vec![3],
false,
)
.unwrap();
let packed = pack_saved_tensor(t.clone()).unwrap();
assert_eq!(packed.data_vec().unwrap(), vec![1.0, 2.0, 3.0]);
let unpacked = unpack_saved_tensor(packed).unwrap();
assert_eq!(unpacked.data_vec().unwrap(), vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_saved_tensors_hooks_transform() {
let result = saved_tensors_hooks(
|t: Tensor<f32>| {
let data: Vec<f32> = t.data_vec()?.iter().map(|&x| x * 2.0).collect();
Tensor::from_storage(TensorStorage::cpu(data), t.shape().to_vec(), false)
},
|t: Tensor<f32>| {
let data: Vec<f32> = t.data_vec()?.iter().map(|&x| x / 2.0).collect();
Tensor::from_storage(TensorStorage::cpu(data), t.shape().to_vec(), false)
},
|| {
let t = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]),
vec![3],
false,
)?;
let packed = pack_saved_tensor(t)?;
assert_eq!(packed.data_vec()?, vec![2.0, 4.0, 6.0]);
let unpacked = unpack_saved_tensor(packed)?;
assert_eq!(unpacked.data_vec()?, vec![1.0, 2.0, 3.0]);
Ok(())
},
);
result.unwrap();
}
#[test]
fn test_hooks_cleared_after_scope() {
saved_tensors_hooks(
|t: Tensor<f32>| Ok(t),
|t: Tensor<f32>| Ok(t),
|| {
assert!(has_saved_tensor_hooks());
Ok(())
},
)
.unwrap();
assert!(!has_saved_tensor_hooks());
}
}