Skip to main content

ergo_runtime/source/
registry.rs

1use std::collections::HashMap;
2
3use semver::Version;
4
5use super::{Cadence, SourceKind, SourcePrimitive, SourcePrimitiveManifest, SourceValidationError};
6use crate::common::{is_valid_id, ValueType};
7
8pub struct SourceRegistry {
9    primitives: HashMap<String, Box<dyn SourcePrimitive>>,
10}
11
12impl SourceRegistry {
13    pub fn new() -> Self {
14        Self {
15            primitives: HashMap::new(),
16        }
17    }
18
19    pub fn validate_manifest(
20        manifest: &SourcePrimitiveManifest,
21    ) -> Result<(), SourceValidationError> {
22        if !is_valid_id(&manifest.id) {
23            return Err(SourceValidationError::InvalidId {
24                id: manifest.id.clone(),
25            });
26        }
27
28        if Version::parse(&manifest.version).is_err() {
29            return Err(SourceValidationError::InvalidVersion {
30                version: manifest.version.clone(),
31            });
32        }
33
34        if manifest.kind != SourceKind::Source {
35            return Err(SourceValidationError::WrongKind {
36                expected: SourceKind::Source,
37                got: manifest.kind.clone(),
38            });
39        }
40
41        if !manifest.inputs.is_empty() {
42            return Err(SourceValidationError::InputsNotAllowed);
43        }
44
45        if manifest.outputs.is_empty() {
46            return Err(SourceValidationError::OutputsRequired);
47        }
48
49        let mut seen: HashMap<&str, usize> = HashMap::new();
50        for (index, output) in manifest.outputs.iter().enumerate() {
51            if let Some(&first_index) = seen.get(output.name.as_str()) {
52                return Err(SourceValidationError::DuplicateOutput {
53                    name: output.name.clone(),
54                    first_index,
55                    second_index: index,
56                });
57            }
58            seen.insert(&output.name, index);
59        }
60
61        for output in &manifest.outputs {
62            match output.value_type {
63                ValueType::Number | ValueType::Series | ValueType::Bool | ValueType::String => {}
64            }
65        }
66
67        for parameter in &manifest.parameters {
68            if let Some(default) = &parameter.default {
69                let got = default.value_type();
70                if got != parameter.value_type {
71                    return Err(SourceValidationError::InvalidParameterType {
72                        parameter: parameter.name.clone(),
73                        expected: parameter.value_type.clone(),
74                        got,
75                    });
76                }
77            }
78        }
79
80        // SRC-16/SRC-17: Validate $key references in context requirements.
81        for req in &manifest.requires.context {
82            if let Some(param_name) = req.name.strip_prefix('$') {
83                let found = manifest.parameters.iter().find(|p| p.name == param_name);
84                match found {
85                    None => {
86                        return Err(SourceValidationError::UnboundContextKeyReference {
87                            name: req.name.clone(),
88                            referenced_param: param_name.to_string(),
89                        });
90                    }
91                    Some(p) if p.value_type != super::ParameterType::String => {
92                        return Err(SourceValidationError::ContextKeyReferenceNotString {
93                            name: req.name.clone(),
94                            referenced_param: param_name.to_string(),
95                        });
96                    }
97                    _ => {}
98                }
99            }
100        }
101
102        if manifest.side_effects {
103            return Err(SourceValidationError::SideEffectsNotAllowed);
104        }
105
106        if !manifest.execution.deterministic {
107            return Err(SourceValidationError::NonDeterministicExecution);
108        }
109
110        if manifest.execution.cadence != Cadence::Continuous {
111            return Err(SourceValidationError::InvalidCadence);
112        }
113
114        if manifest.state.allowed {
115            return Err(SourceValidationError::StateNotAllowed);
116        }
117
118        Ok(())
119    }
120
121    pub fn register(
122        &mut self,
123        primitive: Box<dyn SourcePrimitive>,
124    ) -> Result<(), SourceValidationError> {
125        let manifest = primitive.manifest();
126
127        Self::validate_manifest(manifest)?;
128
129        if self.primitives.contains_key(&manifest.id) {
130            return Err(SourceValidationError::DuplicateId(manifest.id.clone()));
131        }
132
133        self.primitives.insert(manifest.id.clone(), primitive);
134        Ok(())
135    }
136
137    pub fn get(&self, id: &str) -> Option<&dyn SourcePrimitive> {
138        self.primitives.get(id).map(|b| b.as_ref())
139    }
140
141    pub fn keys(&self) -> Vec<(String, String)> {
142        let mut keys: Vec<(String, String)> = self
143            .primitives
144            .values()
145            .map(|primitive| {
146                let manifest = primitive.manifest();
147                (manifest.id.clone(), manifest.version.clone())
148            })
149            .collect();
150        keys.sort();
151        keys
152    }
153}
154
155impl Default for SourceRegistry {
156    fn default() -> Self {
157        Self::new()
158    }
159}