Skip to main content

ergo_runtime/action/
registry.rs

1use std::collections::HashMap;
2
3use semver::Version;
4
5use super::{
6    ActionKind, ActionPrimitive, ActionPrimitiveManifest, ActionValidationError, ActionValueType,
7    OutputSpec, ParameterType,
8};
9use crate::common::{is_valid_id, ValueType};
10
11pub struct ActionRegistry {
12    primitives: HashMap<String, Box<dyn ActionPrimitive>>,
13}
14
15impl ActionRegistry {
16    pub fn new() -> Self {
17        Self {
18            primitives: HashMap::new(),
19        }
20    }
21
22    pub fn validate_manifest(
23        manifest: &ActionPrimitiveManifest,
24    ) -> Result<(), ActionValidationError> {
25        if !is_valid_id(&manifest.id) {
26            return Err(ActionValidationError::InvalidId {
27                id: manifest.id.clone(),
28            });
29        }
30
31        if Version::parse(&manifest.version).is_err() {
32            return Err(ActionValidationError::InvalidVersion {
33                version: manifest.version.clone(),
34            });
35        }
36
37        if manifest.kind != ActionKind::Action {
38            return Err(ActionValidationError::WrongKind {
39                expected: ActionKind::Action,
40                got: manifest.kind.clone(),
41            });
42        }
43
44        let mut seen_inputs: HashMap<&str, usize> = HashMap::new();
45        for (index, input) in manifest.inputs.iter().enumerate() {
46            if let Some(&first_index) = seen_inputs.get(input.name.as_str()) {
47                return Err(ActionValidationError::DuplicateInput {
48                    name: input.name.clone(),
49                    first_index,
50                    second_index: index,
51                });
52            }
53            seen_inputs.insert(&input.name, index);
54        }
55
56        for parameter in &manifest.parameters {
57            if let Some(default) = &parameter.default {
58                let got = default.value_type();
59                if got != parameter.value_type {
60                    return Err(ActionValidationError::InvalidParameterType {
61                        parameter: parameter.name.clone(),
62                        expected: parameter.value_type.clone(),
63                        got,
64                    });
65                }
66            }
67        }
68
69        if !manifest.side_effects {
70            return Err(ActionValidationError::SideEffectsRequired);
71        }
72
73        if manifest.execution.retryable {
74            return Err(ActionValidationError::RetryNotAllowed);
75        }
76
77        if !manifest.execution.deterministic {
78            return Err(ActionValidationError::NonDeterministicExecution);
79        }
80
81        if manifest.state.allowed {
82            return Err(ActionValidationError::StateNotAllowed);
83        }
84
85        if !manifest
86            .inputs
87            .iter()
88            .any(|input| input.value_type == ActionValueType::Event)
89        {
90            return Err(ActionValidationError::EventInputRequired);
91        }
92
93        let mut seen_writes: HashMap<&str, usize> = HashMap::new();
94        for (index, write) in manifest.effects.writes.iter().enumerate() {
95            if let Some(&first_index) = seen_writes.get(write.name.as_str()) {
96                return Err(ActionValidationError::DuplicateWriteName {
97                    name: write.name.clone(),
98                    first_index,
99                    second_index: index,
100                });
101            }
102            seen_writes.insert(&write.name, index);
103
104            if !matches!(
105                write.value_type,
106                ValueType::Number | ValueType::Series | ValueType::Bool | ValueType::String
107            ) {
108                return Err(ActionValidationError::InvalidWriteType {
109                    name: write.name.clone(),
110                    got: write.value_type.clone(),
111                });
112            }
113
114            // ACT-20/ACT-21: Validate $key references in write specs.
115            if let Some(param_name) = write.name.strip_prefix('$') {
116                let found = manifest.parameters.iter().find(|p| p.name == param_name);
117                match found {
118                    None => {
119                        return Err(ActionValidationError::UnboundWriteKeyReference {
120                            name: write.name.clone(),
121                            referenced_param: param_name.to_string(),
122                        });
123                    }
124                    Some(p) if p.value_type != ParameterType::String => {
125                        return Err(ActionValidationError::WriteKeyReferenceNotString {
126                            name: write.name.clone(),
127                            referenced_param: param_name.to_string(),
128                        });
129                    }
130                    _ => {}
131                }
132            }
133
134            // ACT-22: from_input must reference a declared input.
135            // ACT-23: Referenced input must be scalar (not Event) and type-compatible.
136            {
137                let input = manifest.inputs.iter().find(|i| i.name == write.from_input);
138                match input {
139                    None => {
140                        return Err(ActionValidationError::WriteFromInputNotFound {
141                            write_name: write.name.clone(),
142                            from_input: write.from_input.clone(),
143                        });
144                    }
145                    Some(inp)
146                        if !value_type_matches_action_input(&write.value_type, &inp.value_type) =>
147                    {
148                        return Err(ActionValidationError::WriteFromInputTypeMismatch {
149                            write_name: write.name.clone(),
150                            from_input: write.from_input.clone(),
151                            expected: write.value_type.clone(),
152                            found: inp.value_type.clone(),
153                        });
154                    }
155                    Some(_) => {}
156                }
157            }
158        }
159
160        let mut seen_intents: HashMap<&str, usize> = HashMap::new();
161        for (intent_index, intent) in manifest.effects.intents.iter().enumerate() {
162            if let Some(&first_index) = seen_intents.get(intent.name.as_str()) {
163                return Err(ActionValidationError::DuplicateIntentName {
164                    name: intent.name.clone(),
165                    first_index,
166                    second_index: intent_index,
167                });
168            }
169            seen_intents.insert(&intent.name, intent_index);
170
171            let mut seen_fields: HashMap<&str, usize> = HashMap::new();
172            let mut field_types: HashMap<&str, ValueType> = HashMap::new();
173
174            for (field_index, field) in intent.fields.iter().enumerate() {
175                if let Some(&first_index) = seen_fields.get(field.name.as_str()) {
176                    return Err(ActionValidationError::DuplicateIntentFieldName {
177                        intent_name: intent.name.clone(),
178                        field_name: field.name.clone(),
179                        first_index,
180                        second_index: field_index,
181                    });
182                }
183                seen_fields.insert(&field.name, field_index);
184                field_types.insert(&field.name, field.value_type.clone());
185
186                match (field.from_input.as_ref(), field.from_param.as_ref()) {
187                    (Some(_), Some(_)) => {
188                        return Err(ActionValidationError::IntentFieldMultipleSources {
189                            intent_name: intent.name.clone(),
190                            field_name: field.name.clone(),
191                        });
192                    }
193                    (None, None) => {
194                        return Err(ActionValidationError::IntentFieldMissingSource {
195                            intent_name: intent.name.clone(),
196                            field_name: field.name.clone(),
197                        });
198                    }
199                    (Some(from_input), None) => {
200                        let input = manifest.inputs.iter().find(|i| i.name == *from_input);
201                        match input {
202                            None => {
203                                return Err(ActionValidationError::IntentFieldFromInputNotFound {
204                                    intent_name: intent.name.clone(),
205                                    field_name: field.name.clone(),
206                                    from_input: from_input.clone(),
207                                });
208                            }
209                            Some(inp)
210                                if !value_type_matches_action_input(
211                                    &field.value_type,
212                                    &inp.value_type,
213                                ) =>
214                            {
215                                return Err(
216                                    ActionValidationError::IntentFieldFromInputTypeMismatch {
217                                        intent_name: intent.name.clone(),
218                                        field_name: field.name.clone(),
219                                        from_input: from_input.clone(),
220                                        expected: field.value_type.clone(),
221                                        found: inp.value_type.clone(),
222                                    },
223                                );
224                            }
225                            Some(_) => {}
226                        }
227                    }
228                    (None, Some(from_param)) => {
229                        let parameter = manifest.parameters.iter().find(|p| p.name == *from_param);
230                        match parameter {
231                            None => {
232                                return Err(ActionValidationError::IntentFieldFromParamNotFound {
233                                    intent_name: intent.name.clone(),
234                                    field_name: field.name.clone(),
235                                    from_param: from_param.clone(),
236                                });
237                            }
238                            Some(param)
239                                if !value_type_matches_parameter(
240                                    &field.value_type,
241                                    &param.value_type,
242                                ) =>
243                            {
244                                return Err(
245                                    ActionValidationError::IntentFieldFromParamTypeMismatch {
246                                        intent_name: intent.name.clone(),
247                                        field_name: field.name.clone(),
248                                        from_param: from_param.clone(),
249                                        expected: field.value_type.clone(),
250                                        found: param.value_type.clone(),
251                                    },
252                                );
253                            }
254                            Some(_) => {}
255                        }
256                    }
257                }
258            }
259
260            for mirror in &intent.mirror_writes {
261                let Some(field_type) = field_types.get(mirror.from_field.as_str()) else {
262                    return Err(ActionValidationError::MirrorWriteFromFieldNotFound {
263                        intent_name: intent.name.clone(),
264                        write_name: mirror.name.clone(),
265                        from_field: mirror.from_field.clone(),
266                    });
267                };
268
269                if mirror.value_type != *field_type {
270                    return Err(ActionValidationError::MirrorWriteTypeMismatch {
271                        intent_name: intent.name.clone(),
272                        write_name: mirror.name.clone(),
273                        from_field: mirror.from_field.clone(),
274                        expected: mirror.value_type.clone(),
275                        found: field_type.clone(),
276                    });
277                }
278            }
279        }
280
281        Self::validate_outputs(&manifest.outputs)?;
282
283        Ok(())
284    }
285
286    fn validate_outputs(outputs: &[OutputSpec]) -> Result<(), ActionValidationError> {
287        if outputs.len() != 1 {
288            return Err(ActionValidationError::UndeclaredOutput {
289                primitive: "action".to_string(),
290                output: "expected exactly one outcome event".to_string(),
291            });
292        }
293
294        let index = 0;
295        let output = &outputs[index];
296        if output.name != "outcome" {
297            return Err(ActionValidationError::OutputNotOutcome {
298                name: output.name.clone(),
299                index,
300            });
301        }
302
303        if output.value_type != ActionValueType::Event {
304            return Err(ActionValidationError::InvalidOutputType {
305                output: output.name.clone(),
306                expected: ActionValueType::Event,
307                got: output.value_type.clone(),
308            });
309        }
310
311        Ok(())
312    }
313
314    pub fn register(
315        &mut self,
316        primitive: Box<dyn ActionPrimitive>,
317    ) -> Result<(), ActionValidationError> {
318        let manifest = primitive.manifest();
319
320        Self::validate_manifest(manifest)?;
321
322        if self.primitives.contains_key(&manifest.id) {
323            return Err(ActionValidationError::DuplicateId(manifest.id.clone()));
324        }
325
326        self.primitives.insert(manifest.id.clone(), primitive);
327        Ok(())
328    }
329
330    pub fn get(&self, id: &str) -> Option<&dyn ActionPrimitive> {
331        self.primitives.get(id).map(|b| b.as_ref())
332    }
333
334    pub fn keys(&self) -> Vec<(String, String)> {
335        let mut keys: Vec<(String, String)> = self
336            .primitives
337            .values()
338            .map(|primitive| {
339                let manifest = primitive.manifest();
340                (manifest.id.clone(), manifest.version.clone())
341            })
342            .collect();
343        keys.sort();
344        keys
345    }
346}
347
348impl Default for ActionRegistry {
349    fn default() -> Self {
350        Self::new()
351    }
352}
353
354fn value_type_matches_action_input(expected: &ValueType, found: &ActionValueType) -> bool {
355    matches!(
356        (expected, found),
357        (ValueType::Number, ActionValueType::Number)
358            | (ValueType::Series, ActionValueType::Series)
359            | (ValueType::Bool, ActionValueType::Bool)
360            | (ValueType::String, ActionValueType::String)
361    )
362}
363
364fn value_type_matches_parameter(expected: &ValueType, found: &ParameterType) -> bool {
365    matches!(
366        (expected, found),
367        (ValueType::Number, ParameterType::Number)
368            | (ValueType::Bool, ParameterType::Bool)
369            | (ValueType::String, ParameterType::String)
370    )
371}
372
373#[cfg(test)]
374mod tests;