1use std::collections::HashMap;
2
3use semver::Version;
4
5use super::{
6 ActionKind, ActionPrimitive, ActionPrimitiveManifest, ActionValidationError, ActionValueType,
7 OutputSpec, ParameterType,
8};
9use crate::common::{is_valid_id, ValueType};
10
11pub struct ActionRegistry {
12 primitives: HashMap<String, Box<dyn ActionPrimitive>>,
13}
14
15impl ActionRegistry {
16 pub fn new() -> Self {
17 Self {
18 primitives: HashMap::new(),
19 }
20 }
21
22 pub fn validate_manifest(
23 manifest: &ActionPrimitiveManifest,
24 ) -> Result<(), ActionValidationError> {
25 if !is_valid_id(&manifest.id) {
26 return Err(ActionValidationError::InvalidId {
27 id: manifest.id.clone(),
28 });
29 }
30
31 if Version::parse(&manifest.version).is_err() {
32 return Err(ActionValidationError::InvalidVersion {
33 version: manifest.version.clone(),
34 });
35 }
36
37 if manifest.kind != ActionKind::Action {
38 return Err(ActionValidationError::WrongKind {
39 expected: ActionKind::Action,
40 got: manifest.kind.clone(),
41 });
42 }
43
44 let mut seen_inputs: HashMap<&str, usize> = HashMap::new();
45 for (index, input) in manifest.inputs.iter().enumerate() {
46 if let Some(&first_index) = seen_inputs.get(input.name.as_str()) {
47 return Err(ActionValidationError::DuplicateInput {
48 name: input.name.clone(),
49 first_index,
50 second_index: index,
51 });
52 }
53 seen_inputs.insert(&input.name, index);
54 }
55
56 for parameter in &manifest.parameters {
57 if let Some(default) = ¶meter.default {
58 let got = default.value_type();
59 if got != parameter.value_type {
60 return Err(ActionValidationError::InvalidParameterType {
61 parameter: parameter.name.clone(),
62 expected: parameter.value_type.clone(),
63 got,
64 });
65 }
66 }
67 }
68
69 if !manifest.side_effects {
70 return Err(ActionValidationError::SideEffectsRequired);
71 }
72
73 if manifest.execution.retryable {
74 return Err(ActionValidationError::RetryNotAllowed);
75 }
76
77 if !manifest.execution.deterministic {
78 return Err(ActionValidationError::NonDeterministicExecution);
79 }
80
81 if manifest.state.allowed {
82 return Err(ActionValidationError::StateNotAllowed);
83 }
84
85 if !manifest
86 .inputs
87 .iter()
88 .any(|input| input.value_type == ActionValueType::Event)
89 {
90 return Err(ActionValidationError::EventInputRequired);
91 }
92
93 let mut seen_writes: HashMap<&str, usize> = HashMap::new();
94 for (index, write) in manifest.effects.writes.iter().enumerate() {
95 if let Some(&first_index) = seen_writes.get(write.name.as_str()) {
96 return Err(ActionValidationError::DuplicateWriteName {
97 name: write.name.clone(),
98 first_index,
99 second_index: index,
100 });
101 }
102 seen_writes.insert(&write.name, index);
103
104 if !matches!(
105 write.value_type,
106 ValueType::Number | ValueType::Series | ValueType::Bool | ValueType::String
107 ) {
108 return Err(ActionValidationError::InvalidWriteType {
109 name: write.name.clone(),
110 got: write.value_type.clone(),
111 });
112 }
113
114 if let Some(param_name) = write.name.strip_prefix('$') {
116 let found = manifest.parameters.iter().find(|p| p.name == param_name);
117 match found {
118 None => {
119 return Err(ActionValidationError::UnboundWriteKeyReference {
120 name: write.name.clone(),
121 referenced_param: param_name.to_string(),
122 });
123 }
124 Some(p) if p.value_type != ParameterType::String => {
125 return Err(ActionValidationError::WriteKeyReferenceNotString {
126 name: write.name.clone(),
127 referenced_param: param_name.to_string(),
128 });
129 }
130 _ => {}
131 }
132 }
133
134 {
137 let input = manifest.inputs.iter().find(|i| i.name == write.from_input);
138 match input {
139 None => {
140 return Err(ActionValidationError::WriteFromInputNotFound {
141 write_name: write.name.clone(),
142 from_input: write.from_input.clone(),
143 });
144 }
145 Some(inp)
146 if !value_type_matches_action_input(&write.value_type, &inp.value_type) =>
147 {
148 return Err(ActionValidationError::WriteFromInputTypeMismatch {
149 write_name: write.name.clone(),
150 from_input: write.from_input.clone(),
151 expected: write.value_type.clone(),
152 found: inp.value_type.clone(),
153 });
154 }
155 Some(_) => {}
156 }
157 }
158 }
159
160 let mut seen_intents: HashMap<&str, usize> = HashMap::new();
161 for (intent_index, intent) in manifest.effects.intents.iter().enumerate() {
162 if let Some(&first_index) = seen_intents.get(intent.name.as_str()) {
163 return Err(ActionValidationError::DuplicateIntentName {
164 name: intent.name.clone(),
165 first_index,
166 second_index: intent_index,
167 });
168 }
169 seen_intents.insert(&intent.name, intent_index);
170
171 let mut seen_fields: HashMap<&str, usize> = HashMap::new();
172 let mut field_types: HashMap<&str, ValueType> = HashMap::new();
173
174 for (field_index, field) in intent.fields.iter().enumerate() {
175 if let Some(&first_index) = seen_fields.get(field.name.as_str()) {
176 return Err(ActionValidationError::DuplicateIntentFieldName {
177 intent_name: intent.name.clone(),
178 field_name: field.name.clone(),
179 first_index,
180 second_index: field_index,
181 });
182 }
183 seen_fields.insert(&field.name, field_index);
184 field_types.insert(&field.name, field.value_type.clone());
185
186 match (field.from_input.as_ref(), field.from_param.as_ref()) {
187 (Some(_), Some(_)) => {
188 return Err(ActionValidationError::IntentFieldMultipleSources {
189 intent_name: intent.name.clone(),
190 field_name: field.name.clone(),
191 });
192 }
193 (None, None) => {
194 return Err(ActionValidationError::IntentFieldMissingSource {
195 intent_name: intent.name.clone(),
196 field_name: field.name.clone(),
197 });
198 }
199 (Some(from_input), None) => {
200 let input = manifest.inputs.iter().find(|i| i.name == *from_input);
201 match input {
202 None => {
203 return Err(ActionValidationError::IntentFieldFromInputNotFound {
204 intent_name: intent.name.clone(),
205 field_name: field.name.clone(),
206 from_input: from_input.clone(),
207 });
208 }
209 Some(inp)
210 if !value_type_matches_action_input(
211 &field.value_type,
212 &inp.value_type,
213 ) =>
214 {
215 return Err(
216 ActionValidationError::IntentFieldFromInputTypeMismatch {
217 intent_name: intent.name.clone(),
218 field_name: field.name.clone(),
219 from_input: from_input.clone(),
220 expected: field.value_type.clone(),
221 found: inp.value_type.clone(),
222 },
223 );
224 }
225 Some(_) => {}
226 }
227 }
228 (None, Some(from_param)) => {
229 let parameter = manifest.parameters.iter().find(|p| p.name == *from_param);
230 match parameter {
231 None => {
232 return Err(ActionValidationError::IntentFieldFromParamNotFound {
233 intent_name: intent.name.clone(),
234 field_name: field.name.clone(),
235 from_param: from_param.clone(),
236 });
237 }
238 Some(param)
239 if !value_type_matches_parameter(
240 &field.value_type,
241 ¶m.value_type,
242 ) =>
243 {
244 return Err(
245 ActionValidationError::IntentFieldFromParamTypeMismatch {
246 intent_name: intent.name.clone(),
247 field_name: field.name.clone(),
248 from_param: from_param.clone(),
249 expected: field.value_type.clone(),
250 found: param.value_type.clone(),
251 },
252 );
253 }
254 Some(_) => {}
255 }
256 }
257 }
258 }
259
260 for mirror in &intent.mirror_writes {
261 let Some(field_type) = field_types.get(mirror.from_field.as_str()) else {
262 return Err(ActionValidationError::MirrorWriteFromFieldNotFound {
263 intent_name: intent.name.clone(),
264 write_name: mirror.name.clone(),
265 from_field: mirror.from_field.clone(),
266 });
267 };
268
269 if mirror.value_type != *field_type {
270 return Err(ActionValidationError::MirrorWriteTypeMismatch {
271 intent_name: intent.name.clone(),
272 write_name: mirror.name.clone(),
273 from_field: mirror.from_field.clone(),
274 expected: mirror.value_type.clone(),
275 found: field_type.clone(),
276 });
277 }
278 }
279 }
280
281 Self::validate_outputs(&manifest.outputs)?;
282
283 Ok(())
284 }
285
286 fn validate_outputs(outputs: &[OutputSpec]) -> Result<(), ActionValidationError> {
287 if outputs.len() != 1 {
288 return Err(ActionValidationError::UndeclaredOutput {
289 primitive: "action".to_string(),
290 output: "expected exactly one outcome event".to_string(),
291 });
292 }
293
294 let index = 0;
295 let output = &outputs[index];
296 if output.name != "outcome" {
297 return Err(ActionValidationError::OutputNotOutcome {
298 name: output.name.clone(),
299 index,
300 });
301 }
302
303 if output.value_type != ActionValueType::Event {
304 return Err(ActionValidationError::InvalidOutputType {
305 output: output.name.clone(),
306 expected: ActionValueType::Event,
307 got: output.value_type.clone(),
308 });
309 }
310
311 Ok(())
312 }
313
314 pub fn register(
315 &mut self,
316 primitive: Box<dyn ActionPrimitive>,
317 ) -> Result<(), ActionValidationError> {
318 let manifest = primitive.manifest();
319
320 Self::validate_manifest(manifest)?;
321
322 if self.primitives.contains_key(&manifest.id) {
323 return Err(ActionValidationError::DuplicateId(manifest.id.clone()));
324 }
325
326 self.primitives.insert(manifest.id.clone(), primitive);
327 Ok(())
328 }
329
330 pub fn get(&self, id: &str) -> Option<&dyn ActionPrimitive> {
331 self.primitives.get(id).map(|b| b.as_ref())
332 }
333
334 pub fn keys(&self) -> Vec<(String, String)> {
335 let mut keys: Vec<(String, String)> = self
336 .primitives
337 .values()
338 .map(|primitive| {
339 let manifest = primitive.manifest();
340 (manifest.id.clone(), manifest.version.clone())
341 })
342 .collect();
343 keys.sort();
344 keys
345 }
346}
347
348impl Default for ActionRegistry {
349 fn default() -> Self {
350 Self::new()
351 }
352}
353
354fn value_type_matches_action_input(expected: &ValueType, found: &ActionValueType) -> bool {
355 matches!(
356 (expected, found),
357 (ValueType::Number, ActionValueType::Number)
358 | (ValueType::Series, ActionValueType::Series)
359 | (ValueType::Bool, ActionValueType::Bool)
360 | (ValueType::String, ActionValueType::String)
361 )
362}
363
364fn value_type_matches_parameter(expected: &ValueType, found: &ParameterType) -> bool {
365 matches!(
366 (expected, found),
367 (ValueType::Number, ParameterType::Number)
368 | (ValueType::Bool, ParameterType::Bool)
369 | (ValueType::String, ParameterType::String)
370 )
371}
372
373#[cfg(test)]
374mod tests;