use crate::forward::Substitution;
use crate::{RuleAtom, Term};
use anyhow::{anyhow, Result};
use std::collections::{HashMap, HashSet, VecDeque};
use std::time::{Duration, Instant};
use tracing::warn;
#[derive(Debug, Clone)]
pub struct EnhancedToken {
pub bindings: Substitution,
pub facts: Vec<RuleAtom>,
pub timestamp: Instant,
pub priority: i32,
pub specificity: usize,
pub justification: Vec<String>,
}
impl Default for EnhancedToken {
fn default() -> Self {
Self::new()
}
}
impl EnhancedToken {
pub fn new() -> Self {
Self {
bindings: HashMap::new(),
facts: Vec::new(),
timestamp: Instant::now(),
priority: 0,
specificity: 0,
justification: Vec::new(),
}
}
pub fn with_fact(fact: RuleAtom) -> Self {
let mut token = Self::new();
token.facts.push(fact);
token
}
pub fn merge(left: &Self, right: &Self) -> Result<Self> {
let mut merged = Self::new();
merged.bindings.extend(left.bindings.clone());
for (var, value) in &right.bindings {
if let Some(existing) = merged.bindings.get(var) {
if !terms_compatible(existing, value) {
return Err(anyhow!(
"Binding conflict for variable {}: {:?} vs {:?}",
var,
existing,
value
));
}
} else {
merged.bindings.insert(var.clone(), value.clone());
}
}
merged.facts.extend(left.facts.clone());
merged.facts.extend(right.facts.clone());
merged.timestamp = left.timestamp.min(right.timestamp);
merged.priority = left.priority.max(right.priority);
merged.specificity = left.specificity + right.specificity;
merged.justification.extend(left.justification.clone());
merged.justification.extend(right.justification.clone());
Ok(merged)
}
}
#[derive(Debug, Clone, Copy)]
pub enum MemoryStrategy {
Unlimited,
LimitCount(usize),
LimitAge(Duration),
LRU(usize),
Adaptive,
}
#[derive(Debug, Clone, Copy)]
pub enum ConflictResolution {
First,
Recency,
Specificity,
Priority,
Combined,
}
#[derive(Debug)]
pub struct BetaMemory {
left_tokens: VecDeque<EnhancedToken>,
right_tokens: VecDeque<EnhancedToken>,
left_index: HashMap<String, HashMap<Term, Vec<usize>>>,
right_index: HashMap<String, HashMap<Term, Vec<usize>>>,
memory_strategy: MemoryStrategy,
access_times: HashMap<usize, Instant>,
stats: MemoryStats,
}
#[derive(Debug, Default)]
pub struct MemoryStats {
pub total_joins: usize,
pub successful_joins: usize,
pub evictions: usize,
pub peak_size: usize,
}
impl BetaMemory {
pub fn new(strategy: MemoryStrategy) -> Self {
Self {
left_tokens: VecDeque::new(),
right_tokens: VecDeque::new(),
left_index: HashMap::new(),
right_index: HashMap::new(),
memory_strategy: strategy,
access_times: HashMap::new(),
stats: MemoryStats::default(),
}
}
pub fn add_left(&mut self, token: EnhancedToken, join_vars: &[String]) -> usize {
let idx = self.left_tokens.len();
self.left_tokens.push_back(token.clone());
for var in join_vars {
if let Some(value) = token.bindings.get(var) {
self.left_index
.entry(var.clone())
.or_default()
.entry(value.clone())
.or_default()
.push(idx);
}
}
self.access_times.insert(idx, Instant::now());
self.apply_memory_management();
self.stats.peak_size = self.stats.peak_size.max(self.left_tokens.len());
idx
}
pub fn add_right(&mut self, token: EnhancedToken, join_vars: &[String]) -> usize {
let idx = self.right_tokens.len();
self.right_tokens.push_back(token.clone());
for var in join_vars {
if let Some(value) = token.bindings.get(var) {
self.right_index
.entry(var.clone())
.or_default()
.entry(value.clone())
.or_default()
.push(idx);
}
}
self.access_times.insert(idx + 1000000, Instant::now()); self.apply_memory_management();
self.stats.peak_size = self.stats.peak_size.max(self.right_tokens.len());
idx
}
pub fn find_matches_indexed(
&mut self,
token: &EnhancedToken,
is_left: bool,
join_vars: &[String],
) -> Vec<EnhancedToken> {
self.stats.total_joins += 1;
let mut matches = Vec::new();
if join_vars.is_empty() {
let tokens = if is_left {
&self.right_tokens
} else {
&self.left_tokens
};
matches.extend(tokens.iter().cloned());
} else {
let indices = if is_left {
&self.right_index
} else {
&self.left_index
};
let mut candidate_indices = HashSet::new();
for var in join_vars {
if let Some(value) = token.bindings.get(var) {
if let Some(var_index) = indices.get(var) {
if let Some(token_indices) = var_index.get(value) {
if candidate_indices.is_empty() {
candidate_indices.extend(token_indices);
} else {
candidate_indices.retain(|idx| token_indices.contains(idx));
}
}
}
}
}
let tokens = if is_left {
&self.right_tokens
} else {
&self.left_tokens
};
for &idx in &candidate_indices {
if let Some(match_token) = tokens.get(idx) {
matches.push(match_token.clone());
let access_key = if is_left { idx + 1000000 } else { idx };
self.access_times.insert(access_key, Instant::now());
}
}
}
if !matches.is_empty() {
self.stats.successful_joins += 1;
}
matches
}
pub fn memory_strategy(&self) -> &MemoryStrategy {
&self.memory_strategy
}
pub fn set_memory_strategy(&mut self, strategy: MemoryStrategy) {
self.memory_strategy = strategy;
}
fn apply_memory_management(&mut self) {
match self.memory_strategy {
MemoryStrategy::LimitCount(max_count) => {
while self.left_tokens.len() + self.right_tokens.len() > max_count {
self.evict_oldest();
}
}
MemoryStrategy::LimitAge(max_age) => {
let now = Instant::now();
self.left_tokens
.retain(|token| now.duration_since(token.timestamp) < max_age);
self.right_tokens
.retain(|token| now.duration_since(token.timestamp) < max_age);
self.rebuild_indices();
}
MemoryStrategy::LRU(max_count) => {
while self.left_tokens.len() + self.right_tokens.len() > max_count {
self.evict_lru();
}
}
MemoryStrategy::Adaptive => {
let success_rate = if self.stats.total_joins > 0 {
self.stats.successful_joins as f64 / self.stats.total_joins as f64
} else {
1.0
};
if success_rate < 0.1 && self.left_tokens.len() + self.right_tokens.len() > 1000 {
self.evict_oldest();
}
}
MemoryStrategy::Unlimited => {}
}
}
fn evict_oldest(&mut self) {
if self.left_tokens.len() > self.right_tokens.len() {
self.left_tokens.pop_front();
} else if !self.right_tokens.is_empty() {
self.right_tokens.pop_front();
}
self.stats.evictions += 1;
self.rebuild_indices();
}
fn evict_lru(&mut self) {
if let Some((&oldest_key, _)) = self.access_times.iter().min_by_key(|&(_, &time)| time) {
if oldest_key < 1000000 {
if oldest_key < self.left_tokens.len() {
self.left_tokens.remove(oldest_key);
}
} else {
let idx = oldest_key - 1000000;
if idx < self.right_tokens.len() {
self.right_tokens.remove(idx);
}
}
self.stats.evictions += 1;
self.rebuild_indices();
}
}
fn rebuild_indices(&mut self) {
self.left_index.clear();
self.right_index.clear();
self.access_times.clear();
for (idx, token) in self.left_tokens.iter().enumerate() {
for (var, value) in &token.bindings {
self.left_index
.entry(var.clone())
.or_default()
.entry(value.clone())
.or_default()
.push(idx);
}
self.access_times.insert(idx, Instant::now());
}
for (idx, token) in self.right_tokens.iter().enumerate() {
for (var, value) in &token.bindings {
self.right_index
.entry(var.clone())
.or_default()
.entry(value.clone())
.or_default()
.push(idx);
}
self.access_times.insert(idx + 1000000, Instant::now());
}
}
}
#[derive(Debug)]
pub struct BetaJoinNode {
pub id: usize,
pub left_parent: usize,
pub right_parent: usize,
pub join_variables: Vec<String>,
pub conditions: Vec<JoinCondition>,
pub memory: BetaMemory,
pub children: Vec<usize>,
pub conflict_resolution: ConflictResolution,
}
#[derive(Debug, Clone)]
pub enum JoinCondition {
VarEquality { left_var: String, right_var: String },
VarComparison {
left_var: String,
right_var: String,
op: ComparisonOp,
},
VarConstComparison {
var: String,
constant: Term,
op: ComparisonOp,
},
Builtin {
predicate: String,
args: Vec<JoinArg>,
},
Not(Box<JoinCondition>),
And(Vec<JoinCondition>),
Or(Vec<JoinCondition>),
}
#[derive(Debug, Clone)]
pub enum JoinArg {
LeftVar(String),
RightVar(String),
Constant(Term),
}
#[derive(Debug, Clone, Copy)]
pub enum ComparisonOp {
Equal,
NotEqual,
Less,
LessEqual,
Greater,
GreaterEqual,
}
impl BetaJoinNode {
pub fn new(
id: usize,
left_parent: usize,
right_parent: usize,
memory_strategy: MemoryStrategy,
conflict_resolution: ConflictResolution,
) -> Self {
Self {
id,
left_parent,
right_parent,
join_variables: Vec::new(),
conditions: Vec::new(),
memory: BetaMemory::new(memory_strategy),
children: Vec::new(),
conflict_resolution,
}
}
pub fn is_filter_only(&self) -> bool {
self.left_parent == self.right_parent
}
pub fn join(&mut self, token: EnhancedToken, from_left: bool) -> Result<Vec<EnhancedToken>> {
if self.is_filter_only() {
return self.apply_filter_only(token);
}
let mut results = Vec::new();
if from_left {
self.memory.add_left(token.clone(), &self.join_variables);
let matches = self
.memory
.find_matches_indexed(&token, true, &self.join_variables);
for right_token in matches {
if let Ok(joined) = self.try_join(&token, &right_token) {
results.push(joined);
}
}
} else {
self.memory.add_right(token.clone(), &self.join_variables);
let matches = self
.memory
.find_matches_indexed(&token, false, &self.join_variables);
for left_token in matches {
if let Ok(joined) = self.try_join(&left_token, &token) {
results.push(joined);
}
}
}
if results.len() > 1 {
results = self.apply_conflict_resolution(results);
}
Ok(results)
}
fn apply_filter_only(&self, token: EnhancedToken) -> Result<Vec<EnhancedToken>> {
for condition in &self.conditions {
if !self.evaluate_condition(condition, &token, &token)? {
return Ok(Vec::new());
}
}
Ok(vec![token])
}
fn try_join(&self, left: &EnhancedToken, right: &EnhancedToken) -> Result<EnhancedToken> {
for var in &self.join_variables {
if let (Some(left_val), Some(right_val)) =
(left.bindings.get(var), right.bindings.get(var))
{
if !terms_compatible(left_val, right_val) {
return Err(anyhow!("Join variable {} doesn't match", var));
}
}
}
for condition in &self.conditions {
if !self.evaluate_condition(condition, left, right)? {
return Err(anyhow!("Join condition failed"));
}
}
EnhancedToken::merge(left, right)
}
#[allow(clippy::only_used_in_recursion)]
fn evaluate_condition(
&self,
condition: &JoinCondition,
left: &EnhancedToken,
right: &EnhancedToken,
) -> Result<bool> {
match condition {
JoinCondition::VarEquality {
left_var,
right_var,
} => {
let left_val = left.bindings.get(left_var);
let right_val = right.bindings.get(right_var);
Ok(match (left_val, right_val) {
(Some(lv), Some(rv)) => terms_compatible(lv, rv),
_ => false,
})
}
JoinCondition::VarComparison {
left_var,
right_var,
op,
} => {
let left_val = left.bindings.get(left_var);
let right_val = right.bindings.get(right_var);
Ok(match (left_val, right_val) {
(Some(lv), Some(rv)) => evaluate_comparison(lv, rv, *op)?,
_ => false,
})
}
JoinCondition::VarConstComparison { var, constant, op } => {
let var_val = left.bindings.get(var).or_else(|| right.bindings.get(var));
Ok(match var_val {
Some(val) => evaluate_comparison(val, constant, *op)?,
None => false,
})
}
JoinCondition::Builtin { predicate, args } => {
evaluate_builtin(predicate, args, left, right)
}
JoinCondition::Not(cond) => Ok(!self.evaluate_condition(cond, left, right)?),
JoinCondition::And(conds) => {
for cond in conds {
if !self.evaluate_condition(cond, left, right)? {
return Ok(false);
}
}
Ok(true)
}
JoinCondition::Or(conds) => {
for cond in conds {
if self.evaluate_condition(cond, left, right)? {
return Ok(true);
}
}
Ok(false)
}
}
}
fn apply_conflict_resolution(&self, mut tokens: Vec<EnhancedToken>) -> Vec<EnhancedToken> {
match self.conflict_resolution {
ConflictResolution::First => {
tokens.truncate(1);
tokens
}
ConflictResolution::Recency => {
tokens.sort_by_key(|t| std::cmp::Reverse(t.timestamp));
tokens.truncate(1);
tokens
}
ConflictResolution::Specificity => {
tokens.sort_by_key(|t| std::cmp::Reverse(t.specificity));
tokens.truncate(1);
tokens
}
ConflictResolution::Priority => {
tokens.sort_by_key(|t| std::cmp::Reverse(t.priority));
tokens.truncate(1);
tokens
}
ConflictResolution::Combined => {
tokens.sort_by_key(|t| {
let recency_score = t.timestamp.elapsed().as_secs() as i32;
std::cmp::Reverse(t.priority * 1000 + t.specificity as i32 * 10 - recency_score)
});
tokens.truncate(1);
tokens
}
}
}
pub fn get_stats(&self) -> &MemoryStats {
&self.memory.stats
}
}
fn terms_compatible(t1: &Term, t2: &Term) -> bool {
match (t1, t2) {
(Term::Variable(_), _) | (_, Term::Variable(_)) => true,
(Term::Constant(c1), Term::Constant(c2)) => c1 == c2,
(Term::Literal(l1), Term::Literal(l2)) => l1 == l2,
_ => false,
}
}
fn evaluate_comparison(left: &Term, right: &Term, op: ComparisonOp) -> Result<bool> {
let left_num = parse_numeric(left);
let right_num = parse_numeric(right);
match (left_num, right_num) {
(Some(l), Some(r)) => Ok(match op {
ComparisonOp::Equal => (l - r).abs() < f64::EPSILON,
ComparisonOp::NotEqual => (l - r).abs() >= f64::EPSILON,
ComparisonOp::Less => l < r,
ComparisonOp::LessEqual => l <= r,
ComparisonOp::Greater => l > r,
ComparisonOp::GreaterEqual => l >= r,
}),
_ => {
let left_str = term_to_string(left);
let right_str = term_to_string(right);
Ok(match op {
ComparisonOp::Equal => left_str == right_str,
ComparisonOp::NotEqual => left_str != right_str,
ComparisonOp::Less => left_str < right_str,
ComparisonOp::LessEqual => left_str <= right_str,
ComparisonOp::Greater => left_str > right_str,
ComparisonOp::GreaterEqual => left_str >= right_str,
})
}
}
}
fn parse_numeric(term: &Term) -> Option<f64> {
match term {
Term::Literal(s) | Term::Constant(s) => s.parse::<f64>().ok(),
_ => None,
}
}
fn term_to_string(term: &Term) -> String {
match term {
Term::Variable(v) => format!("?{v}"),
Term::Constant(c) => c.clone(),
Term::Literal(l) => l.clone(),
Term::Function { name, args } => {
let arg_strings: Vec<String> = args.iter().map(term_to_string).collect();
format!("{}({})", name, arg_strings.join(","))
}
}
}
fn evaluate_builtin(
predicate: &str,
args: &[JoinArg],
left: &EnhancedToken,
right: &EnhancedToken,
) -> Result<bool> {
let arg_values: Vec<Option<Term>> = args
.iter()
.map(|arg| match arg {
JoinArg::LeftVar(var) => left.bindings.get(var).cloned(),
JoinArg::RightVar(var) => right.bindings.get(var).cloned(),
JoinArg::Constant(term) => Some(term.clone()),
})
.collect();
if arg_values.iter().any(|v| v.is_none()) {
return Ok(false);
}
let values: Vec<Term> = arg_values
.into_iter()
.map(|v| v.expect("argument values verified to be Some"))
.collect();
match predicate {
"regex" => {
if values.len() >= 2 {
if let (Term::Literal(text), Term::Literal(pattern)) = (&values[0], &values[1]) {
let re =
regex::Regex::new(pattern).map_err(|e| anyhow!("Invalid regex: {}", e))?;
Ok(re.is_match(text))
} else {
Ok(false)
}
} else {
Ok(false)
}
}
"contains" => {
if values.len() >= 2 {
let s1 = term_to_string(&values[0]);
let s2 = term_to_string(&values[1]);
Ok(s1.contains(&s2))
} else {
Ok(false)
}
}
"starts_with" => {
if values.len() >= 2 {
let s1 = term_to_string(&values[0]);
let s2 = term_to_string(&values[1]);
Ok(s1.starts_with(&s2))
} else {
Ok(false)
}
}
"numeric_add" => {
if values.len() >= 3 {
if let (Some(n1), Some(n2), Some(result)) = (
parse_numeric(&values[0]),
parse_numeric(&values[1]),
parse_numeric(&values[2]),
) {
Ok((n1 + n2 - result).abs() < f64::EPSILON)
} else {
Ok(false)
}
} else {
Ok(false)
}
}
_ => {
warn!("Unknown builtin predicate: {}", predicate);
Ok(false)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_enhanced_token_merge() -> Result<(), Box<dyn std::error::Error>> {
let mut token1 = EnhancedToken::new();
token1
.bindings
.insert("X".to_string(), Term::Constant("a".to_string()));
token1
.bindings
.insert("Y".to_string(), Term::Constant("b".to_string()));
let mut token2 = EnhancedToken::new();
token2
.bindings
.insert("Y".to_string(), Term::Constant("b".to_string()));
token2
.bindings
.insert("Z".to_string(), Term::Constant("c".to_string()));
let merged = EnhancedToken::merge(&token1, &token2)?;
assert_eq!(merged.bindings.len(), 3);
assert_eq!(
merged.bindings.get("X"),
Some(&Term::Constant("a".to_string()))
);
assert_eq!(
merged.bindings.get("Y"),
Some(&Term::Constant("b".to_string()))
);
assert_eq!(
merged.bindings.get("Z"),
Some(&Term::Constant("c".to_string()))
);
Ok(())
}
#[test]
fn test_beta_memory_indexed_lookup() {
let mut memory = BetaMemory::new(MemoryStrategy::Unlimited);
let mut token1 = EnhancedToken::new();
token1
.bindings
.insert("X".to_string(), Term::Constant("a".to_string()));
memory.add_left(token1, &["X".to_string()]);
let mut token2 = EnhancedToken::new();
token2
.bindings
.insert("X".to_string(), Term::Constant("b".to_string()));
memory.add_left(token2, &["X".to_string()]);
let mut search_token = EnhancedToken::new();
search_token
.bindings
.insert("X".to_string(), Term::Constant("a".to_string()));
let matches = memory.find_matches_indexed(&search_token, false, &["X".to_string()]);
assert_eq!(matches.len(), 1);
assert_eq!(
matches[0].bindings.get("X"),
Some(&Term::Constant("a".to_string()))
);
}
#[test]
fn test_memory_eviction_strategies() {
let mut memory = BetaMemory::new(MemoryStrategy::LimitCount(2));
for i in 0..5 {
let mut token = EnhancedToken::new();
token
.bindings
.insert("X".to_string(), Term::Constant(i.to_string()));
memory.add_left(token, &["X".to_string()]);
}
assert!(memory.left_tokens.len() <= 2);
assert!(memory.stats.evictions > 0);
}
#[test]
fn test_join_conditions() -> Result<(), Box<dyn std::error::Error>> {
let node = BetaJoinNode::new(
1,
0,
0,
MemoryStrategy::Unlimited,
ConflictResolution::First,
);
let mut left = EnhancedToken::new();
left.bindings
.insert("X".to_string(), Term::Literal("5".to_string()));
let mut right = EnhancedToken::new();
right
.bindings
.insert("Y".to_string(), Term::Literal("10".to_string()));
let cond = JoinCondition::VarComparison {
left_var: "X".to_string(),
right_var: "Y".to_string(),
op: ComparisonOp::Less,
};
assert!(node.evaluate_condition(&cond, &left, &right)?);
Ok(())
}
#[test]
fn test_builtin_evaluation() -> Result<(), Box<dyn std::error::Error>> {
let mut left = EnhancedToken::new();
left.bindings
.insert("text".to_string(), Term::Literal("hello world".to_string()));
let right = EnhancedToken::new();
let args = vec![
JoinArg::LeftVar("text".to_string()),
JoinArg::Constant(Term::Literal("hello.*".to_string())),
];
assert!(evaluate_builtin("regex", &args, &left, &right)?);
let args = vec![
JoinArg::LeftVar("text".to_string()),
JoinArg::Constant(Term::Literal("world".to_string())),
];
assert!(evaluate_builtin("contains", &args, &left, &right)?);
Ok(())
}
#[test]
fn test_memory_strategy_getter_setter() {
let mut memory = BetaMemory::new(MemoryStrategy::Unlimited);
assert!(matches!(
memory.memory_strategy(),
&MemoryStrategy::Unlimited
));
memory.set_memory_strategy(MemoryStrategy::LimitCount(100));
assert!(matches!(
memory.memory_strategy(),
&MemoryStrategy::LimitCount(100)
));
memory.set_memory_strategy(MemoryStrategy::LRU(50));
assert!(matches!(memory.memory_strategy(), &MemoryStrategy::LRU(50)));
}
}