use super::events::GraphEvent;
use hashbrown::HashMap;
use parking_lot::RwLock;
use polaris_system::api::API;
use polaris_system::param::SystemContext;
use polaris_system::plugin::{IntoScheduleIds, ScheduleId};
use polaris_system::resource::LocalResource;
use std::any::TypeId;
use std::fmt;
use std::sync::Arc;
pub struct BoxedHook {
pub(crate) handler: Box<dyn Fn(&mut SystemContext<'_>, &GraphEvent) + Send + Sync>,
pub(crate) provided_resources: Vec<TypeId>,
}
impl BoxedHook {
#[must_use = "BoxedHook must be registered via HooksAPI::register_boxed to take effect"]
pub fn new(
handler: impl Fn(&mut SystemContext<'_>, &GraphEvent) + Send + Sync + 'static,
provided_resources: Vec<TypeId>,
) -> Self {
Self {
handler: Box::new(handler),
provided_resources,
}
}
pub fn invoke(&self, ctx: &mut SystemContext<'_>, event: &GraphEvent) {
(self.handler)(ctx, event);
}
#[must_use]
pub fn provided_resources(&self) -> &[TypeId] {
&self.provided_resources
}
}
#[derive(Debug, Clone)]
pub enum HookRegistrationError {
DuplicateName {
schedule: ScheduleId,
name: String,
},
}
impl fmt::Display for HookRegistrationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
HookRegistrationError::DuplicateName { schedule, name } => {
write!(
f,
"hook '{}' already registered for schedule '{}'",
name,
schedule.type_name()
)
}
}
}
}
impl std::error::Error for HookRegistrationError {}
struct HookEntry {
name: String,
hook: BoxedHook,
}
#[derive(Clone, Default)]
pub struct HooksAPI {
hooks: Arc<RwLock<HashMap<ScheduleId, Vec<HookEntry>>>>,
}
impl API for HooksAPI {}
impl HooksAPI {
#[must_use]
pub fn new() -> Self {
Self {
hooks: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register_observer<S, F>(
&self,
name: impl Into<String>,
hook: F,
) -> Result<&Self, HookRegistrationError>
where
S: IntoScheduleIds,
F: Fn(&GraphEvent) + Send + Sync + 'static,
{
let schedules = S::schedule_ids();
let name = name.into();
let hook = Arc::new(hook);
for schedule in &schedules {
let hook_name = if schedules.len() > 1 {
format!("{}@{}", name, schedule.type_name())
} else {
name.clone()
};
let hook_clone = Arc::clone(&hook);
self.register_boxed(
*schedule,
hook_name,
BoxedHook::new(
move |_ctx, event: &GraphEvent| {
hook_clone(event);
},
Vec::new(), ),
)?;
}
Ok(self)
}
pub fn register_provider<S, T, F>(
&self,
name: impl Into<String>,
hook: F,
) -> Result<&Self, HookRegistrationError>
where
S: IntoScheduleIds,
T: LocalResource,
F: Fn(&GraphEvent) -> Option<T> + Send + Sync + 'static,
{
let schedules = S::schedule_ids();
let name = name.into();
let hook = Arc::new(hook);
for schedule in &schedules {
let hook_name = if schedules.len() > 1 {
format!("{}@{}", name, schedule.type_name())
} else {
name.clone()
};
let hook_clone = Arc::clone(&hook);
self.register_boxed(
*schedule,
hook_name,
BoxedHook::new(
move |ctx, event: &GraphEvent| {
if let Some(resource) = hook_clone(event) {
ctx.insert(resource);
}
},
vec![TypeId::of::<T>()], ),
)?;
}
Ok(self)
}
pub fn register_boxed(
&self,
schedule: ScheduleId,
name: impl Into<String>,
hook: BoxedHook,
) -> Result<(), HookRegistrationError> {
let name = name.into();
let mut hooks = self.hooks.write();
let entries = hooks.entry(schedule).or_default();
if entries.iter().any(|entry| entry.name == name) {
return Err(HookRegistrationError::DuplicateName { schedule, name });
}
entries.push(HookEntry { name, hook });
Ok(())
}
pub fn invoke(&self, schedule: ScheduleId, ctx: &mut SystemContext<'_>, event: &GraphEvent) {
let hooks = self.hooks.read();
if let Some(entries) = hooks.get(&schedule) {
for entry in entries {
entry.hook.invoke(ctx, event);
}
}
}
#[must_use]
pub fn hook_count(&self, schedule: ScheduleId) -> usize {
let hooks = self.hooks.read();
hooks.get(&schedule).map_or(0, Vec::len)
}
#[must_use]
pub fn provided_resources_for(&self, schedule: ScheduleId) -> Vec<TypeId> {
let hooks = self.hooks.read();
hooks
.get(&schedule)
.map(|entries| {
entries
.iter()
.flat_map(|entry| entry.hook.provided_resources().iter().copied())
.collect()
})
.unwrap_or_default()
}
#[must_use]
pub fn contains_hook(&self, schedule: ScheduleId, name: &str) -> bool {
let hooks = self.hooks.read();
hooks
.get(&schedule)
.is_some_and(|entries| entries.iter().any(|entry| entry.name == name))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hooks::events::{RunId, RunLabels};
use crate::hooks::schedule::{OnSystemComplete, OnSystemStart};
use crate::node::NodeId;
use polaris_system::plugin::Schedule;
use polaris_system::resource::LocalResource;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
fn sample_system_start(node_name: &'static str) -> GraphEvent {
GraphEvent::SystemStart {
run_id: RunId::new(),
labels: RunLabels::empty(),
node_id: NodeId::new(),
node_name,
}
}
#[test]
fn hooks_api_register_increments_count() {
let api = HooksAPI::new();
let schedule = OnSystemStart::schedule_id();
api.register_observer::<OnSystemStart, _>("test_hook", |_: &GraphEvent| {})
.expect("registration should succeed");
assert_eq!(api.hook_count(schedule), 1);
api.register_observer::<OnSystemStart, _>("another_hook", |_: &GraphEvent| {})
.expect("registration should succeed");
assert_eq!(api.hook_count(schedule), 2);
}
#[test]
fn hooks_api_invoke_calls_hooks() {
let api = HooksAPI::new();
let schedule = OnSystemStart::schedule_id();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
api.register_observer::<OnSystemStart, _>("counting_hook", move |_: &GraphEvent| {
counter_clone.fetch_add(1, Ordering::SeqCst);
})
.expect("registration should succeed");
let mut ctx = SystemContext::new();
let event = sample_system_start("test");
api.invoke(schedule, &mut ctx, &event);
assert_eq!(counter.load(Ordering::SeqCst), 1);
api.invoke(schedule, &mut ctx, &event);
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
#[test]
fn hooks_api_invoke_calls_all_hooks_in_order() {
let api = HooksAPI::new();
let schedule = OnSystemStart::schedule_id();
let execution_order = Arc::new(Mutex::new(Vec::new()));
for name in ["first", "second", "third"] {
let order_clone = execution_order.clone();
let name_owned = name.to_owned();
api.register_observer::<OnSystemStart, _>(name, move |_: &GraphEvent| {
order_clone.lock().unwrap().push(name_owned.clone());
})
.expect("registration should succeed");
}
let mut ctx = SystemContext::new();
let event = sample_system_start("test");
api.invoke(schedule, &mut ctx, &event);
let order = execution_order.lock().unwrap();
assert_eq!(
*order,
vec!["first", "second", "third"],
"hooks should execute in registration order"
);
}
#[test]
fn hooks_api_invoke_unknown_schedule_is_noop() {
let api = HooksAPI::new();
let mut ctx = SystemContext::new();
let event = sample_system_start("test");
api.invoke(OnSystemStart::schedule_id(), &mut ctx, &event);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct TestResource {
value: i32,
}
impl LocalResource for TestResource {}
#[test]
fn register_provider_inserts_resource() {
let api = HooksAPI::new();
api.register_provider::<OnSystemStart, TestResource, _>(
"provider",
|_event: &GraphEvent| Some(TestResource { value: 42 }),
)
.expect("registration should succeed");
let mut ctx = SystemContext::new();
let event = sample_system_start("test");
assert!(!ctx.contains_resource::<TestResource>());
api.invoke(OnSystemStart::schedule_id(), &mut ctx, &event);
let resource = ctx
.get_resource::<TestResource>()
.expect("resource should be inserted");
assert_eq!(resource.value, 42);
}
#[test]
fn provided_resources_for_returns_provider_types() {
let api = HooksAPI::new();
let schedule = OnSystemStart::schedule_id();
assert!(api.provided_resources_for(schedule).is_empty());
api.register_observer::<OnSystemStart, _>("observer", |_: &GraphEvent| {})
.unwrap();
assert!(
api.provided_resources_for(schedule).is_empty(),
"observers provide no resources"
);
api.register_provider::<OnSystemStart, TestResource, _>("provider", |_: &GraphEvent| {
Some(TestResource { value: 0 })
})
.unwrap();
let provided = api.provided_resources_for(schedule);
assert_eq!(provided.len(), 1);
assert_eq!(provided[0], TypeId::of::<TestResource>());
}
#[test]
fn register_boxed_rejects_duplicate_names() {
let api = HooksAPI::new();
let schedule = OnSystemStart::schedule_id();
api.register_boxed(
schedule,
"my_hook",
BoxedHook::new(move |_ctx, _event| {}, Vec::new()),
)
.expect("first registration should succeed");
let result = api.register_boxed(
schedule,
"my_hook",
BoxedHook::new(move |_ctx, _event| {}, Vec::new()),
);
assert!(result.is_err());
if let Err(HookRegistrationError::DuplicateName { name, .. }) = result {
assert_eq!(name, "my_hook");
} else {
panic!("expected DuplicateName error");
}
}
#[test]
fn same_name_different_schedules_allowed() {
let api = HooksAPI::new();
api.register_observer::<OnSystemStart, _>("logger", |_: &GraphEvent| {})
.expect("first registration should succeed");
api.register_observer::<OnSystemComplete, _>("logger", |_: &GraphEvent| {})
.expect("same name on different schedule should succeed");
assert_eq!(api.hook_count(OnSystemStart::schedule_id()), 1);
assert_eq!(api.hook_count(OnSystemComplete::schedule_id()), 1);
}
#[test]
fn register_observer_chaining() {
let api = HooksAPI::new();
api.register_observer::<OnSystemStart, _>("first", |_: &GraphEvent| {})
.unwrap()
.register_observer::<OnSystemStart, _>("second", |_: &GraphEvent| {})
.unwrap();
assert_eq!(api.hook_count(OnSystemStart::schedule_id()), 2);
}
#[test]
fn contains_hook() {
let api = HooksAPI::new();
let schedule = OnSystemStart::schedule_id();
assert!(!api.contains_hook(schedule, "my_hook"));
api.register_observer::<OnSystemStart, _>("my_hook", |_: &GraphEvent| {})
.unwrap();
assert!(api.contains_hook(schedule, "my_hook"));
assert!(!api.contains_hook(schedule, "other_hook"));
}
#[test]
fn multiple_providers_last_write_wins() {
let api = HooksAPI::new();
let schedule = OnSystemStart::schedule_id();
api.register_provider::<OnSystemStart, TestResource, _>(
"first_provider",
|_: &GraphEvent| Some(TestResource { value: 1 }),
)
.unwrap();
api.register_provider::<OnSystemStart, TestResource, _>(
"second_provider",
|_: &GraphEvent| Some(TestResource { value: 2 }),
)
.unwrap();
api.register_provider::<OnSystemStart, TestResource, _>(
"third_provider",
|_: &GraphEvent| Some(TestResource { value: 3 }),
)
.unwrap();
let mut ctx = SystemContext::new();
let event = sample_system_start("test");
api.invoke(schedule, &mut ctx, &event);
let resource = ctx
.get_resource::<TestResource>()
.expect("resource should exist");
assert_eq!(resource.value, 3, "last provider's value should win");
}
#[test]
fn register_observer_multiple_schedules() {
let api = HooksAPI::new();
let events = Arc::new(Mutex::new(Vec::new()));
let events_clone = Arc::clone(&events);
api.register_observer::<(OnSystemStart, OnSystemComplete), _>(
"tracker",
move |event: &GraphEvent| {
events_clone
.lock()
.unwrap()
.push(event.schedule_name().to_string());
},
)
.unwrap();
assert_eq!(api.hook_count(OnSystemStart::schedule_id()), 1);
assert_eq!(api.hook_count(OnSystemComplete::schedule_id()), 1);
let mut ctx = SystemContext::new();
api.invoke(
OnSystemStart::schedule_id(),
&mut ctx,
&GraphEvent::SystemStart {
run_id: RunId::new(),
labels: RunLabels::empty(),
node_id: NodeId::new(),
node_name: "test",
},
);
api.invoke(
OnSystemComplete::schedule_id(),
&mut ctx,
&GraphEvent::SystemComplete {
run_id: RunId::new(),
labels: RunLabels::empty(),
node_id: NodeId::new(),
node_name: "test",
duration: Duration::ZERO,
},
);
let names = events.lock().unwrap();
assert_eq!(names.len(), 2);
assert!(names.contains(&"OnSystemStart".to_string()));
assert!(names.contains(&"OnSystemComplete".to_string()));
}
#[test]
fn graph_event_provides_typed_access_in_hook() {
let api = HooksAPI::new();
let captured = Arc::new(Mutex::new(None));
let captured_clone = Arc::clone(&captured);
api.register_observer::<OnSystemStart, _>("capture", move |event: &GraphEvent| {
if let GraphEvent::SystemStart {
node_name: system_name,
..
} = event
{
*captured_clone.lock().unwrap() = Some(system_name.to_string());
}
})
.unwrap();
let mut ctx = SystemContext::new();
api.invoke(
OnSystemStart::schedule_id(),
&mut ctx,
&GraphEvent::SystemStart {
run_id: RunId::new(),
labels: RunLabels::empty(),
node_id: NodeId::new(),
node_name: "my_system",
},
);
let name = captured.lock().unwrap().take().unwrap();
assert_eq!(name, "my_system");
}
}