use crate::dsl_impl::{
code_generators::{
generate_feature_engineering_code, generate_hyperparameter_code, generate_pipeline_code,
},
parsers::{parse_feature_engineering, parse_hyperparameter_config, parse_ml_pipeline},
};
use proc_macro2::TokenStream;
pub fn ml_pipeline_impl(input: TokenStream) -> TokenStream {
match parse_ml_pipeline(input) {
Ok(pipeline) => generate_pipeline_code(pipeline),
Err(err) => err.to_compile_error(),
}
}
pub fn feature_engineering_impl(input: TokenStream) -> TokenStream {
match parse_feature_engineering(input) {
Ok(config) => generate_feature_engineering_code(config),
Err(err) => err.to_compile_error(),
}
}
pub fn hyperparameter_config_impl(input: TokenStream) -> TokenStream {
match parse_hyperparameter_config(input) {
Ok(config) => generate_hyperparameter_code(config),
Err(err) => err.to_compile_error(),
}
}
pub fn model_evaluation_impl(_input: TokenStream) -> TokenStream {
quote::quote! {
compile_error!("model_evaluation! macro not yet implemented");
}
}
pub fn data_pipeline_impl(_input: TokenStream) -> TokenStream {
quote::quote! {
compile_error!("data_pipeline! macro not yet implemented");
}
}
pub fn experiment_config_impl(_input: TokenStream) -> TokenStream {
quote::quote! {
compile_error!("experiment_config! macro not yet implemented");
}
}
pub fn handle_macro_error(error: syn::Error, context: &str) -> TokenStream {
let error_msg = format!("DSL macro error in {}: {}", context, error);
quote::quote! {
compile_error!(#error_msg);
}
}
pub struct MacroRegistry {
implementations: std::collections::HashMap<String, fn(TokenStream) -> TokenStream>,
}
impl MacroRegistry {
pub fn new() -> Self {
let mut registry = Self {
implementations: std::collections::HashMap::new(),
};
registry.register("ml_pipeline", ml_pipeline_impl);
registry.register("feature_engineering", feature_engineering_impl);
registry.register("hyperparameter_config", hyperparameter_config_impl);
registry.register("model_evaluation", model_evaluation_impl);
registry.register("data_pipeline", data_pipeline_impl);
registry.register("experiment_config", experiment_config_impl);
registry
}
pub fn register(&mut self, name: &str, implementation: fn(TokenStream) -> TokenStream) {
self.implementations
.insert(name.to_string(), implementation);
}
pub fn execute(&self, name: &str, input: TokenStream) -> TokenStream {
if let Some(implementation) = self.implementations.get(name) {
implementation(input)
} else {
let error_msg = format!("Unknown macro: {}", name);
quote::quote! {
compile_error!(#error_msg);
}
}
}
pub fn list_macros(&self) -> Vec<String> {
self.implementations.keys().cloned().collect()
}
}
impl Default for MacroRegistry {
fn default() -> Self {
Self::new()
}
}
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::*;
use quote::quote;
#[test]
fn test_macro_registry_creation() {
let registry = MacroRegistry::new();
let macros = registry.list_macros();
assert!(macros.contains(&"ml_pipeline".to_string()));
assert!(macros.contains(&"feature_engineering".to_string()));
assert!(macros.contains(&"hyperparameter_config".to_string()));
}
#[test]
fn test_macro_registry_custom_registration() {
let mut registry = MacroRegistry::new();
fn test_macro(_input: TokenStream) -> TokenStream {
quote! { println!("test macro executed"); }
}
registry.register("test_macro", test_macro);
let macros = registry.list_macros();
assert!(macros.contains(&"test_macro".to_string()));
}
#[test]
fn test_unknown_macro_execution() {
let registry = MacroRegistry::new();
let result = registry.execute("unknown_macro", TokenStream::new());
let result_str = result.to_string();
assert!(result_str.contains("compile_error"));
assert!(result_str.contains("Unknown macro"));
}
}