#[allow(unused_imports)]
use crate::prelude::*;
use oxiz_core::ast::{TermId, TermKind, TermManager};
use oxiz_core::interner::Spur;
use super::QuantifiedFormula;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Pattern {
pub terms: Vec<TermId>,
pub variables: FxHashSet<Spur>,
pub quality: u32,
pub pattern_type: PatternType,
}
impl Pattern {
pub fn new(terms: Vec<TermId>) -> Self {
Self {
terms,
variables: FxHashSet::default(),
quality: 0,
pattern_type: PatternType::MultiPattern,
}
}
pub fn extract_variables(&mut self, manager: &TermManager) {
self.variables.clear();
let terms: Vec<_> = self.terms.to_vec();
for term in terms {
self.extract_vars_rec(term, manager);
}
}
fn extract_vars_rec(&mut self, term: TermId, manager: &TermManager) {
let mut visited = FxHashSet::default();
self.extract_vars_helper(term, manager, &mut visited);
}
fn extract_vars_helper(
&mut self,
term: TermId,
manager: &TermManager,
visited: &mut FxHashSet<TermId>,
) {
if visited.contains(&term) {
return;
}
visited.insert(term);
let Some(t) = manager.get(term) else {
return;
};
if let TermKind::Var(name) = t.kind {
self.variables.insert(name);
return;
}
match &t.kind {
TermKind::Apply { args, .. } => {
for &arg in args.iter() {
self.extract_vars_helper(arg, manager, visited);
}
}
TermKind::Not(arg) | TermKind::Neg(arg) => {
self.extract_vars_helper(*arg, manager, visited);
}
TermKind::And(args) | TermKind::Or(args) => {
for &arg in args {
self.extract_vars_helper(arg, manager, visited);
}
}
_ => {}
}
}
pub fn calculate_quality(&mut self, manager: &TermManager) {
let num_funcs = self.count_function_symbols(manager);
let num_vars = self.variables.len();
let complexity_penalty = self.terms.len();
self.quality = (num_funcs * 100 + num_vars * 50) as u32 - complexity_penalty as u32;
}
fn count_function_symbols(&self, manager: &TermManager) -> usize {
let mut count = 0;
let mut visited = FxHashSet::default();
for &term in &self.terms {
count += self.count_funcs_rec(term, manager, &mut visited);
}
count
}
fn count_funcs_rec(
&self,
term: TermId,
manager: &TermManager,
visited: &mut FxHashSet<TermId>,
) -> usize {
if visited.contains(&term) {
return 0;
}
visited.insert(term);
let Some(t) = manager.get(term) else {
return 0;
};
match &t.kind {
TermKind::Apply { args, .. } => {
1 + args
.iter()
.map(|&arg| self.count_funcs_rec(arg, manager, visited))
.sum::<usize>()
}
_ => 0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PatternType {
SingleTerm,
MultiPattern,
UserSpecified,
AutoGenerated,
}
#[derive(Debug)]
pub struct PatternGenerator {
max_patterns: usize,
min_quality: u32,
stats: GeneratorStats,
}
impl PatternGenerator {
pub fn new() -> Self {
Self {
max_patterns: 10,
min_quality: 0,
stats: GeneratorStats::default(),
}
}
pub fn generate(
&mut self,
quantifier: &QuantifiedFormula,
manager: &TermManager,
) -> Vec<Pattern> {
self.stats.num_generations += 1;
if !quantifier.patterns.is_empty() {
return self.user_patterns_to_patterns(&quantifier.patterns, manager);
}
let mut patterns = Vec::new();
patterns.extend(self.generate_function_patterns(quantifier.body, manager));
patterns.extend(self.generate_equality_patterns(quantifier.body, manager));
patterns.extend(self.generate_arithmetic_patterns(quantifier.body, manager));
patterns.retain(|p| p.quality >= self.min_quality);
patterns.sort_by_key(|p| std::cmp::Reverse(p.quality));
patterns.truncate(self.max_patterns);
self.stats.num_patterns_generated += patterns.len();
patterns
}
fn user_patterns_to_patterns(
&self,
user_patterns: &[Vec<TermId>],
manager: &TermManager,
) -> Vec<Pattern> {
let mut patterns = Vec::new();
for pattern_terms in user_patterns {
let mut pattern = Pattern::new(pattern_terms.clone());
pattern.extract_variables(manager);
pattern.calculate_quality(manager);
pattern.pattern_type = PatternType::UserSpecified;
patterns.push(pattern);
}
patterns
}
fn generate_function_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
let mut patterns = Vec::new();
let func_apps = self.collect_function_applications(body, manager);
for func_app in func_apps {
let mut pattern = Pattern::new(vec![func_app]);
pattern.extract_variables(manager);
pattern.calculate_quality(manager);
pattern.pattern_type = PatternType::AutoGenerated;
patterns.push(pattern);
}
patterns
}
fn generate_equality_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
let mut patterns = Vec::new();
let equalities = self.collect_equalities(body, manager);
for eq_term in equalities {
let mut pattern = Pattern::new(vec![eq_term]);
pattern.extract_variables(manager);
pattern.calculate_quality(manager);
pattern.pattern_type = PatternType::AutoGenerated;
patterns.push(pattern);
}
patterns
}
fn generate_arithmetic_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
let mut patterns = Vec::new();
let arith_terms = self.collect_arithmetic_terms(body, manager);
for arith_term in arith_terms {
let mut pattern = Pattern::new(vec![arith_term]);
pattern.extract_variables(manager);
pattern.calculate_quality(manager);
pattern.pattern_type = PatternType::AutoGenerated;
patterns.push(pattern);
}
patterns
}
fn collect_function_applications(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
let mut results = Vec::new();
let mut visited = FxHashSet::default();
self.collect_funcs_rec(term, &mut results, &mut visited, manager);
results
}
fn collect_funcs_rec(
&self,
term: TermId,
results: &mut Vec<TermId>,
visited: &mut FxHashSet<TermId>,
manager: &TermManager,
) {
if visited.contains(&term) {
return;
}
visited.insert(term);
let Some(t) = manager.get(term) else {
return;
};
if let TermKind::Apply { args, .. } = &t.kind {
results.push(term);
for &arg in args.iter() {
self.collect_funcs_rec(arg, results, visited, manager);
}
}
match &t.kind {
TermKind::Not(arg) | TermKind::Neg(arg) => {
self.collect_funcs_rec(*arg, results, visited, manager);
}
TermKind::And(args) | TermKind::Or(args) => {
for &arg in args {
self.collect_funcs_rec(arg, results, visited, manager);
}
}
_ => {}
}
}
fn collect_equalities(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
let mut results = Vec::new();
let mut visited = FxHashSet::default();
self.collect_eqs_rec(term, &mut results, &mut visited, manager);
results
}
fn collect_eqs_rec(
&self,
term: TermId,
results: &mut Vec<TermId>,
visited: &mut FxHashSet<TermId>,
manager: &TermManager,
) {
if visited.contains(&term) {
return;
}
visited.insert(term);
let Some(t) = manager.get(term) else {
return;
};
if matches!(t.kind, TermKind::Eq(_, _)) {
results.push(term);
}
match &t.kind {
TermKind::Not(arg) | TermKind::Neg(arg) => {
self.collect_eqs_rec(*arg, results, visited, manager);
}
TermKind::And(args) | TermKind::Or(args) => {
for &arg in args {
self.collect_eqs_rec(arg, results, visited, manager);
}
}
_ => {}
}
}
fn collect_arithmetic_terms(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
let mut results = Vec::new();
let mut visited = FxHashSet::default();
self.collect_arith_rec(term, &mut results, &mut visited, manager);
results
}
fn collect_arith_rec(
&self,
term: TermId,
results: &mut Vec<TermId>,
visited: &mut FxHashSet<TermId>,
manager: &TermManager,
) {
if visited.contains(&term) {
return;
}
visited.insert(term);
let Some(t) = manager.get(term) else {
return;
};
match &t.kind {
TermKind::Lt(_, _) | TermKind::Le(_, _) | TermKind::Gt(_, _) | TermKind::Ge(_, _) => {
results.push(term);
}
TermKind::Not(arg) | TermKind::Neg(arg) => {
self.collect_arith_rec(*arg, results, visited, manager);
}
TermKind::And(args) | TermKind::Or(args) => {
for &arg in args {
self.collect_arith_rec(arg, results, visited, manager);
}
}
_ => {}
}
}
pub fn stats(&self) -> &GeneratorStats {
&self.stats
}
}
impl Default for PatternGenerator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct GeneratorStats {
pub num_generations: usize,
pub num_patterns_generated: usize,
}
#[derive(Debug)]
pub struct MultiPatternCoordinator {
pattern_sets: Vec<PatternSet>,
match_cache: FxHashMap<TermId, Vec<PatternMatch>>,
}
impl MultiPatternCoordinator {
pub fn new() -> Self {
Self {
pattern_sets: Vec::new(),
match_cache: FxHashMap::default(),
}
}
pub fn add_pattern_set(&mut self, patterns: Vec<Pattern>) {
self.pattern_sets.push(PatternSet {
patterns,
matches: Vec::new(),
});
}
pub fn find_matches(&mut self, _manager: &TermManager) -> Vec<MultiMatch> {
let mut multi_matches = Vec::new();
for pattern_set in &self.pattern_sets {
let mut set_matches = Vec::new();
for pattern in &pattern_set.patterns {
for &term in &pattern.terms {
if let Some(cached) = self.match_cache.get(&term) {
set_matches.extend(cached.clone());
}
}
}
if !set_matches.is_empty() {
multi_matches.push(MultiMatch {
pattern_set: pattern_set.patterns.clone(),
matches: set_matches,
});
}
}
multi_matches
}
pub fn clear_cache(&mut self) {
self.match_cache.clear();
}
}
impl Default for MultiPatternCoordinator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
struct PatternSet {
patterns: Vec<Pattern>,
matches: Vec<PatternMatch>,
}
#[derive(Debug, Clone)]
pub struct PatternMatch {
pub pattern: Pattern,
pub matched_term: TermId,
pub bindings: FxHashMap<Spur, TermId>,
}
#[derive(Debug, Clone)]
pub struct MultiMatch {
pub pattern_set: Vec<Pattern>,
pub matches: Vec<PatternMatch>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_creation() {
let pattern = Pattern::new(vec![TermId::new(1)]);
assert_eq!(pattern.terms.len(), 1);
assert_eq!(pattern.variables.len(), 0);
}
#[test]
fn test_pattern_type_equality() {
assert_eq!(PatternType::SingleTerm, PatternType::SingleTerm);
assert_ne!(PatternType::SingleTerm, PatternType::MultiPattern);
}
#[test]
fn test_pattern_generator_creation() {
let generator = PatternGenerator::new();
assert_eq!(generator.max_patterns, 10);
}
#[test]
fn test_multi_pattern_coordinator() {
let mut coord = MultiPatternCoordinator::new();
coord.add_pattern_set(vec![]);
assert_eq!(coord.pattern_sets.len(), 1);
}
#[test]
fn test_pattern_equality() {
let p1 = Pattern::new(vec![TermId::new(1)]);
let p2 = Pattern::new(vec![TermId::new(1)]);
assert_eq!(p1, p2);
}
}