daedalus_runtime/
capabilities.rs

1use crate::executor::{EdgePayload, NodeError};
2use std::any::{Any, TypeId};
3use std::collections::HashMap;
4use std::sync::{OnceLock, RwLock};
5
6type CapabilityFn = dyn Fn(&[&dyn Any]) -> Result<EdgePayload, NodeError> + Send + Sync;
7
8pub struct CapabilityEntry {
9    pub type_ids: Vec<TypeId>,
10    pub func: Box<CapabilityFn>,
11}
12
13impl CapabilityEntry {
14    pub fn new(type_ids: Vec<TypeId>, func: Box<CapabilityFn>) -> Self {
15        Self { type_ids, func }
16    }
17}
18
19#[derive(Default)]
20pub struct CapabilityRegistry {
21    entries: HashMap<String, Vec<CapabilityEntry>>,
22}
23
24impl CapabilityRegistry {
25    pub fn new() -> Self {
26        Self {
27            entries: HashMap::new(),
28        }
29    }
30
31    pub fn register(
32        &mut self,
33        key: impl Into<String>,
34        type_ids: Vec<TypeId>,
35        func: Box<CapabilityFn>,
36    ) {
37        self.entries
38            .entry(key.into())
39            .or_default()
40            .push(CapabilityEntry { type_ids, func });
41    }
42
43    pub fn register_typed<T, F>(&mut self, key: impl Into<String>, f: F)
44    where
45        T: Send + Sync + 'static,
46        F: Fn(&T, &T) -> Result<T, NodeError> + Send + Sync + 'static,
47    {
48        let key = key.into();
49        self.register(
50            key,
51            vec![TypeId::of::<T>(), TypeId::of::<T>()],
52            Box::new(move |args: &[&dyn Any]| {
53                let a = args
54                    .first()
55                    .and_then(|v| v.downcast_ref::<T>())
56                    .ok_or_else(|| NodeError::InvalidInput("lhs".into()))?;
57                let b = args
58                    .get(1)
59                    .and_then(|v| v.downcast_ref::<T>())
60                    .ok_or_else(|| NodeError::InvalidInput("rhs".into()))?;
61                f(a, b).map(|out| EdgePayload::Any(std::sync::Arc::new(out)))
62            }),
63        );
64    }
65
66    pub fn register_typed3<T, F>(&mut self, key: impl Into<String>, f: F)
67    where
68        T: Send + Sync + 'static,
69        F: Fn(&T, &T, &T) -> Result<T, NodeError> + Send + Sync + 'static,
70    {
71        let key = key.into();
72        self.register(
73            key,
74            vec![TypeId::of::<T>(), TypeId::of::<T>(), TypeId::of::<T>()],
75            Box::new(move |args: &[&dyn Any]| {
76                let a = args
77                    .first()
78                    .and_then(|v| v.downcast_ref::<T>())
79                    .ok_or_else(|| NodeError::InvalidInput("x".into()))?;
80                let b = args
81                    .get(1)
82                    .and_then(|v| v.downcast_ref::<T>())
83                    .ok_or_else(|| NodeError::InvalidInput("lo".into()))?;
84                let c = args
85                    .get(2)
86                    .and_then(|v| v.downcast_ref::<T>())
87                    .ok_or_else(|| NodeError::InvalidInput("hi".into()))?;
88                f(a, b, c).map(|out| EdgePayload::Any(std::sync::Arc::new(out)))
89            }),
90        );
91    }
92
93    pub fn get(&self, key: &str) -> Option<&[CapabilityEntry]> {
94        self.entries.get(key).map(|v| v.as_slice())
95    }
96
97    pub fn merge(&mut self, other: CapabilityRegistry) {
98        for (k, mut v) in other.entries {
99            self.entries.entry(k).or_default().append(&mut v);
100        }
101    }
102
103    /// Register the common arithmetic capabilities for built-in primitives.
104    /// Keys are the trait names directly: "Add", "Sub", "Mul", "Div".
105    pub fn register_primitive_arithmetic(&mut self) {
106        macro_rules! register_math_for {
107            ($ty:ty) => {
108                self.register_typed::<$ty, _>("Add", |a, b| Ok(a.clone() + b.clone()));
109                self.register_typed::<$ty, _>("Sub", |a, b| Ok(a.clone() - b.clone()));
110                self.register_typed::<$ty, _>("Mul", |a, b| Ok(a.clone() * b.clone()));
111                self.register_typed::<$ty, _>("Div", |a, b| Ok(a.clone() / b.clone()));
112            };
113        }
114        register_math_for!(i8);
115        register_math_for!(i16);
116        register_math_for!(i32);
117        register_math_for!(i64);
118        register_math_for!(i128);
119        register_math_for!(isize);
120        register_math_for!(u8);
121        register_math_for!(u16);
122        register_math_for!(u32);
123        register_math_for!(u64);
124        register_math_for!(u128);
125        register_math_for!(usize);
126        register_math_for!(f32);
127        register_math_for!(f64);
128    }
129}
130
131static GLOBAL: OnceLock<RwLock<CapabilityRegistry>> = OnceLock::new();
132
133pub fn global() -> &'static RwLock<CapabilityRegistry> {
134    GLOBAL.get_or_init(|| RwLock::new(CapabilityRegistry::new()))
135}