use crate::RuleAtom;
use anyhow::Result;
use std::collections::{HashMap, HashSet, VecDeque};
use tracing::{debug, info};
use super::functions::NodeId;
use super::retenetwork_type::ReteNetwork;
use super::types::{ReteNode, Token};
impl ReteNetwork {
pub(super) fn is_left_token(&self, token: &Token, beta_id: NodeId) -> Result<bool> {
if let Some(ReteNode::Beta {
left_parent,
right_parent,
..
}) = self.nodes.get(&beta_id)
{
println!(
"is_left_token: beta_id={beta_id}, left_parent={left_parent}, right_parent={right_parent}"
);
let has_x = token.bindings.contains_key("X");
let has_z = token.bindings.contains_key("Z");
println!(
"Token bindings: {:?}, has_X: {has_x}, has_Z: {has_z}",
token.bindings
);
if has_x && !has_z {
println!("Token from left pattern (has X, no Z) - returning true");
return Ok(true);
} else if has_z && !has_x {
println!("Token from right pattern (has Z, no X) - returning false");
return Ok(false);
} else {
if let Some(last_fact) = token.facts.last() {
println!("Cannot determine from bindings, checking patterns...");
println!("Checking fact: {last_fact:?}");
let left_pattern = self.nodes.get(left_parent).and_then(|node| {
if let ReteNode::Alpha { pattern, .. } = node {
Some(pattern)
} else {
None
}
});
let right_pattern = self.nodes.get(right_parent).and_then(|node| {
if let ReteNode::Alpha { pattern, .. } = node {
Some(pattern)
} else {
None
}
});
if let (Some(left_pattern), Some(right_pattern)) = (left_pattern, right_pattern)
{
println!("Left pattern: {left_pattern:?}");
println!("Right pattern: {right_pattern:?}");
if (self.unify_atoms(right_pattern, last_fact, &HashMap::new())?).is_some()
{
println!("Fact matches right pattern - returning false");
return Ok(false);
} else if (self.unify_atoms(left_pattern, last_fact, &HashMap::new())?)
.is_some()
{
println!("Fact matches left pattern - returning true");
return Ok(true);
}
}
}
}
}
println!("Cannot determine token side - defaulting to left");
Ok(true)
}
pub(super) fn join_tokens(&self, left_token: &Token, right_token: &Token) -> Result<Token> {
let mut joined_token = Token::new();
joined_token.bindings.extend(left_token.bindings.clone());
for (var, value) in &right_token.bindings {
if let Some(existing_value) = joined_token.bindings.get(var) {
if !self.terms_equal(existing_value, value) {
return Err(anyhow::anyhow!("Binding conflict for variable {}", var));
}
} else {
joined_token.bindings.insert(var.clone(), value.clone());
}
}
joined_token.tags.extend(left_token.tags.clone());
joined_token.tags.extend(right_token.tags.clone());
joined_token.facts.extend(left_token.facts.clone());
joined_token.facts.extend(right_token.facts.clone());
Ok(joined_token)
}
pub(super) fn execute_production(
&self,
rule_name: &str,
rule_head: &[RuleAtom],
token: &Token,
) -> Result<Vec<RuleAtom>> {
if self.debug_mode {
debug!(
"Executing production '{}' with token: {:?}",
rule_name, token
);
}
let mut derived_facts = Vec::new();
for head_atom in rule_head {
let instantiated = self.apply_substitution(head_atom, &token.bindings)?;
derived_facts.push(instantiated);
}
if self.debug_mode && !derived_facts.is_empty() {
debug!(
"Production '{}' derived {} facts using bindings {:?}",
rule_name,
derived_facts.len(),
token.bindings
);
}
Ok(derived_facts)
}
pub fn remove_fact(&mut self, fact: &RuleAtom) -> Result<Vec<RuleAtom>> {
if self.debug_mode {
debug!("Removing fact from RETE network: {:?}", fact);
}
let retracted_facts = Vec::new();
let mut matching_alphas = Vec::new();
for (&node_id, node) in &self.nodes {
if let ReteNode::Alpha { pattern, .. } = node {
if self.unify_atoms(pattern, fact, &HashMap::new())?.is_some() {
matching_alphas.push(node_id);
}
}
}
for alpha_id in matching_alphas {
if let Some(memory) = self.alpha_memory.get_mut(&alpha_id) {
memory.remove(fact);
}
}
Ok(retracted_facts)
}
pub fn get_facts(&self) -> Vec<RuleAtom> {
let mut facts = Vec::new();
for memory in self.alpha_memory.values() {
facts.extend(memory.iter().cloned());
}
facts
}
pub fn forward_chain(&mut self, initial_facts: Vec<RuleAtom>) -> Result<Vec<RuleAtom>> {
info!(
"Starting RETE forward chaining with {} initial facts",
initial_facts.len()
);
let mut all_facts = HashSet::new();
let mut facts_to_process = VecDeque::from(initial_facts);
let mut iteration = 0;
while let Some(fact) = facts_to_process.pop_front() {
if all_facts.contains(&fact) {
continue;
}
all_facts.insert(fact.clone());
iteration += 1;
if self.debug_mode && iteration % 100 == 0 {
debug!(
"RETE iteration {}, {} facts processed",
iteration,
all_facts.len()
);
}
let derived = self.add_fact(fact)?;
for derived_fact in derived {
if !all_facts.contains(&derived_fact) {
facts_to_process.push_back(derived_fact);
}
}
}
info!(
"RETE forward chaining completed after {} iterations, {} facts total",
iteration,
all_facts.len()
);
Ok(all_facts.into_iter().collect())
}
}