1use crate::functions::properties::FunctionProperties;
8use std::collections::HashMap;
9
10pub trait FunctionFamilyExtension: Send + Sync {
21 fn family_name(&self) -> &'static str;
23
24 fn get_properties(&self) -> HashMap<String, FunctionProperties>;
26
27 fn has_function(&self, name: &str) -> bool;
29
30 fn version(&self) -> (u32, u32, u32) {
32 (1, 0, 0) }
34
35 fn dependencies(&self) -> Vec<&'static str> {
37 vec![] }
39}
40
41pub struct ExtensionRegistry {
50 extensions: HashMap<&'static str, Box<dyn FunctionFamilyExtension>>,
52
53 cached_properties: Option<HashMap<String, FunctionProperties>>,
55
56 cache_version: u64,
58}
59
60impl Default for ExtensionRegistry {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl ExtensionRegistry {
67 pub fn new() -> Self {
69 Self {
70 extensions: HashMap::with_capacity(16), cached_properties: None,
72 cache_version: 0,
73 }
74 }
75
76 pub fn register_extension(
101 &mut self,
102 extension: Box<dyn FunctionFamilyExtension>,
103 ) -> Result<(), ExtensionError> {
104 let family_name = extension.family_name();
105
106 if self.extensions.contains_key(family_name) {
108 return Err(ExtensionError::FamilyAlreadyRegistered(
109 family_name.to_owned(),
110 ));
111 }
112
113 for dep in extension.dependencies() {
115 if !self.extensions.contains_key(dep) {
116 return Err(ExtensionError::MissingDependency {
117 extension: family_name.to_owned(),
118 dependency: dep.to_owned(),
119 });
120 }
121 }
122
123 self.extensions.insert(family_name, extension);
125
126 self.cached_properties = None;
128 self.cache_version += 1;
129
130 Ok(())
131 }
132
133 pub fn get_all_properties(&mut self) -> &HashMap<String, FunctionProperties> {
138 if self.cached_properties.is_none() {
139 let mut combined = HashMap::with_capacity(256);
140
141 for extension in self.extensions.values() {
142 combined.extend(extension.get_properties());
143 }
144
145 self.cached_properties = Some(combined);
146 }
147
148 self.cached_properties.as_ref().unwrap()
149 }
150
151 pub fn has_function(&self, name: &str) -> bool {
153 self.extensions.values().any(|ext| ext.has_function(name))
154 }
155
156 pub fn registered_families(&self) -> Vec<&'static str> {
158 self.extensions.keys().copied().collect()
159 }
160
161 pub fn get_extension(&self, family_name: &str) -> Option<&dyn FunctionFamilyExtension> {
163 self.extensions.get(family_name).map(|ext| ext.as_ref())
164 }
165}
166
167#[derive(Debug, Clone)]
169pub enum ExtensionError {
170 FamilyAlreadyRegistered(String),
172
173 MissingDependency {
175 extension: String,
176 dependency: String,
177 },
178
179 IncompatibleVersion {
181 extension: String,
182 required: (u32, u32, u32),
183 found: (u32, u32, u32),
184 },
185}
186
187impl std::fmt::Display for ExtensionError {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 match self {
190 ExtensionError::FamilyAlreadyRegistered(name) => {
191 write!(f, "Function family '{}' is already registered", name)
192 }
193 ExtensionError::MissingDependency {
194 extension,
195 dependency,
196 } => {
197 write!(
198 f,
199 "Extension '{}' requires '{}' which is not registered",
200 extension, dependency
201 )
202 }
203 ExtensionError::IncompatibleVersion {
204 extension,
205 required,
206 found,
207 } => {
208 write!(
209 f,
210 "Extension '{}' requires version {:?} but found {:?}",
211 extension, required, found
212 )
213 }
214 }
215 }
216}
217
218impl std::error::Error for ExtensionError {}
219
220pub trait FunctionValidator {
225 fn validate_mathematical_correctness(
227 &self,
228 name: &str,
229 test_points: &[(Vec<f64>, f64)],
230 ) -> ValidationResult;
231
232 fn validate_performance(&self, name: &str, benchmark_size: usize) -> ValidationResult;
234
235 fn validate_numerical_stability(&self, name: &str, edge_cases: &[f64]) -> ValidationResult;
237}
238
239#[derive(Debug, Clone)]
241pub struct ValidationResult {
242 pub passed: bool,
244
245 pub report: String,
247
248 pub metrics: Option<ValidationMetrics>,
250}
251
252#[derive(Debug, Clone)]
254pub struct ValidationMetrics {
255 pub ops_per_second: f64,
257
258 pub memory_usage: usize,
260
261 pub accuracy: f64,
263}
264
265pub struct DefaultValidator;
267
268impl FunctionValidator for DefaultValidator {
269 fn validate_mathematical_correctness(
270 &self,
271 name: &str,
272 test_points: &[(Vec<f64>, f64)],
273 ) -> ValidationResult {
274 ValidationResult {
276 passed: true,
277 report: format!(
278 "Mathematical correctness validated for {} with {} test points",
279 name,
280 test_points.len()
281 ),
282 metrics: None,
283 }
284 }
285
286 fn validate_performance(&self, name: &str, benchmark_size: usize) -> ValidationResult {
287 ValidationResult {
289 passed: true,
290 report: format!(
291 "Performance validated for {} with benchmark size {}",
292 name, benchmark_size
293 ),
294 metrics: Some(ValidationMetrics {
295 ops_per_second: 1_000_000.0, memory_usage: 1024, accuracy: 1e-15, }),
299 }
300 }
301
302 fn validate_numerical_stability(&self, name: &str, edge_cases: &[f64]) -> ValidationResult {
303 ValidationResult {
305 passed: true,
306 report: format!(
307 "Numerical stability validated for {} with {} edge cases",
308 name,
309 edge_cases.len()
310 ),
311 metrics: None,
312 }
313 }
314}
315
316#[macro_export]
321macro_rules! impl_function_family {
322 (
323 $name:ident,
324 family_name = $family_name:literal,
325 version = ($major:literal, $minor:literal, $patch:literal),
326 dependencies = [$($dep:literal),*],
327 functions = {
328 $(
329 $func_name:literal => $func_props:expr
330 ),* $(,)?
331 }
332 ) => {
333 pub struct $name;
334
335 impl $crate::functions::extensibility::FunctionFamilyExtension for $name {
336 fn family_name(&self) -> &'static str {
337 $family_name
338 }
339
340 fn version(&self) -> (u32, u32, u32) {
341 ($major, $minor, $patch)
342 }
343
344 fn dependencies(&self) -> Vec<&'static str> {
345 vec![$($dep),*]
346 }
347
348 fn get_properties(&self) -> std::collections::HashMap<String, $crate::functions::properties::FunctionProperties> {
349 let mut props = std::collections::HashMap::new();
350 $(
351 props.insert($func_name.to_string(), $func_props);
352 )*
353 props
354 }
355
356 fn has_function(&self, name: &str) -> bool {
357 matches!(name, $($func_name)|*)
358 }
359 }
360 };
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn test_extension_registry() {
369 let registry = ExtensionRegistry::new();
370 assert_eq!(registry.registered_families().len(), 0);
371
372 assert!(!registry.has_function("nonexistent"));
374 }
375
376 #[test]
377 fn test_validation_result() {
378 let result = ValidationResult {
379 passed: true,
380 report: "Test validation".to_string(),
381 metrics: None,
382 };
383
384 assert!(result.passed);
385 assert_eq!(result.report, "Test validation");
386 }
387}