1use std::any::TypeId;
2use std::collections::{HashMap, HashSet};
3use std::sync::Arc;
4
5use crate::phase::{
6 EffectHandlerArc, PhaseHook, PhaseHookArc, ScheduledActionHandlerArc, ToolGateHook,
7 ToolGateHookArc, ToolPolicyGateHook, ToolPolicyHook, TypedEffectAdapter, TypedEffectHandler,
8 TypedScheduledActionAdapter, TypedScheduledActionHandler,
9};
10use crate::state::{KeyScope, MergeStrategy, StateKey, StateKeyOptions, StateMap};
11use awaken_contract::StateError;
12use awaken_contract::contract::profile_store::ProfileKey;
13use awaken_contract::contract::tool::Tool;
14use awaken_contract::model::{EffectSpec, JsonValue, Phase, ScheduledActionSpec};
15
16#[derive(Clone)]
17pub(crate) struct KeyRegistration {
18 pub(crate) type_id: TypeId,
19 pub(crate) key: String,
20 pub(crate) options: StateKeyOptions,
21 pub(crate) merge_strategy: MergeStrategy,
22 pub(crate) scope: KeyScope,
23 pub(crate) export: fn(&StateMap) -> Result<Option<JsonValue>, StateError>,
24 pub(crate) import: fn(&mut StateMap, JsonValue) -> Result<(), StateError>,
25 pub(crate) clear: fn(&mut StateMap),
26}
27
28impl KeyRegistration {
29 pub(crate) fn new<K: StateKey>(options: StateKeyOptions) -> Self {
30 Self {
31 type_id: TypeId::of::<K>(),
32 key: K::KEY.into(),
33 options,
34 merge_strategy: K::MERGE,
35 scope: options.scope,
36 export: |map| match map.get::<K>() {
37 Some(value) => K::encode(value).map(Some),
38 None => Ok(None),
39 },
40 import: |map, json| {
41 let value = K::decode(json)?;
42 map.insert::<K>(value);
43 Ok(())
44 },
45 clear: |map| {
46 let _ = map.remove::<K>();
47 },
48 }
49 }
50}
51
52#[derive(Clone)]
53pub struct ProfileKeyRegistration {
54 pub type_id: TypeId,
55 pub key: String,
56}
57
58pub(crate) struct ScheduledActionHandlerRegistration {
59 pub(crate) key: String,
60 pub(crate) handler: ScheduledActionHandlerArc,
61}
62
63pub(crate) struct EffectHandlerRegistration {
64 pub(crate) key: String,
65 pub(crate) handler: EffectHandlerArc,
66}
67
68pub(crate) struct PhaseHookRegistration {
69 pub(crate) phase: Phase,
70 pub(crate) plugin_id: String,
71 pub(crate) hook: PhaseHookArc,
72}
73
74pub(crate) struct ToolGateHookRegistration {
75 pub(crate) plugin_id: String,
76 pub(crate) hook: ToolGateHookArc,
77}
78
79pub(crate) type RequestTransformArc =
80 std::sync::Arc<dyn awaken_contract::contract::transform::InferenceRequestTransform>;
81
82pub(crate) struct RequestTransformRegistration {
83 pub(crate) plugin_id: String,
84 pub(crate) transform: RequestTransformArc,
85}
86
87pub(crate) struct ToolRegistration {
88 pub(crate) id: String,
89 pub(crate) tool: Arc<dyn Tool>,
90}
91
92#[derive(Default)]
93pub struct PluginRegistry {
94 pub(crate) plugins: HashMap<TypeId, InstalledPlugin>,
95 pub(crate) keys_by_type: HashMap<TypeId, KeyRegistration>,
96 pub(crate) keys_by_name: HashMap<String, KeyRegistration>,
97}
98
99pub struct InstalledPlugin {
100 pub(crate) owned_key_type_ids: Vec<TypeId>,
101}
102
103impl PluginRegistry {
104 pub(crate) fn merge_strategy(&self, key: &str) -> MergeStrategy {
105 self.keys_by_name
106 .get(key)
107 .map(|reg| reg.merge_strategy)
108 .unwrap_or(MergeStrategy::Exclusive)
109 }
110
111 pub(crate) fn ensure_key(&self, key: &str) -> Result<(), StateError> {
112 if self.keys_by_name.contains_key(key) {
113 Ok(())
114 } else {
115 Err(StateError::UnknownKey { key: key.into() })
116 }
117 }
118}
119
120pub struct PluginRegistrar {
121 pub(crate) keys: Vec<KeyRegistration>,
122 key_type_ids: HashSet<TypeId>,
123 key_names: HashSet<String>,
124 pub profile_keys: Vec<ProfileKeyRegistration>,
125 profile_key_type_ids: HashSet<TypeId>,
126 profile_key_names: HashSet<String>,
127 pub(crate) scheduled_actions: Vec<ScheduledActionHandlerRegistration>,
128 scheduled_action_keys: HashSet<String>,
129 pub(crate) effects: Vec<EffectHandlerRegistration>,
130 effect_keys: HashSet<String>,
131 pub(crate) phase_hooks: Vec<PhaseHookRegistration>,
132 pub(crate) tool_gate_hooks: Vec<ToolGateHookRegistration>,
133 pub(crate) request_transforms: Vec<RequestTransformRegistration>,
134 pub(crate) tools: Vec<ToolRegistration>,
135 tool_ids: HashSet<String>,
136}
137
138impl PluginRegistrar {
139 pub(crate) fn new() -> Self {
140 Self {
141 keys: Vec::new(),
142 key_type_ids: HashSet::new(),
143 key_names: HashSet::new(),
144 profile_keys: Vec::new(),
145 profile_key_type_ids: HashSet::new(),
146 profile_key_names: HashSet::new(),
147 scheduled_actions: Vec::new(),
148 scheduled_action_keys: HashSet::new(),
149 effects: Vec::new(),
150 effect_keys: HashSet::new(),
151 phase_hooks: Vec::new(),
152 tool_gate_hooks: Vec::new(),
153 request_transforms: Vec::new(),
154 tools: Vec::new(),
155 tool_ids: HashSet::new(),
156 }
157 }
158
159 pub fn register_key<K>(&mut self, options: StateKeyOptions) -> Result<(), StateError>
160 where
161 K: StateKey,
162 {
163 let type_id = TypeId::of::<K>();
164 if !self.key_type_ids.insert(type_id) || !self.key_names.insert(K::KEY.to_string()) {
165 return Err(StateError::KeyAlreadyRegistered {
166 key: K::KEY.to_string(),
167 });
168 }
169
170 self.keys.push(KeyRegistration::new::<K>(options));
171 Ok(())
172 }
173
174 pub fn register_scheduled_action<A, H>(&mut self, handler: H) -> Result<(), StateError>
175 where
176 A: ScheduledActionSpec,
177 H: TypedScheduledActionHandler<A>,
178 {
179 let key = A::KEY.to_string();
180 if !self.scheduled_action_keys.insert(key.clone()) {
181 return Err(StateError::HandlerAlreadyRegistered { key });
182 }
183
184 self.scheduled_actions
185 .push(ScheduledActionHandlerRegistration {
186 key,
187 handler: Arc::new(TypedScheduledActionAdapter::<A, H> {
188 handler,
189 _marker: std::marker::PhantomData,
190 }),
191 });
192 Ok(())
193 }
194
195 pub fn register_effect<E, H>(&mut self, handler: H) -> Result<(), StateError>
196 where
197 E: EffectSpec,
198 H: TypedEffectHandler<E>,
199 {
200 let key = E::KEY.to_string();
201 if !self.effect_keys.insert(key.clone()) {
202 return Err(StateError::EffectHandlerAlreadyRegistered { key });
203 }
204
205 self.effects.push(EffectHandlerRegistration {
206 key,
207 handler: Arc::new(TypedEffectAdapter::<E, H> {
208 handler,
209 _marker: std::marker::PhantomData,
210 }),
211 });
212 Ok(())
213 }
214
215 pub fn register_phase_hook<H>(
216 &mut self,
217 plugin_id: impl Into<String>,
218 phase: Phase,
219 hook: H,
220 ) -> Result<(), StateError>
221 where
222 H: PhaseHook,
223 {
224 self.phase_hooks.push(PhaseHookRegistration {
225 phase,
226 plugin_id: plugin_id.into(),
227 hook: Arc::new(hook),
228 });
229 Ok(())
230 }
231
232 pub fn register_tool_gate_hook<H>(
233 &mut self,
234 plugin_id: impl Into<String>,
235 hook: H,
236 ) -> Result<(), StateError>
237 where
238 H: ToolGateHook,
239 {
240 self.tool_gate_hooks.push(ToolGateHookRegistration {
241 plugin_id: plugin_id.into(),
242 hook: Arc::new(hook),
243 });
244 Ok(())
245 }
246
247 pub fn register_tool_policy_hook<H>(
250 &mut self,
251 plugin_id: impl Into<String>,
252 hook: H,
253 ) -> Result<(), StateError>
254 where
255 H: ToolPolicyHook,
256 {
257 self.tool_gate_hooks.push(ToolGateHookRegistration {
258 plugin_id: plugin_id.into(),
259 hook: Arc::new(ToolPolicyGateHook::new(Arc::new(hook))),
260 });
261 Ok(())
262 }
263
264 pub fn register_tool(
269 &mut self,
270 id: impl Into<String>,
271 tool: Arc<dyn Tool>,
272 ) -> Result<(), StateError> {
273 let id = id.into();
274 if !self.tool_ids.insert(id.clone()) {
275 return Err(StateError::ToolAlreadyRegistered { tool_id: id });
276 }
277 self.tools.push(ToolRegistration { id, tool });
278 Ok(())
279 }
280
281 pub fn register_request_transform<T>(&mut self, plugin_id: impl Into<String>, transform: T)
283 where
284 T: awaken_contract::contract::transform::InferenceRequestTransform + 'static,
285 {
286 self.request_transforms.push(RequestTransformRegistration {
287 plugin_id: plugin_id.into(),
288 transform: Arc::new(transform),
289 });
290 }
291
292 pub fn register_profile_key<K: ProfileKey>(&mut self) -> Result<(), StateError> {
294 let type_id = TypeId::of::<K>();
295 if !self.profile_key_type_ids.insert(type_id)
296 || !self.profile_key_names.insert(K::KEY.to_string())
297 {
298 return Err(StateError::KeyAlreadyRegistered {
299 key: K::KEY.to_string(),
300 });
301 }
302 self.profile_keys.push(ProfileKeyRegistration {
303 type_id,
304 key: K::KEY.to_string(),
305 });
306 Ok(())
307 }
308
309 #[cfg(any(test, feature = "test-utils"))]
310 pub fn new_for_test() -> Self {
311 Self::new()
312 }
313
314 #[cfg(any(test, feature = "test-utils"))]
315 pub fn profile_keys_for_test(&self) -> Vec<ProfileKeyRegistration> {
316 self.profile_keys.clone()
317 }
318
319 #[cfg(any(test, feature = "test-utils"))]
321 pub fn tool_ids_for_test(&self) -> Vec<String> {
322 self.tools.iter().map(|t| t.id.clone()).collect()
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use awaken_contract::contract::profile_store::ProfileKey;
330
331 struct TestLocale;
332 impl ProfileKey for TestLocale {
333 const KEY: &'static str = "locale";
334 type Value = String;
335 }
336
337 #[test]
338 fn register_profile_key_succeeds() {
339 let mut registrar = PluginRegistrar::new_for_test();
340 registrar.register_profile_key::<TestLocale>().unwrap();
341 let keys = registrar.profile_keys_for_test();
342 assert_eq!(keys.len(), 1);
343 assert_eq!(keys[0].key, "locale");
344 }
345
346 #[test]
347 fn register_duplicate_profile_key_errors() {
348 let mut registrar = PluginRegistrar::new_for_test();
349 registrar.register_profile_key::<TestLocale>().unwrap();
350 let err = registrar.register_profile_key::<TestLocale>().unwrap_err();
351 assert!(err.to_string().contains("locale"));
352 }
353}