use crate::functions::properties::FunctionProperties;
use std::collections::HashMap;
pub trait FunctionFamilyExtension: Send + Sync {
fn family_name(&self) -> &'static str;
fn get_properties(&self) -> HashMap<String, FunctionProperties>;
fn has_function(&self, name: &str) -> bool;
fn version(&self) -> (u32, u32, u32) {
(1, 0, 0) }
fn dependencies(&self) -> Vec<&'static str> {
vec![] }
}
pub struct ExtensionRegistry {
extensions: HashMap<&'static str, Box<dyn FunctionFamilyExtension>>,
cached_properties: Option<HashMap<String, FunctionProperties>>,
cache_version: u64,
}
impl Default for ExtensionRegistry {
fn default() -> Self {
Self::new()
}
}
impl ExtensionRegistry {
pub fn new() -> Self {
Self {
extensions: HashMap::with_capacity(16), cached_properties: None,
cache_version: 0,
}
}
pub fn register_extension(
&mut self,
extension: Box<dyn FunctionFamilyExtension>,
) -> Result<(), ExtensionError> {
let family_name = extension.family_name();
if self.extensions.contains_key(family_name) {
return Err(ExtensionError::FamilyAlreadyRegistered(
family_name.to_owned(),
));
}
for dep in extension.dependencies() {
if !self.extensions.contains_key(dep) {
return Err(ExtensionError::MissingDependency {
extension: family_name.to_owned(),
dependency: dep.to_owned(),
});
}
}
self.extensions.insert(family_name, extension);
self.cached_properties = None;
self.cache_version += 1;
Ok(())
}
pub fn get_all_properties(&mut self) -> &HashMap<String, FunctionProperties> {
if self.cached_properties.is_none() {
let mut combined = HashMap::with_capacity(256);
for extension in self.extensions.values() {
combined.extend(extension.get_properties());
}
self.cached_properties = Some(combined);
}
self.cached_properties.as_ref().unwrap()
}
pub fn has_function(&self, name: &str) -> bool {
self.extensions.values().any(|ext| ext.has_function(name))
}
pub fn registered_families(&self) -> Vec<&'static str> {
self.extensions.keys().copied().collect()
}
pub fn get_extension(&self, family_name: &str) -> Option<&dyn FunctionFamilyExtension> {
self.extensions.get(family_name).map(|ext| ext.as_ref())
}
}
#[derive(Debug, Clone)]
pub enum ExtensionError {
FamilyAlreadyRegistered(String),
MissingDependency {
extension: String,
dependency: String,
},
IncompatibleVersion {
extension: String,
required: (u32, u32, u32),
found: (u32, u32, u32),
},
}
impl std::fmt::Display for ExtensionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExtensionError::FamilyAlreadyRegistered(name) => {
write!(f, "Function family '{}' is already registered", name)
}
ExtensionError::MissingDependency {
extension,
dependency,
} => {
write!(
f,
"Extension '{}' requires '{}' which is not registered",
extension, dependency
)
}
ExtensionError::IncompatibleVersion {
extension,
required,
found,
} => {
write!(
f,
"Extension '{}' requires version {:?} but found {:?}",
extension, required, found
)
}
}
}
}
impl std::error::Error for ExtensionError {}
pub trait FunctionValidator {
fn validate_mathematical_correctness(
&self,
name: &str,
test_points: &[(Vec<f64>, f64)],
) -> ValidationResult;
fn validate_performance(&self, name: &str, benchmark_size: usize) -> ValidationResult;
fn validate_numerical_stability(&self, name: &str, edge_cases: &[f64]) -> ValidationResult;
}
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub passed: bool,
pub report: String,
pub metrics: Option<ValidationMetrics>,
}
#[derive(Debug, Clone)]
pub struct ValidationMetrics {
pub ops_per_second: f64,
pub memory_usage: usize,
pub accuracy: f64,
}
pub struct DefaultValidator;
impl FunctionValidator for DefaultValidator {
fn validate_mathematical_correctness(
&self,
name: &str,
test_points: &[(Vec<f64>, f64)],
) -> ValidationResult {
ValidationResult {
passed: true,
report: format!(
"Mathematical correctness validated for {} with {} test points",
name,
test_points.len()
),
metrics: None,
}
}
fn validate_performance(&self, name: &str, benchmark_size: usize) -> ValidationResult {
ValidationResult {
passed: true,
report: format!(
"Performance validated for {} with benchmark size {}",
name, benchmark_size
),
metrics: Some(ValidationMetrics {
ops_per_second: 1_000_000.0, memory_usage: 1024, accuracy: 1e-15, }),
}
}
fn validate_numerical_stability(&self, name: &str, edge_cases: &[f64]) -> ValidationResult {
ValidationResult {
passed: true,
report: format!(
"Numerical stability validated for {} with {} edge cases",
name,
edge_cases.len()
),
metrics: None,
}
}
}
#[macro_export]
macro_rules! impl_function_family {
(
$name:ident,
family_name = $family_name:literal,
version = ($major:literal, $minor:literal, $patch:literal),
dependencies = [$($dep:literal),*],
functions = {
$(
$func_name:literal => $func_props:expr
),* $(,)?
}
) => {
pub struct $name;
impl $crate::functions::extensibility::FunctionFamilyExtension for $name {
fn family_name(&self) -> &'static str {
$family_name
}
fn version(&self) -> (u32, u32, u32) {
($major, $minor, $patch)
}
fn dependencies(&self) -> Vec<&'static str> {
vec![$($dep),*]
}
fn get_properties(&self) -> std::collections::HashMap<String, $crate::functions::properties::FunctionProperties> {
let mut props = std::collections::HashMap::new();
$(
props.insert($func_name.to_string(), $func_props);
)*
props
}
fn has_function(&self, name: &str) -> bool {
matches!(name, $($func_name)|*)
}
}
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extension_registry() {
let registry = ExtensionRegistry::new();
assert_eq!(registry.registered_families().len(), 0);
assert!(!registry.has_function("nonexistent"));
}
#[test]
fn test_validation_result() {
let result = ValidationResult {
passed: true,
report: "Test validation".to_string(),
metrics: None,
};
assert!(result.passed);
assert_eq!(result.report, "Test validation");
}
}