Skip to main content

awaken_runtime/plugins/
registry.rs

1use std::any::TypeId;
2use std::collections::{HashMap, HashSet};
3use std::sync::Arc;
4
5use crate::phase::{
6    EffectHandlerArc, PhaseHook, PhaseHookArc, ScheduledActionHandlerArc, ToolGateHook,
7    ToolGateHookArc, ToolPolicyGateHook, ToolPolicyHook, TypedEffectAdapter, TypedEffectHandler,
8    TypedScheduledActionAdapter, TypedScheduledActionHandler,
9};
10use crate::state::{KeyScope, MergeStrategy, StateKey, StateKeyOptions, StateMap};
11use awaken_contract::StateError;
12use awaken_contract::contract::profile_store::ProfileKey;
13use awaken_contract::contract::tool::Tool;
14use awaken_contract::model::{EffectSpec, JsonValue, Phase, ScheduledActionSpec};
15
16#[derive(Clone)]
17pub(crate) struct KeyRegistration {
18    pub(crate) type_id: TypeId,
19    pub(crate) key: String,
20    pub(crate) options: StateKeyOptions,
21    pub(crate) merge_strategy: MergeStrategy,
22    pub(crate) scope: KeyScope,
23    pub(crate) export: fn(&StateMap) -> Result<Option<JsonValue>, StateError>,
24    pub(crate) import: fn(&mut StateMap, JsonValue) -> Result<(), StateError>,
25    pub(crate) clear: fn(&mut StateMap),
26}
27
28impl KeyRegistration {
29    pub(crate) fn new<K: StateKey>(options: StateKeyOptions) -> Self {
30        Self {
31            type_id: TypeId::of::<K>(),
32            key: K::KEY.into(),
33            options,
34            merge_strategy: K::MERGE,
35            scope: options.scope,
36            export: |map| match map.get::<K>() {
37                Some(value) => K::encode(value).map(Some),
38                None => Ok(None),
39            },
40            import: |map, json| {
41                let value = K::decode(json)?;
42                map.insert::<K>(value);
43                Ok(())
44            },
45            clear: |map| {
46                let _ = map.remove::<K>();
47            },
48        }
49    }
50}
51
52#[derive(Clone)]
53pub struct ProfileKeyRegistration {
54    pub type_id: TypeId,
55    pub key: String,
56}
57
58pub(crate) struct ScheduledActionHandlerRegistration {
59    pub(crate) key: String,
60    pub(crate) handler: ScheduledActionHandlerArc,
61}
62
63pub(crate) struct EffectHandlerRegistration {
64    pub(crate) key: String,
65    pub(crate) handler: EffectHandlerArc,
66}
67
68pub(crate) struct PhaseHookRegistration {
69    pub(crate) phase: Phase,
70    pub(crate) plugin_id: String,
71    pub(crate) hook: PhaseHookArc,
72}
73
74pub(crate) struct ToolGateHookRegistration {
75    pub(crate) plugin_id: String,
76    pub(crate) hook: ToolGateHookArc,
77}
78
79pub(crate) type RequestTransformArc =
80    std::sync::Arc<dyn awaken_contract::contract::transform::InferenceRequestTransform>;
81
82pub(crate) struct RequestTransformRegistration {
83    pub(crate) plugin_id: String,
84    pub(crate) transform: RequestTransformArc,
85}
86
87pub(crate) struct ToolRegistration {
88    pub(crate) id: String,
89    pub(crate) tool: Arc<dyn Tool>,
90}
91
92#[derive(Default)]
93pub struct PluginRegistry {
94    pub(crate) plugins: HashMap<TypeId, InstalledPlugin>,
95    pub(crate) keys_by_type: HashMap<TypeId, KeyRegistration>,
96    pub(crate) keys_by_name: HashMap<String, KeyRegistration>,
97}
98
99pub struct InstalledPlugin {
100    pub(crate) owned_key_type_ids: Vec<TypeId>,
101}
102
103impl PluginRegistry {
104    pub(crate) fn merge_strategy(&self, key: &str) -> MergeStrategy {
105        self.keys_by_name
106            .get(key)
107            .map(|reg| reg.merge_strategy)
108            .unwrap_or(MergeStrategy::Exclusive)
109    }
110
111    pub(crate) fn ensure_key(&self, key: &str) -> Result<(), StateError> {
112        if self.keys_by_name.contains_key(key) {
113            Ok(())
114        } else {
115            Err(StateError::UnknownKey { key: key.into() })
116        }
117    }
118}
119
120pub struct PluginRegistrar {
121    pub(crate) keys: Vec<KeyRegistration>,
122    key_type_ids: HashSet<TypeId>,
123    key_names: HashSet<String>,
124    pub profile_keys: Vec<ProfileKeyRegistration>,
125    profile_key_type_ids: HashSet<TypeId>,
126    profile_key_names: HashSet<String>,
127    pub(crate) scheduled_actions: Vec<ScheduledActionHandlerRegistration>,
128    scheduled_action_keys: HashSet<String>,
129    pub(crate) effects: Vec<EffectHandlerRegistration>,
130    effect_keys: HashSet<String>,
131    pub(crate) phase_hooks: Vec<PhaseHookRegistration>,
132    pub(crate) tool_gate_hooks: Vec<ToolGateHookRegistration>,
133    pub(crate) request_transforms: Vec<RequestTransformRegistration>,
134    pub(crate) tools: Vec<ToolRegistration>,
135    tool_ids: HashSet<String>,
136}
137
138impl PluginRegistrar {
139    pub(crate) fn new() -> Self {
140        Self {
141            keys: Vec::new(),
142            key_type_ids: HashSet::new(),
143            key_names: HashSet::new(),
144            profile_keys: Vec::new(),
145            profile_key_type_ids: HashSet::new(),
146            profile_key_names: HashSet::new(),
147            scheduled_actions: Vec::new(),
148            scheduled_action_keys: HashSet::new(),
149            effects: Vec::new(),
150            effect_keys: HashSet::new(),
151            phase_hooks: Vec::new(),
152            tool_gate_hooks: Vec::new(),
153            request_transforms: Vec::new(),
154            tools: Vec::new(),
155            tool_ids: HashSet::new(),
156        }
157    }
158
159    pub fn register_key<K>(&mut self, options: StateKeyOptions) -> Result<(), StateError>
160    where
161        K: StateKey,
162    {
163        let type_id = TypeId::of::<K>();
164        if !self.key_type_ids.insert(type_id) || !self.key_names.insert(K::KEY.to_string()) {
165            return Err(StateError::KeyAlreadyRegistered {
166                key: K::KEY.to_string(),
167            });
168        }
169
170        self.keys.push(KeyRegistration::new::<K>(options));
171        Ok(())
172    }
173
174    pub fn register_scheduled_action<A, H>(&mut self, handler: H) -> Result<(), StateError>
175    where
176        A: ScheduledActionSpec,
177        H: TypedScheduledActionHandler<A>,
178    {
179        let key = A::KEY.to_string();
180        if !self.scheduled_action_keys.insert(key.clone()) {
181            return Err(StateError::HandlerAlreadyRegistered { key });
182        }
183
184        self.scheduled_actions
185            .push(ScheduledActionHandlerRegistration {
186                key,
187                handler: Arc::new(TypedScheduledActionAdapter::<A, H> {
188                    handler,
189                    _marker: std::marker::PhantomData,
190                }),
191            });
192        Ok(())
193    }
194
195    pub fn register_effect<E, H>(&mut self, handler: H) -> Result<(), StateError>
196    where
197        E: EffectSpec,
198        H: TypedEffectHandler<E>,
199    {
200        let key = E::KEY.to_string();
201        if !self.effect_keys.insert(key.clone()) {
202            return Err(StateError::EffectHandlerAlreadyRegistered { key });
203        }
204
205        self.effects.push(EffectHandlerRegistration {
206            key,
207            handler: Arc::new(TypedEffectAdapter::<E, H> {
208                handler,
209                _marker: std::marker::PhantomData,
210            }),
211        });
212        Ok(())
213    }
214
215    pub fn register_phase_hook<H>(
216        &mut self,
217        plugin_id: impl Into<String>,
218        phase: Phase,
219        hook: H,
220    ) -> Result<(), StateError>
221    where
222        H: PhaseHook,
223    {
224        self.phase_hooks.push(PhaseHookRegistration {
225            phase,
226            plugin_id: plugin_id.into(),
227            hook: Arc::new(hook),
228        });
229        Ok(())
230    }
231
232    pub fn register_tool_gate_hook<H>(
233        &mut self,
234        plugin_id: impl Into<String>,
235        hook: H,
236    ) -> Result<(), StateError>
237    where
238        H: ToolGateHook,
239    {
240        self.tool_gate_hooks.push(ToolGateHookRegistration {
241            plugin_id: plugin_id.into(),
242            hook: Arc::new(hook),
243        });
244        Ok(())
245    }
246
247    /// Register a typed tool policy hook. Policy hooks are executed through the
248    /// existing ToolGate phase, so ordering and conflict resolution stay unified.
249    pub fn register_tool_policy_hook<H>(
250        &mut self,
251        plugin_id: impl Into<String>,
252        hook: H,
253    ) -> Result<(), StateError>
254    where
255        H: ToolPolicyHook,
256    {
257        self.tool_gate_hooks.push(ToolGateHookRegistration {
258            plugin_id: plugin_id.into(),
259            hook: Arc::new(ToolPolicyGateHook::new(Arc::new(hook))),
260        });
261        Ok(())
262    }
263
264    /// Register a tool provided by this plugin.
265    ///
266    /// The tool becomes available to agents that activate this plugin.
267    /// Tool IDs must be unique across all plugins; duplicates cause a resolve error.
268    pub fn register_tool(
269        &mut self,
270        id: impl Into<String>,
271        tool: Arc<dyn Tool>,
272    ) -> Result<(), StateError> {
273        let id = id.into();
274        if !self.tool_ids.insert(id.clone()) {
275            return Err(StateError::ToolAlreadyRegistered { tool_id: id });
276        }
277        self.tools.push(ToolRegistration { id, tool });
278        Ok(())
279    }
280
281    /// Register a request transform applied after message assembly, before LLM call.
282    pub fn register_request_transform<T>(&mut self, plugin_id: impl Into<String>, transform: T)
283    where
284        T: awaken_contract::contract::transform::InferenceRequestTransform + 'static,
285    {
286        self.request_transforms.push(RequestTransformRegistration {
287            plugin_id: plugin_id.into(),
288            transform: Arc::new(transform),
289        });
290    }
291
292    /// Register a profile key for typed profile storage access.
293    pub fn register_profile_key<K: ProfileKey>(&mut self) -> Result<(), StateError> {
294        let type_id = TypeId::of::<K>();
295        if !self.profile_key_type_ids.insert(type_id)
296            || !self.profile_key_names.insert(K::KEY.to_string())
297        {
298            return Err(StateError::KeyAlreadyRegistered {
299                key: K::KEY.to_string(),
300            });
301        }
302        self.profile_keys.push(ProfileKeyRegistration {
303            type_id,
304            key: K::KEY.to_string(),
305        });
306        Ok(())
307    }
308
309    #[cfg(any(test, feature = "test-utils"))]
310    pub fn new_for_test() -> Self {
311        Self::new()
312    }
313
314    #[cfg(any(test, feature = "test-utils"))]
315    pub fn profile_keys_for_test(&self) -> Vec<ProfileKeyRegistration> {
316        self.profile_keys.clone()
317    }
318
319    /// Returns the list of registered tool IDs (test helper).
320    #[cfg(any(test, feature = "test-utils"))]
321    pub fn tool_ids_for_test(&self) -> Vec<String> {
322        self.tools.iter().map(|t| t.id.clone()).collect()
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use awaken_contract::contract::profile_store::ProfileKey;
330
331    struct TestLocale;
332    impl ProfileKey for TestLocale {
333        const KEY: &'static str = "locale";
334        type Value = String;
335    }
336
337    #[test]
338    fn register_profile_key_succeeds() {
339        let mut registrar = PluginRegistrar::new_for_test();
340        registrar.register_profile_key::<TestLocale>().unwrap();
341        let keys = registrar.profile_keys_for_test();
342        assert_eq!(keys.len(), 1);
343        assert_eq!(keys[0].key, "locale");
344    }
345
346    #[test]
347    fn register_duplicate_profile_key_errors() {
348        let mut registrar = PluginRegistrar::new_for_test();
349        registrar.register_profile_key::<TestLocale>().unwrap();
350        let err = registrar.register_profile_key::<TestLocale>().unwrap_err();
351        assert!(err.to_string().contains("locale"));
352    }
353}