daedalus_runtime/
capabilities.rs1use crate::executor::{EdgePayload, NodeError};
2use std::any::{Any, TypeId};
3use std::collections::HashMap;
4use std::sync::{OnceLock, RwLock};
5
6type CapabilityFn = dyn Fn(&[&dyn Any]) -> Result<EdgePayload, NodeError> + Send + Sync;
7
8pub struct CapabilityEntry {
9 pub type_ids: Vec<TypeId>,
10 pub func: Box<CapabilityFn>,
11}
12
13impl CapabilityEntry {
14 pub fn new(type_ids: Vec<TypeId>, func: Box<CapabilityFn>) -> Self {
15 Self { type_ids, func }
16 }
17}
18
19#[derive(Default)]
20pub struct CapabilityRegistry {
21 entries: HashMap<String, Vec<CapabilityEntry>>,
22}
23
24impl CapabilityRegistry {
25 pub fn new() -> Self {
26 Self {
27 entries: HashMap::new(),
28 }
29 }
30
31 pub fn register(
32 &mut self,
33 key: impl Into<String>,
34 type_ids: Vec<TypeId>,
35 func: Box<CapabilityFn>,
36 ) {
37 self.entries
38 .entry(key.into())
39 .or_default()
40 .push(CapabilityEntry { type_ids, func });
41 }
42
43 pub fn register_typed<T, F>(&mut self, key: impl Into<String>, f: F)
44 where
45 T: Send + Sync + 'static,
46 F: Fn(&T, &T) -> Result<T, NodeError> + Send + Sync + 'static,
47 {
48 let key = key.into();
49 self.register(
50 key,
51 vec![TypeId::of::<T>(), TypeId::of::<T>()],
52 Box::new(move |args: &[&dyn Any]| {
53 let a = args
54 .first()
55 .and_then(|v| v.downcast_ref::<T>())
56 .ok_or_else(|| NodeError::InvalidInput("lhs".into()))?;
57 let b = args
58 .get(1)
59 .and_then(|v| v.downcast_ref::<T>())
60 .ok_or_else(|| NodeError::InvalidInput("rhs".into()))?;
61 f(a, b).map(|out| EdgePayload::Any(std::sync::Arc::new(out)))
62 }),
63 );
64 }
65
66 pub fn register_typed3<T, F>(&mut self, key: impl Into<String>, f: F)
67 where
68 T: Send + Sync + 'static,
69 F: Fn(&T, &T, &T) -> Result<T, NodeError> + Send + Sync + 'static,
70 {
71 let key = key.into();
72 self.register(
73 key,
74 vec![TypeId::of::<T>(), TypeId::of::<T>(), TypeId::of::<T>()],
75 Box::new(move |args: &[&dyn Any]| {
76 let a = args
77 .first()
78 .and_then(|v| v.downcast_ref::<T>())
79 .ok_or_else(|| NodeError::InvalidInput("x".into()))?;
80 let b = args
81 .get(1)
82 .and_then(|v| v.downcast_ref::<T>())
83 .ok_or_else(|| NodeError::InvalidInput("lo".into()))?;
84 let c = args
85 .get(2)
86 .and_then(|v| v.downcast_ref::<T>())
87 .ok_or_else(|| NodeError::InvalidInput("hi".into()))?;
88 f(a, b, c).map(|out| EdgePayload::Any(std::sync::Arc::new(out)))
89 }),
90 );
91 }
92
93 pub fn get(&self, key: &str) -> Option<&[CapabilityEntry]> {
94 self.entries.get(key).map(|v| v.as_slice())
95 }
96
97 pub fn merge(&mut self, other: CapabilityRegistry) {
98 for (k, mut v) in other.entries {
99 self.entries.entry(k).or_default().append(&mut v);
100 }
101 }
102
103 pub fn register_primitive_arithmetic(&mut self) {
106 macro_rules! register_math_for {
107 ($ty:ty) => {
108 self.register_typed::<$ty, _>("Add", |a, b| Ok(a.clone() + b.clone()));
109 self.register_typed::<$ty, _>("Sub", |a, b| Ok(a.clone() - b.clone()));
110 self.register_typed::<$ty, _>("Mul", |a, b| Ok(a.clone() * b.clone()));
111 self.register_typed::<$ty, _>("Div", |a, b| Ok(a.clone() / b.clone()));
112 };
113 }
114 register_math_for!(i8);
115 register_math_for!(i16);
116 register_math_for!(i32);
117 register_math_for!(i64);
118 register_math_for!(i128);
119 register_math_for!(isize);
120 register_math_for!(u8);
121 register_math_for!(u16);
122 register_math_for!(u32);
123 register_math_for!(u64);
124 register_math_for!(u128);
125 register_math_for!(usize);
126 register_math_for!(f32);
127 register_math_for!(f64);
128 }
129}
130
131static GLOBAL: OnceLock<RwLock<CapabilityRegistry>> = OnceLock::new();
132
133pub fn global() -> &'static RwLock<CapabilityRegistry> {
134 GLOBAL.get_or_init(|| RwLock::new(CapabilityRegistry::new()))
135}