Skip to main content

lash_core/plugin/
trigger_registry.rs

1use std::collections::BTreeMap;
2use std::sync::{Arc, Mutex};
3
4use serde::{Deserialize, Serialize};
5
6use super::{
7    PluginError, PluginFactory, PluginRegistrar, PluginSessionContext, PluginSnapshotMeta,
8    SessionPlugin, SnapshotReader, SnapshotWriter,
9};
10
11pub(crate) const SESSION_TRIGGER_PLUGIN_ID: &str = "lash.session_triggers";
12
13#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
14#[serde(transparent)]
15pub struct TriggerSourceType(String);
16
17impl TriggerSourceType {
18    pub fn new(value: impl Into<String>) -> Self {
19        Self(value.into())
20    }
21
22    pub fn as_str(&self) -> &str {
23        &self.0
24    }
25}
26
27impl From<String> for TriggerSourceType {
28    fn from(value: String) -> Self {
29        Self::new(value)
30    }
31}
32
33impl From<&str> for TriggerSourceType {
34    fn from(value: &str) -> Self {
35        Self::new(value)
36    }
37}
38
39impl AsRef<str> for TriggerSourceType {
40    fn as_ref(&self) -> &str {
41        self.as_str()
42    }
43}
44
45impl std::fmt::Display for TriggerSourceType {
46    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        formatter.write_str(self.as_str())
48    }
49}
50
51#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
52pub struct TriggerRegistration {
53    pub handle: String,
54    #[serde(default, skip_serializing_if = "Option::is_none")]
55    pub name: Option<String>,
56    pub source_type: TriggerSourceType,
57    pub source: serde_json::Value,
58    pub target: TriggerTargetSummary,
59    #[serde(default = "default_enabled")]
60    pub enabled: bool,
61}
62
63#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
64pub struct TriggerTargetSummary {
65    pub process_name: String,
66    pub inputs: lashlang::TriggerInputTemplate,
67}
68
69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
70pub(crate) struct SessionTriggerRoute {
71    pub handle: String,
72    #[serde(default, skip_serializing_if = "Option::is_none")]
73    pub name: Option<String>,
74    pub source_type: String,
75    pub source: serde_json::Value,
76    pub event_ty: lashlang::TypeExpr,
77    pub module_ref: lashlang::ModuleRef,
78    pub required_surface_ref: lashlang::RequiredSurfaceRef,
79    pub process_ref: lashlang::ProcessRef,
80    pub process_name: String,
81    pub input_template: lashlang::TriggerInputTemplate,
82    #[serde(default = "default_enabled")]
83    pub enabled: bool,
84}
85
86impl From<&SessionTriggerRoute> for TriggerRegistration {
87    fn from(route: &SessionTriggerRoute) -> Self {
88        Self {
89            handle: route.handle.clone(),
90            name: route.name.clone(),
91            source_type: TriggerSourceType::new(route.source_type.clone()),
92            source: route.source.clone(),
93            target: TriggerTargetSummary {
94                process_name: route.process_name.clone(),
95                inputs: route.input_template.clone(),
96            },
97            enabled: route.enabled,
98        }
99    }
100}
101
102fn default_enabled() -> bool {
103    true
104}
105
106#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
107struct SessionTriggerRegistryState {
108    #[serde(default)]
109    revision: u64,
110    #[serde(default)]
111    next_id: u64,
112    #[serde(default)]
113    routes: BTreeMap<String, SessionTriggerRoute>,
114}
115
116#[derive(Default)]
117pub(crate) struct SessionTriggerRegistry {
118    state: Mutex<SessionTriggerRegistryState>,
119}
120
121impl SessionTriggerRegistry {
122    pub(crate) async fn register_route(
123        &self,
124        request: serde_json::Value,
125        resources: &lashlang::ResourceCatalog,
126        artifact_store: &dyn lashlang::LashlangArtifactStore,
127    ) -> Result<SessionTriggerRoute, PluginError> {
128        let request = lashlang::TriggerRegistrationRequest::decode(&request)
129            .map_err(|err| PluginError::Session(err.to_string()))?;
130        let source_type = request.source.source_type.clone();
131        let source = request.source.to_json();
132        let event_type = lashlang::event_type_for_source(resources, &source_type)
133            .map_err(|err| PluginError::Session(err.to_string()))?;
134        let target = request.target;
135        let validation =
136            validate_target_process(&target, &event_type, &request.inputs, artifact_store).await?;
137
138        let mut state = self
139            .state
140            .lock()
141            .map_err(|_| PluginError::Session("trigger registry lock poisoned".to_string()))?;
142        state.next_id = state.next_id.saturating_add(1);
143        let handle = format!("trigger:{}", state.next_id);
144        let route = SessionTriggerRoute {
145            handle: handle.clone(),
146            name: request.name,
147            source_type,
148            source,
149            event_ty: validation.event_ty,
150            module_ref: target.module_ref,
151            required_surface_ref: target.required_surface_ref,
152            process_ref: target.process_ref,
153            process_name: target.process_name,
154            input_template: validation.inputs,
155            enabled: true,
156        };
157        state.routes.insert(handle, route.clone());
158        state.revision = state.revision.saturating_add(1);
159        Ok(route)
160    }
161
162    pub(crate) fn list(
163        &self,
164        request: serde_json::Value,
165    ) -> Result<Vec<TriggerRegistration>, PluginError> {
166        let request = lashlang::TriggerListRequest::decode(&request)
167            .map_err(|err| PluginError::Session(err.to_string()))?;
168        let state = self
169            .state
170            .lock()
171            .map_err(|_| PluginError::Session("trigger registry lock poisoned".to_string()))?;
172        Ok(state
173            .routes
174            .values()
175            .filter(|route| {
176                request.target.as_ref().is_none_or(|target| {
177                    target.matches(
178                        &route.module_ref,
179                        &route.required_surface_ref,
180                        &route.process_ref,
181                        &route.process_name,
182                    )
183                }) && request
184                    .name
185                    .as_deref()
186                    .is_none_or(|name| route.name.as_deref() == Some(name))
187                    && request
188                        .source_type
189                        .as_deref()
190                        .is_none_or(|source_type| route.source_type == source_type)
191                    && request
192                        .enabled
193                        .is_none_or(|enabled| route.enabled == enabled)
194            })
195            .map(TriggerRegistration::from)
196            .collect())
197    }
198
199    pub(crate) fn list_all(&self) -> Result<Vec<TriggerRegistration>, PluginError> {
200        let state = self
201            .state
202            .lock()
203            .map_err(|_| PluginError::Session("trigger registry lock poisoned".to_string()))?;
204        Ok(state
205            .routes
206            .values()
207            .map(TriggerRegistration::from)
208            .collect())
209    }
210
211    pub(crate) fn routes_by_source_type(
212        &self,
213        source_type: &TriggerSourceType,
214    ) -> Result<Vec<TriggerRegistration>, PluginError> {
215        let state = self
216            .state
217            .lock()
218            .map_err(|_| PluginError::Session("trigger registry lock poisoned".to_string()))?;
219        Ok(state
220            .routes
221            .values()
222            .filter(|route| route.source_type == source_type.as_str())
223            .map(TriggerRegistration::from)
224            .collect())
225    }
226
227    pub(crate) fn activation_routes_by_source_type(
228        &self,
229        source_type: &str,
230    ) -> Result<Vec<SessionTriggerRoute>, PluginError> {
231        let state = self
232            .state
233            .lock()
234            .map_err(|_| PluginError::Session("trigger registry lock poisoned".to_string()))?;
235        Ok(state
236            .routes
237            .values()
238            .filter(|route| route.source_type == source_type)
239            .cloned()
240            .collect())
241    }
242
243    pub(crate) fn cancel(&self, request: serde_json::Value) -> Result<bool, PluginError> {
244        let request = lashlang::TriggerCancelRequest::decode(&request)
245            .map_err(|err| PluginError::Session(err.to_string()))?;
246        let mut state = self
247            .state
248            .lock()
249            .map_err(|_| PluginError::Session("trigger registry lock poisoned".to_string()))?;
250        let Some(route) = state.routes.get_mut(&request.handle) else {
251            return Ok(false);
252        };
253        let changed = route.enabled;
254        route.enabled = false;
255        if changed {
256            state.revision = state.revision.saturating_add(1);
257        }
258        Ok(changed)
259    }
260
261    pub(crate) fn route(&self, handle: &str) -> Result<Option<SessionTriggerRoute>, PluginError> {
262        let state = self
263            .state
264            .lock()
265            .map_err(|_| PluginError::Session("trigger registry lock poisoned".to_string()))?;
266        Ok(state.routes.get(handle).cloned())
267    }
268
269    fn snapshot_state(&self) -> Result<SessionTriggerRegistryState, PluginError> {
270        self.state
271            .lock()
272            .map(|state| state.clone())
273            .map_err(|_| PluginError::Session("trigger registry lock poisoned".to_string()))
274    }
275
276    fn restore_state(&self, state: SessionTriggerRegistryState) -> Result<(), PluginError> {
277        let mut current = self
278            .state
279            .lock()
280            .map_err(|_| PluginError::Session("trigger registry lock poisoned".to_string()))?;
281        *current = state;
282        Ok(())
283    }
284
285    fn revision(&self) -> u64 {
286        self.state
287            .lock()
288            .map(|state| state.revision)
289            .unwrap_or_default()
290    }
291}
292
293pub(crate) struct SessionTriggerPluginFactory;
294
295impl PluginFactory for SessionTriggerPluginFactory {
296    fn id(&self) -> &'static str {
297        SESSION_TRIGGER_PLUGIN_ID
298    }
299
300    fn lashlang_resources(&self) -> lashlang::ResourceCatalog {
301        let mut resources = lashlang::ResourceCatalog::new();
302        lashlang::add_trigger_resource_operations(&mut resources);
303        resources
304    }
305
306    fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
307        Ok(Arc::new(SessionTriggerPlugin {
308            registry: Arc::new(SessionTriggerRegistry::default()),
309        }))
310    }
311}
312
313struct SessionTriggerPlugin {
314    registry: Arc<SessionTriggerRegistry>,
315}
316
317impl SessionPlugin for SessionTriggerPlugin {
318    fn id(&self) -> &'static str {
319        SESSION_TRIGGER_PLUGIN_ID
320    }
321
322    fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
323        reg.triggers().registry(Arc::clone(&self.registry))
324    }
325
326    fn snapshot(
327        &self,
328        _writer: &mut dyn SnapshotWriter,
329    ) -> Result<PluginSnapshotMeta, PluginError> {
330        Ok(PluginSnapshotMeta {
331            plugin_id: self.id().to_string(),
332            plugin_version: self.version().to_string(),
333            revision: self.snapshot_revision(),
334            state: Some(
335                serde_json::to_value(self.registry.snapshot_state()?).map_err(|err| {
336                    PluginError::Session(format!(
337                        "failed to encode trigger registry snapshot: {err}"
338                    ))
339                })?,
340            ),
341        })
342    }
343
344    fn snapshot_revision(&self) -> u64 {
345        self.registry.revision()
346    }
347
348    fn restore(
349        &self,
350        meta: &PluginSnapshotMeta,
351        _reader: &dyn SnapshotReader,
352    ) -> Result<(), PluginError> {
353        let Some(value) = meta.state.clone() else {
354            return self
355                .registry
356                .restore_state(SessionTriggerRegistryState::default());
357        };
358        let state: SessionTriggerRegistryState = serde_json::from_value(value).map_err(|err| {
359            PluginError::Session(format!("failed to decode trigger registry snapshot: {err}"))
360        })?;
361        self.registry.restore_state(state)
362    }
363}
364
365pub(super) fn trigger_handle_json(handle: &str) -> serde_json::Value {
366    serde_json::json!({
367        "type": "trigger_handle",
368        "id": handle,
369    })
370}
371
372async fn validate_target_process(
373    target: &lashlang::TriggerTargetIdentity,
374    event_ty: &lashlang::NamedDataType,
375    inputs: &lashlang::TriggerInputTemplate,
376    artifact_store: &dyn lashlang::LashlangArtifactStore,
377) -> Result<lashlang::TriggerTargetValidation, PluginError> {
378    let artifact = artifact_store
379        .get_module_artifact(&target.module_ref)
380        .await
381        .map_err(|err| PluginError::Session(format!("load trigger target artifact: {err}")))?
382        .ok_or_else(|| {
383            PluginError::Session(format!(
384                "missing trigger target artifact `{}`",
385                target.module_ref
386            ))
387        })?;
388    let validation = lashlang::validate_trigger_target(target, event_ty, inputs, &artifact)
389        .map_err(|err| PluginError::Session(err.to_string()))?;
390    Ok(validation)
391}