ergo_runtime/compute/
registry.rs1use std::collections::HashMap;
2
3use semver::Version;
4
5use crate::common::{is_valid_id, PrimitiveKind, ValidationError, ValueType};
6use crate::compute::{Cardinality, ComputePrimitive, ComputePrimitiveManifest};
7
8pub struct PrimitiveRegistry {
9 primitives: HashMap<String, Box<dyn ComputePrimitive>>,
10}
11
12impl PrimitiveRegistry {
13 pub fn new() -> Self {
14 Self {
15 primitives: HashMap::new(),
16 }
17 }
18
19 pub fn validate_manifest(manifest: &ComputePrimitiveManifest) -> Result<(), ValidationError> {
20 if !is_valid_id(&manifest.id) {
21 return Err(ValidationError::InvalidId {
22 id: manifest.id.clone(),
23 });
24 }
25
26 if Version::parse(&manifest.version).is_err() {
27 return Err(ValidationError::InvalidVersion {
28 version: manifest.version.clone(),
29 });
30 }
31
32 if manifest.kind != PrimitiveKind::Compute {
33 return Err(ValidationError::WrongKind {
34 expected: PrimitiveKind::Compute,
35 got: manifest.kind.clone(),
36 });
37 }
38
39 if manifest.execution.cadence != crate::compute::Cadence::Continuous {
40 return Err(ValidationError::InvalidCadence {
41 primitive: manifest.id.clone(),
42 });
43 }
44
45 if manifest.side_effects {
46 return Err(ValidationError::SideEffectsNotAllowed);
47 }
48
49 if !manifest.execution.deterministic {
50 return Err(ValidationError::NonDeterministicExecution);
51 }
52
53 if manifest.errors.allowed && !manifest.errors.deterministic {
54 return Err(ValidationError::NonDeterministicErrors {
55 primitive: manifest.id.clone(),
56 });
57 }
58
59 if manifest.inputs.is_empty() {
61 return Err(ValidationError::NoInputsDeclared {
62 primitive: manifest.id.clone(),
63 });
64 }
65
66 if manifest.outputs.is_empty() {
67 return Err(ValidationError::NoOutputsDeclared {
68 primitive: manifest.id.clone(),
69 });
70 }
71
72 let mut seen_inputs: HashMap<&str, usize> = HashMap::new();
73 for (index, input) in manifest.inputs.iter().enumerate() {
74 if let Some(&first_index) = seen_inputs.get(input.name.as_str()) {
75 return Err(ValidationError::DuplicateInput {
76 name: input.name.clone(),
77 first_index,
78 second_index: index,
79 });
80 }
81 seen_inputs.insert(&input.name, index);
82
83 if input.value_type == ValueType::String {
84 return Err(ValidationError::InvalidInputType {
85 input: input.name.clone(),
86 expected: ValueType::Number,
87 got: input.value_type.clone(),
88 });
89 }
90
91 if input.cardinality != Cardinality::Single {
92 return Err(ValidationError::InvalidInputCardinality {
93 primitive: manifest.id.clone(),
94 input: input.name.clone(),
95 got: format!("{:?}", input.cardinality),
96 });
97 }
98 }
99
100 let mut seen_outputs: HashMap<&str, usize> = HashMap::new();
101 for (index, output) in manifest.outputs.iter().enumerate() {
102 if let Some(&first_index) = seen_outputs.get(output.name.as_str()) {
103 return Err(ValidationError::DuplicateOutput {
104 name: output.name.clone(),
105 first_index,
106 second_index: index,
107 });
108 }
109 seen_outputs.insert(&output.name, index);
110
111 match output.value_type {
112 ValueType::Number | ValueType::Series | ValueType::Bool | ValueType::String => {}
113 }
114 }
115
116 for parameter in &manifest.parameters {
117 if parameter.value_type == ValueType::Series
118 || parameter.value_type == ValueType::String
119 {
120 return Err(ValidationError::UnsupportedParameterType {
121 primitive: manifest.id.clone(),
122 version: manifest.version.clone(),
123 parameter: parameter.name.clone(),
124 got: parameter.value_type.clone(),
125 });
126 }
127
128 if let Some(default) = ¶meter.default {
129 let got = default.value_type();
130 if got != parameter.value_type {
131 return Err(ValidationError::InvalidParameterType {
132 parameter: parameter.name.clone(),
133 expected: parameter.value_type.clone(),
134 got,
135 });
136 }
137 }
138 }
139
140 if manifest.state.allowed && !manifest.state.resettable {
141 return Err(ValidationError::StateNotResettable {
142 primitive: manifest.id.clone(),
143 });
144 }
145
146 Ok(())
147 }
148
149 pub fn register(
150 &mut self,
151 primitive: Box<dyn ComputePrimitive>,
152 ) -> Result<(), ValidationError> {
153 let manifest = primitive.manifest();
154
155 Self::validate_manifest(manifest)?;
156
157 if self.primitives.contains_key(&manifest.id) {
158 return Err(ValidationError::DuplicateId(manifest.id.clone()));
159 }
160
161 self.primitives.insert(manifest.id.clone(), primitive);
162 Ok(())
163 }
164
165 pub fn get(&self, id: &str) -> Option<&dyn ComputePrimitive> {
166 self.primitives.get(id).map(|b| b.as_ref())
167 }
168
169 pub fn keys(&self) -> Vec<(String, String)> {
170 let mut keys: Vec<(String, String)> = self
171 .primitives
172 .values()
173 .map(|primitive| {
174 let manifest = primitive.manifest();
175 (manifest.id.clone(), manifest.version.clone())
176 })
177 .collect();
178 keys.sort();
179 keys
180 }
181}
182
183impl Default for PrimitiveRegistry {
184 fn default() -> Self {
185 Self::new()
186 }
187}
188
189#[cfg(test)]
190mod tests;