Skip to main content

ergo_runtime/compute/
registry.rs

1use std::collections::HashMap;
2
3use semver::Version;
4
5use crate::common::{is_valid_id, PrimitiveKind, ValidationError, ValueType};
6use crate::compute::{Cardinality, ComputePrimitive, ComputePrimitiveManifest};
7
8pub struct PrimitiveRegistry {
9    primitives: HashMap<String, Box<dyn ComputePrimitive>>,
10}
11
12impl PrimitiveRegistry {
13    pub fn new() -> Self {
14        Self {
15            primitives: HashMap::new(),
16        }
17    }
18
19    pub fn validate_manifest(manifest: &ComputePrimitiveManifest) -> Result<(), ValidationError> {
20        if !is_valid_id(&manifest.id) {
21            return Err(ValidationError::InvalidId {
22                id: manifest.id.clone(),
23            });
24        }
25
26        if Version::parse(&manifest.version).is_err() {
27            return Err(ValidationError::InvalidVersion {
28                version: manifest.version.clone(),
29            });
30        }
31
32        if manifest.kind != PrimitiveKind::Compute {
33            return Err(ValidationError::WrongKind {
34                expected: PrimitiveKind::Compute,
35                got: manifest.kind.clone(),
36            });
37        }
38
39        if manifest.execution.cadence != crate::compute::Cadence::Continuous {
40            return Err(ValidationError::InvalidCadence {
41                primitive: manifest.id.clone(),
42            });
43        }
44
45        if manifest.side_effects {
46            return Err(ValidationError::SideEffectsNotAllowed);
47        }
48
49        if !manifest.execution.deterministic {
50            return Err(ValidationError::NonDeterministicExecution);
51        }
52
53        if manifest.errors.allowed && !manifest.errors.deterministic {
54            return Err(ValidationError::NonDeterministicErrors {
55                primitive: manifest.id.clone(),
56            });
57        }
58
59        // X.7: Compute primitives must declare at least one input.
60        if manifest.inputs.is_empty() {
61            return Err(ValidationError::NoInputsDeclared {
62                primitive: manifest.id.clone(),
63            });
64        }
65
66        if manifest.outputs.is_empty() {
67            return Err(ValidationError::NoOutputsDeclared {
68                primitive: manifest.id.clone(),
69            });
70        }
71
72        let mut seen_inputs: HashMap<&str, usize> = HashMap::new();
73        for (index, input) in manifest.inputs.iter().enumerate() {
74            if let Some(&first_index) = seen_inputs.get(input.name.as_str()) {
75                return Err(ValidationError::DuplicateInput {
76                    name: input.name.clone(),
77                    first_index,
78                    second_index: index,
79                });
80            }
81            seen_inputs.insert(&input.name, index);
82
83            if input.value_type == ValueType::String {
84                return Err(ValidationError::InvalidInputType {
85                    input: input.name.clone(),
86                    expected: ValueType::Number,
87                    got: input.value_type.clone(),
88                });
89            }
90
91            if input.cardinality != Cardinality::Single {
92                return Err(ValidationError::InvalidInputCardinality {
93                    primitive: manifest.id.clone(),
94                    input: input.name.clone(),
95                    got: format!("{:?}", input.cardinality),
96                });
97            }
98        }
99
100        let mut seen_outputs: HashMap<&str, usize> = HashMap::new();
101        for (index, output) in manifest.outputs.iter().enumerate() {
102            if let Some(&first_index) = seen_outputs.get(output.name.as_str()) {
103                return Err(ValidationError::DuplicateOutput {
104                    name: output.name.clone(),
105                    first_index,
106                    second_index: index,
107                });
108            }
109            seen_outputs.insert(&output.name, index);
110
111            match output.value_type {
112                ValueType::Number | ValueType::Series | ValueType::Bool | ValueType::String => {}
113            }
114        }
115
116        for parameter in &manifest.parameters {
117            if parameter.value_type == ValueType::Series
118                || parameter.value_type == ValueType::String
119            {
120                return Err(ValidationError::UnsupportedParameterType {
121                    primitive: manifest.id.clone(),
122                    version: manifest.version.clone(),
123                    parameter: parameter.name.clone(),
124                    got: parameter.value_type.clone(),
125                });
126            }
127
128            if let Some(default) = &parameter.default {
129                let got = default.value_type();
130                if got != parameter.value_type {
131                    return Err(ValidationError::InvalidParameterType {
132                        parameter: parameter.name.clone(),
133                        expected: parameter.value_type.clone(),
134                        got,
135                    });
136                }
137            }
138        }
139
140        if manifest.state.allowed && !manifest.state.resettable {
141            return Err(ValidationError::StateNotResettable {
142                primitive: manifest.id.clone(),
143            });
144        }
145
146        Ok(())
147    }
148
149    pub fn register(
150        &mut self,
151        primitive: Box<dyn ComputePrimitive>,
152    ) -> Result<(), ValidationError> {
153        let manifest = primitive.manifest();
154
155        Self::validate_manifest(manifest)?;
156
157        if self.primitives.contains_key(&manifest.id) {
158            return Err(ValidationError::DuplicateId(manifest.id.clone()));
159        }
160
161        self.primitives.insert(manifest.id.clone(), primitive);
162        Ok(())
163    }
164
165    pub fn get(&self, id: &str) -> Option<&dyn ComputePrimitive> {
166        self.primitives.get(id).map(|b| b.as_ref())
167    }
168
169    pub fn keys(&self) -> Vec<(String, String)> {
170        let mut keys: Vec<(String, String)> = self
171            .primitives
172            .values()
173            .map(|primitive| {
174                let manifest = primitive.manifest();
175                (manifest.id.clone(), manifest.version.clone())
176            })
177            .collect();
178        keys.sort();
179        keys
180    }
181}
182
183impl Default for PrimitiveRegistry {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189#[cfg(test)]
190mod tests;