use crate::error::{DSLCompileError, Result};
use crate::interval_domain::IntervalDomain;
use std::fs;
use std::path::PathBuf;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum RuleCategory {
CoreDatatypes,
BasicArithmetic,
DomainAwareArithmetic,
Transcendental,
Trigonometric,
Summation,
}
impl RuleCategory {
#[must_use]
pub fn filename(&self) -> &'static str {
match self {
RuleCategory::CoreDatatypes => "core_datatypes.egg",
RuleCategory::BasicArithmetic => "basic_arithmetic.egg",
RuleCategory::DomainAwareArithmetic => "domain_aware_arithmetic.egg",
RuleCategory::Transcendental => "transcendental.egg",
RuleCategory::Trigonometric => "trigonometric.egg",
RuleCategory::Summation => "summation.egg",
}
}
#[must_use]
pub fn description(&self) -> &'static str {
match self {
RuleCategory::CoreDatatypes => "Core mathematical expression datatypes",
RuleCategory::BasicArithmetic => "Basic arithmetic operations and identities",
RuleCategory::DomainAwareArithmetic => "Domain-aware arithmetic with preconditions",
RuleCategory::Transcendental => "Exponential and logarithmic functions",
RuleCategory::Trigonometric => "Trigonometric functions and identities",
RuleCategory::Summation => "Summation linearity and algebraic rules",
}
}
#[must_use]
pub fn all() -> Vec<RuleCategory> {
vec![
RuleCategory::CoreDatatypes,
RuleCategory::BasicArithmetic,
RuleCategory::DomainAwareArithmetic,
RuleCategory::Transcendental,
RuleCategory::Trigonometric,
RuleCategory::Summation,
]
}
#[must_use]
pub fn default_set() -> Vec<RuleCategory> {
vec![
RuleCategory::CoreDatatypes,
RuleCategory::BasicArithmetic,
RuleCategory::Transcendental,
]
}
#[must_use]
pub fn domain_aware_set() -> Vec<RuleCategory> {
vec![
RuleCategory::CoreDatatypes,
RuleCategory::DomainAwareArithmetic,
RuleCategory::Transcendental,
]
}
}
#[derive(Debug, Clone)]
pub struct RuleConfig {
pub categories: Vec<RuleCategory>,
pub rules_directory: Option<PathBuf>,
pub validate_syntax: bool,
pub include_comments: bool,
pub generate_domain_aware: bool,
pub variable_domains: std::collections::HashMap<String, IntervalDomain<f64>>,
}
impl Default for RuleConfig {
fn default() -> Self {
Self {
categories: RuleCategory::default_set(),
rules_directory: None,
validate_syntax: true,
include_comments: false,
generate_domain_aware: false,
variable_domains: std::collections::HashMap::new(),
}
}
}
impl RuleConfig {
#[must_use]
pub fn domain_aware() -> Self {
Self {
categories: RuleCategory::domain_aware_set(),
generate_domain_aware: true,
..Default::default()
}
}
#[must_use]
pub fn with_variable_domain(mut self, var_name: &str, domain: IntervalDomain<f64>) -> Self {
self.variable_domains.insert(var_name.to_string(), domain);
self
}
}
pub struct RuleLoader {
config: RuleConfig,
rules_dir: PathBuf,
}
impl RuleLoader {
#[must_use]
pub fn new(config: RuleConfig) -> Self {
let rules_dir = config
.rules_directory
.clone()
.unwrap_or_else(|| PathBuf::from("rules"));
Self { config, rules_dir }
}
#[must_use]
pub fn default() -> Self {
Self::new(RuleConfig::default())
}
#[must_use]
pub fn domain_aware() -> Self {
Self::new(RuleConfig::domain_aware())
}
pub fn load_rules(&self) -> Result<String> {
let mut program = String::new();
if self.config.include_comments {
program.push_str("; Combined Egglog Program for DSLCompile\n");
program.push_str("; Generated by RuleLoader\n\n");
}
if !self
.config
.categories
.contains(&RuleCategory::CoreDatatypes)
{
let core_content = self.load_rule_file(&RuleCategory::CoreDatatypes)?;
program.push_str(&core_content);
program.push('\n');
}
for category in &self.config.categories {
if self.config.include_comments {
program.push_str("; ========================================\n");
program.push_str(&format!("; {}\n", category.description()));
program.push_str("; ========================================\n\n");
}
let content = self.load_rule_file(category)?;
program.push_str(&content);
program.push('\n');
}
if self.config.generate_domain_aware {
if self.config.include_comments {
program.push_str("; ========================================\n");
program.push_str("; DYNAMICALLY GENERATED DOMAIN-AWARE RULES\n");
program.push_str("; ========================================\n\n");
}
let domain_rules = self.generate_domain_aware_rules()?;
program.push_str(&domain_rules);
program.push('\n');
}
if self.config.validate_syntax {
self.validate_program_syntax(&program)?;
}
Ok(program)
}
fn generate_domain_aware_rules(&self) -> Result<String> {
let mut rules = String::new();
rules.push_str("; Domain-aware power rules\n");
for (var_name, domain) in &self.config.variable_domains {
if domain.is_positive(0.0) {
rules.push_str(&format!(
"; Variable {var_name} is positive, safe to use x^0 = 1\n"
));
rules.push_str(&format!(
"(rewrite (Pow (Var \"{var_name}\") (Num 0.0)) (Num 1.0))\n"
));
}
if domain.is_non_negative(0.0) {
rules.push_str(&format!(
"; Variable {var_name} is non-negative, safe to use sqrt(x^2) = x\n"
));
rules.push_str(&format!(
"(rewrite (Sqrt (Mul (Var \"{var_name}\") (Var \"{var_name}\"))) (Var \"{var_name}\"))\n"
));
}
}
rules.push_str("\n; IEEE 754 compliant rules (computational, not mathematical)\n");
rules.push_str("; These follow IEEE 754 standard but may not be mathematically rigorous\n");
rules.push_str("(rewrite (Pow (Num 0.0) (Num 0.0)) (Num 1.0)) ; IEEE 754: 0^0 = 1\n");
Ok(rules)
}
fn load_rule_file(&self, category: &RuleCategory) -> Result<String> {
let file_path = self.rules_dir.join(category.filename());
fs::read_to_string(&file_path).map_err(|e| {
DSLCompileError::Optimization(format!(
"Failed to load rule file {}: {}",
file_path.display(),
e
))
})
}
fn validate_program_syntax(&self, program: &str) -> Result<()> {
let mut paren_count = 0;
let mut in_comment = false;
for line in program.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with(';') {
continue;
}
for ch in line.chars() {
match ch {
';' => in_comment = true,
'\n' => in_comment = false,
'(' if !in_comment => paren_count += 1,
')' if !in_comment => paren_count -= 1,
_ => {}
}
}
in_comment = false; }
if paren_count != 0 {
return Err(DSLCompileError::Optimization(format!(
"Unbalanced parentheses in egglog program: {paren_count} unclosed"
)));
}
if !program.contains("datatype Math") {
return Err(DSLCompileError::Optimization(
"Missing required 'datatype Math' definition".to_string(),
));
}
Ok(())
}
pub fn list_available_rules(&self) -> Result<Vec<(RuleCategory, bool, String)>> {
let mut rules_info = Vec::new();
for category in RuleCategory::all() {
let file_path = self.rules_dir.join(category.filename());
let exists = file_path.exists();
let description = category.description().to_string();
rules_info.push((category, exists, description));
}
Ok(rules_info)
}
pub fn validate_rule_files(&self) -> Result<()> {
for category in &self.config.categories {
let file_path = self.rules_dir.join(category.filename());
if !file_path.exists() {
return Err(DSLCompileError::Optimization(format!(
"Required rule file not found: {}",
file_path.display()
)));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rule_category_filename() {
assert_eq!(RuleCategory::CoreDatatypes.filename(), "core_datatypes.egg");
assert_eq!(
RuleCategory::BasicArithmetic.filename(),
"basic_arithmetic.egg"
);
assert_eq!(
RuleCategory::Transcendental.filename(),
"transcendental.egg"
);
}
#[test]
fn test_default_rule_config() {
let config = RuleConfig::default();
assert!(config.categories.contains(&RuleCategory::CoreDatatypes));
assert!(config.categories.contains(&RuleCategory::BasicArithmetic));
assert!(config.validate_syntax);
}
#[test]
fn test_rule_loader_creation() {
let loader = RuleLoader::default();
assert_eq!(loader.rules_dir, PathBuf::from("rules"));
}
#[test]
fn test_syntax_validation() {
let loader = RuleLoader::default();
let valid_program = "(datatype Math (Num f64))";
assert!(loader.validate_program_syntax(valid_program).is_ok());
let invalid_program = "(datatype Math (Num f64)";
assert!(loader.validate_program_syntax(invalid_program).is_err());
let missing_datatype = "(rewrite (Add ?x ?y) (Add ?y ?x))";
assert!(loader.validate_program_syntax(missing_datatype).is_err());
}
}