#[derive(Debug)]
pub enum InitSubsystem {
Crypto,
Rnn,
KvCache,
Other(&'static str),
}
#[derive(Debug)]
pub enum InitError {
HookFailed(&'static str, String),
Aggregate(Vec<InitError>),
}
impl ::std::fmt::Display for InitError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
InitError::HookFailed(name, msg) => write!(f, "hook '{}' failed: {}", name, msg),
InitError::Aggregate(v) => write!(f, "aggregate init errors: {}", v.len()),
}
}
}
impl ::std::error::Error for InitError {}
#[derive(Debug)]
pub enum RegisterError {
AlreadyInitialized,
}
impl ::std::fmt::Display for RegisterError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
RegisterError::AlreadyInitialized => write!(f, "registry already initialized"),
}
}
}
impl ::std::error::Error for RegisterError {}
type InitHook = fn() -> Result<(), InitError>;
struct HookEntry {
name: &'static str,
subsystem: InitSubsystem,
hook: InitHook,
}
struct Registry {
initialized: bool,
hooks: Vec<HookEntry>,
}
static INIT_REGISTRY: ::std::sync::OnceLock<::std::sync::Mutex<Registry>> =
::std::sync::OnceLock::new();
fn registry() -> &'static ::std::sync::Mutex<Registry> {
INIT_REGISTRY.get_or_init(|| {
::std::sync::Mutex::new(Registry {
initialized: false,
hooks: Vec::new(),
})
})
}
pub fn register_init_hook(
name: &'static str,
subsystem: InitSubsystem,
hook: InitHook,
) -> Result<(), RegisterError> {
let mut r = registry().lock().unwrap();
if r.initialized {
return Err(RegisterError::AlreadyInitialized);
}
r.hooks.push(HookEntry {
name,
subsystem,
hook,
});
Ok(())
}
pub fn init() -> Result<(), InitError> {
match crate::std::crypto_std::register_module_init_hooks() {
Ok(()) => {}
Err(RegisterError::AlreadyInitialized) => {}
}
match crate::std::rnn_std::register_module_init_hooks() {
Ok(()) => {}
Err(RegisterError::AlreadyInitialized) => {}
}
match crate::std::kv_cache_std::register_module_init_hooks() {
Ok(()) => {}
Err(RegisterError::AlreadyInitialized) => {}
}
let mut r = registry().lock().unwrap();
if r.initialized {
return Ok(());
}
let mut errors: Vec<InitError> = Vec::new();
for h in &r.hooks {
match (h.hook)() {
Ok(()) => {}
Err(e) => errors.push(InitError::HookFailed(h.name, e.to_string())),
}
}
if !errors.is_empty() {
return Err(InitError::Aggregate(errors));
}
r.initialized = true;
Ok(())
}
pub fn try_init(subsystems: &[InitSubsystem]) -> Result<(), InitError> {
match crate::std::crypto_std::register_module_init_hooks() {
Ok(()) => {}
Err(RegisterError::AlreadyInitialized) => {}
}
match crate::std::rnn_std::register_module_init_hooks() {
Ok(()) => {}
Err(RegisterError::AlreadyInitialized) => {}
}
match crate::std::kv_cache_std::register_module_init_hooks() {
Ok(()) => {}
Err(RegisterError::AlreadyInitialized) => {}
}
let mut r = registry().lock().unwrap();
if r.initialized {
return Ok(());
}
let mut errors: Vec<InitError> = Vec::new();
for h in &r.hooks {
let mut run = false;
for s in subsystems {
match (s, &h.subsystem) {
(InitSubsystem::Other(a), InitSubsystem::Other(b)) => {
if a == b {
run = true;
break;
}
}
(InitSubsystem::Crypto, InitSubsystem::Crypto)
| (InitSubsystem::Rnn, InitSubsystem::Rnn)
| (InitSubsystem::KvCache, InitSubsystem::KvCache) => {
run = true;
break;
}
_ => {}
}
}
if run {
match (h.hook)() {
Ok(()) => {}
Err(e) => errors.push(InitError::HookFailed(h.name, e.to_string())),
}
}
}
if !errors.is_empty() {
return Err(InitError::Aggregate(errors));
}
r.initialized = true;
Ok(())
}
pub fn is_initialized() -> bool {
let r = registry().lock().unwrap();
r.initialized
}
pub fn shutdown() {
let mut r = registry().lock().unwrap();
r.initialized = false;
r.hooks.clear();
}