use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct WME {
pub fields: [String; 3],
}
impl WME {
pub fn new(s: impl Into<String>, p: impl Into<String>, o: impl Into<String>) -> Self {
WME {
fields: [s.into(), p.into(), o.into()],
}
}
pub fn subject(&self) -> &str {
&self.fields[0]
}
pub fn predicate(&self) -> &str {
&self.fields[1]
}
pub fn object(&self) -> &str {
&self.fields[2]
}
pub fn matches_condition(&self, cond: &Condition) -> bool {
let field_val = match cond.field {
CondField::Subject => &self.fields[0],
CondField::Predicate => &self.fields[1],
CondField::Object => &self.fields[2],
};
match &cond.test {
CondTest::Constant(c) => field_val == c,
CondTest::Variable(_) => true, CondTest::Any => true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CondField {
Subject,
Predicate,
Object,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CondTest {
Constant(String),
Variable(String),
Any,
}
#[derive(Debug, Clone)]
pub struct Condition {
pub field: CondField,
pub test: CondTest,
}
impl Condition {
pub fn constant(field: CondField, value: impl Into<String>) -> Self {
Condition {
field,
test: CondTest::Constant(value.into()),
}
}
pub fn variable(field: CondField, name: impl Into<String>) -> Self {
Condition {
field,
test: CondTest::Variable(name.into()),
}
}
pub fn any(field: CondField) -> Self {
Condition {
field,
test: CondTest::Any,
}
}
}
#[derive(Debug, Clone)]
pub struct Production {
pub id: String,
pub conditions: Vec<Condition>,
pub action: String,
}
impl Production {
pub fn new(
id: impl Into<String>,
conditions: Vec<Condition>,
action: impl Into<String>,
) -> Self {
Production {
id: id.into(),
conditions,
action: action.into(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct AlphaMemory {
pub wmes: Vec<WME>,
}
impl AlphaMemory {
pub fn new() -> Self {
AlphaMemory { wmes: Vec::new() }
}
pub fn add(&mut self, wme: WME) {
self.wmes.push(wme);
}
pub fn remove(&mut self, wme: &WME) -> bool {
if let Some(pos) = self.wmes.iter().position(|w| w == wme) {
self.wmes.remove(pos);
true
} else {
false
}
}
pub fn len(&self) -> usize {
self.wmes.len()
}
pub fn is_empty(&self) -> bool {
self.wmes.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct Token {
pub wmes: Vec<WME>,
pub bindings: HashMap<String, String>,
}
impl Token {
pub fn new() -> Self {
Token {
wmes: Vec::new(),
bindings: HashMap::new(),
}
}
pub fn extend(&self, wme: WME, new_bindings: HashMap<String, String>) -> Token {
let mut t = self.clone();
t.wmes.push(wme);
t.bindings.extend(new_bindings);
t
}
}
impl Default for Token {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct BetaMemory {
pub tokens: Vec<Token>,
}
impl BetaMemory {
pub fn new() -> Self {
BetaMemory { tokens: Vec::new() }
}
}
pub struct ReteNetwork {
alpha_memories: HashMap<String, AlphaMemory>,
beta_memories: Vec<BetaMemory>,
productions: Vec<Production>,
all_wmes: Vec<WME>,
}
impl Default for ReteNetwork {
fn default() -> Self {
Self::new()
}
}
impl ReteNetwork {
pub fn new() -> Self {
ReteNetwork {
alpha_memories: HashMap::new(),
beta_memories: Vec::new(),
productions: Vec::new(),
all_wmes: Vec::new(),
}
}
pub fn add_production(&mut self, production: Production) {
self.productions.push(production);
}
pub fn add_wme(&mut self, wme: WME) -> Vec<String> {
self.all_wmes.push(wme.clone());
for prod in &self.productions {
for cond in &prod.conditions {
if wme.matches_condition(cond) {
let key = alpha_key(cond);
self.alpha_memories.entry(key).or_default().add(wme.clone());
}
}
}
self.evaluate_productions()
}
pub fn remove_wme(&mut self, wme: &WME) -> bool {
let pos = self.all_wmes.iter().position(|w| w == wme);
if let Some(p) = pos {
self.all_wmes.remove(p);
for am in self.alpha_memories.values_mut() {
am.remove(wme);
}
true
} else {
false
}
}
pub fn production_count(&self) -> usize {
self.productions.len()
}
pub fn wme_count(&self) -> usize {
self.all_wmes.len()
}
pub fn get_alpha_memory(&self, key: &str) -> Option<&AlphaMemory> {
self.alpha_memories.get(key)
}
pub fn clear(&mut self) {
self.all_wmes.clear();
self.alpha_memories.clear();
self.beta_memories.clear();
}
fn evaluate_productions(&self) -> Vec<String> {
let mut triggered = Vec::new();
'prod: for prod in &self.productions {
if prod.conditions.is_empty() {
triggered.push(prod.id.clone());
continue;
}
let first_cond = &prod.conditions[0];
let candidates = self.wmes_matching_condition(first_cond);
for first_wme in candidates {
if let Some(token) =
self.try_match_conditions(&prod.conditions, 0, Token::new(), first_wme)
{
let _ = token; triggered.push(prod.id.clone());
continue 'prod; }
}
}
triggered
}
fn try_match_conditions(
&self,
conds: &[Condition],
depth: usize,
token: Token,
wme: &WME,
) -> Option<Token> {
if !wme.matches_condition(&conds[depth]) {
return None;
}
let new_bindings = collect_bindings(&conds[depth], wme);
for (var, val) in &new_bindings {
if let Some(existing) = token.bindings.get(var.as_str()) {
if existing != val {
return None; }
}
}
let new_token = token.extend(wme.clone(), new_bindings);
if depth + 1 == conds.len() {
return Some(new_token);
}
let next_cond = &conds[depth + 1];
for candidate in self.wmes_matching_condition(next_cond) {
if let Some(t) =
self.try_match_conditions(conds, depth + 1, new_token.clone(), candidate)
{
return Some(t);
}
}
None
}
fn wmes_matching_condition<'a>(&'a self, cond: &Condition) -> Vec<&'a WME> {
self.all_wmes
.iter()
.filter(|w| w.matches_condition(cond))
.collect()
}
}
fn alpha_key(cond: &Condition) -> String {
let field = match cond.field {
CondField::Subject => "s",
CondField::Predicate => "p",
CondField::Object => "o",
};
let test = match &cond.test {
CondTest::Constant(c) => format!("={c}"),
CondTest::Variable(v) => format!("?{v}"),
CondTest::Any => "*".to_string(),
};
format!("{field}:{test}")
}
fn collect_bindings(cond: &Condition, wme: &WME) -> HashMap<String, String> {
let mut bindings = HashMap::new();
if let CondTest::Variable(name) = &cond.test {
let value = match cond.field {
CondField::Subject => wme.subject(),
CondField::Predicate => wme.predicate(),
CondField::Object => wme.object(),
};
bindings.insert(name.clone(), value.to_string());
}
bindings
}
#[cfg(test)]
mod tests {
use super::*;
fn wme(s: &str, p: &str, o: &str) -> WME {
WME::new(s, p, o)
}
fn const_cond(field: CondField, val: &str) -> Condition {
Condition::constant(field, val)
}
fn var_cond(field: CondField, name: &str) -> Condition {
Condition::variable(field, name)
}
fn any_cond(field: CondField) -> Condition {
Condition::any(field)
}
#[test]
fn test_wme_fields() {
let w = wme("alice", "knows", "bob");
assert_eq!(w.subject(), "alice");
assert_eq!(w.predicate(), "knows");
assert_eq!(w.object(), "bob");
assert_eq!(w.fields[0], "alice");
assert_eq!(w.fields[1], "knows");
assert_eq!(w.fields[2], "bob");
}
#[test]
fn test_wme_matches_constant_subject() {
let w = wme("alice", "knows", "bob");
let c = const_cond(CondField::Subject, "alice");
assert!(w.matches_condition(&c));
}
#[test]
fn test_wme_no_match_constant_subject() {
let w = wme("alice", "knows", "bob");
let c = const_cond(CondField::Subject, "carol");
assert!(!w.matches_condition(&c));
}
#[test]
fn test_wme_matches_variable_always() {
let w = wme("alice", "knows", "bob");
let c = var_cond(CondField::Subject, "x");
assert!(w.matches_condition(&c));
}
#[test]
fn test_wme_matches_any_always() {
let w = wme("alice", "knows", "bob");
assert!(w.matches_condition(&any_cond(CondField::Subject)));
assert!(w.matches_condition(&any_cond(CondField::Predicate)));
assert!(w.matches_condition(&any_cond(CondField::Object)));
}
#[test]
fn test_wme_matches_constant_predicate() {
let w = wme("a", "type", "Person");
assert!(w.matches_condition(&const_cond(CondField::Predicate, "type")));
assert!(!w.matches_condition(&const_cond(CondField::Predicate, "label")));
}
#[test]
fn test_wme_matches_constant_object() {
let w = wme("a", "type", "Person");
assert!(w.matches_condition(&const_cond(CondField::Object, "Person")));
assert!(!w.matches_condition(&const_cond(CondField::Object, "Animal")));
}
#[test]
fn test_new_network_empty() {
let net = ReteNetwork::new();
assert_eq!(net.production_count(), 0);
assert_eq!(net.wme_count(), 0);
}
#[test]
fn test_add_production_increases_count() {
let mut net = ReteNetwork::new();
net.add_production(Production::new("p1", vec![], "action1"));
assert_eq!(net.production_count(), 1);
}
#[test]
fn test_add_wme_increases_count() {
let mut net = ReteNetwork::new();
net.add_wme(wme("a", "b", "c"));
assert_eq!(net.wme_count(), 1);
}
#[test]
fn test_remove_wme_decreases_count() {
let mut net = ReteNetwork::new();
let w = wme("a", "b", "c");
net.add_wme(w.clone());
assert_eq!(net.wme_count(), 1);
assert!(net.remove_wme(&w));
assert_eq!(net.wme_count(), 0);
}
#[test]
fn test_remove_nonexistent_wme_returns_false() {
let mut net = ReteNetwork::new();
let w = wme("x", "y", "z");
assert!(!net.remove_wme(&w));
}
#[test]
fn test_clear_removes_all_wmes() {
let mut net = ReteNetwork::new();
net.add_wme(wme("a", "b", "c"));
net.add_wme(wme("d", "e", "f"));
net.clear();
assert_eq!(net.wme_count(), 0);
}
#[test]
fn test_clear_keeps_productions() {
let mut net = ReteNetwork::new();
net.add_production(Production::new("p1", vec![], "a"));
net.add_wme(wme("a", "b", "c"));
net.clear();
assert_eq!(net.production_count(), 1);
assert_eq!(net.wme_count(), 0);
}
#[test]
fn test_single_cond_production_fires() {
let mut net = ReteNetwork::new();
let cond = const_cond(CondField::Predicate, "type");
net.add_production(Production::new("p_type", vec![cond], "assert type"));
let triggered = net.add_wme(wme("alice", "type", "Person"));
assert!(triggered.contains(&"p_type".to_string()));
}
#[test]
fn test_single_cond_production_no_fire_on_mismatch() {
let mut net = ReteNetwork::new();
let cond = const_cond(CondField::Predicate, "type");
net.add_production(Production::new("p_type", vec![cond], "assert type"));
let triggered = net.add_wme(wme("alice", "label", "Alice"));
assert!(!triggered.contains(&"p_type".to_string()));
}
#[test]
fn test_variable_cond_always_fires() {
let mut net = ReteNetwork::new();
let cond = var_cond(CondField::Subject, "x");
net.add_production(Production::new("p_any_s", vec![cond], "any subject"));
let triggered = net.add_wme(wme("whatever", "p", "o"));
assert!(triggered.contains(&"p_any_s".to_string()));
}
#[test]
fn test_any_cond_always_fires() {
let mut net = ReteNetwork::new();
let cond = any_cond(CondField::Predicate);
net.add_production(Production::new("p_wildcard", vec![cond], "any"));
let triggered = net.add_wme(wme("s", "anything", "o"));
assert!(triggered.contains(&"p_wildcard".to_string()));
}
#[test]
fn test_two_cond_production_fires_when_both_satisfied() {
let mut net = ReteNetwork::new();
let conds = vec![
const_cond(CondField::Predicate, "type"),
const_cond(CondField::Object, "Person"),
];
net.add_production(Production::new("p_person", conds, "is person"));
net.add_wme(wme("alice", "type", "Person"));
let triggered = net.add_wme(wme("alice", "type", "Person")); let _ = triggered;
net.clear();
let t1 = net.add_wme(wme("alice", "type", "Person"));
assert!(t1.contains(&"p_person".to_string()));
}
#[test]
fn test_two_cond_production_needs_two_wmes() -> Result<(), Box<dyn std::error::Error>> {
let mut net = ReteNetwork::new();
let conds = vec![
const_cond(CondField::Predicate, "type"),
const_cond(CondField::Predicate, "knows"),
];
net.add_production(Production::new("p_knows_person", conds, "action"));
let t1 = net.add_wme(wme("alice", "type", "Person"));
assert!(!t1.contains(&"p_knows_person".to_string()));
let t2 = net.add_wme(wme("alice", "knows", "bob"));
assert!(t2.contains(&"p_knows_person".to_string()));
Ok(())
}
#[test]
fn test_three_cond_production() {
let mut net = ReteNetwork::new();
let conds = vec![
const_cond(CondField::Predicate, "a"),
const_cond(CondField::Predicate, "b"),
const_cond(CondField::Predicate, "c"),
];
net.add_production(Production::new("p_three", conds, "action"));
net.add_wme(wme("s1", "a", "o1"));
net.add_wme(wme("s2", "b", "o2"));
let t3 = net.add_wme(wme("s3", "c", "o3"));
assert!(t3.contains(&"p_three".to_string()));
}
#[test]
fn test_multiple_productions_can_fire() {
let mut net = ReteNetwork::new();
net.add_production(Production::new(
"p1",
vec![const_cond(CondField::Predicate, "type")],
"a1",
));
net.add_production(Production::new(
"p2",
vec![any_cond(CondField::Subject)],
"a2",
));
let triggered = net.add_wme(wme("alice", "type", "Person"));
assert!(triggered.contains(&"p1".to_string()));
assert!(triggered.contains(&"p2".to_string()));
}
#[test]
fn test_only_matching_productions_trigger() {
let mut net = ReteNetwork::new();
net.add_production(Production::new(
"p_type",
vec![const_cond(CondField::Predicate, "type")],
"a",
));
net.add_production(Production::new(
"p_label",
vec![const_cond(CondField::Predicate, "label")],
"b",
));
let triggered = net.add_wme(wme("x", "type", "Y"));
assert!(triggered.contains(&"p_type".to_string()));
assert!(!triggered.contains(&"p_label".to_string()));
}
#[test]
fn test_alpha_memory_populated_after_add_wme() -> Result<(), Box<dyn std::error::Error>> {
let mut net = ReteNetwork::new();
let cond = const_cond(CondField::Predicate, "type");
net.add_production(Production::new("p", vec![cond.clone()], ""));
net.add_wme(wme("alice", "type", "Person"));
let key = super::alpha_key(&cond);
let am = net.get_alpha_memory(&key);
assert!(am.is_some());
assert_eq!(am.ok_or("expected Some value")?.len(), 1);
Ok(())
}
#[test]
fn test_get_alpha_memory_returns_none_for_unknown_key() {
let net = ReteNetwork::new();
assert!(net.get_alpha_memory("nonexistent").is_none());
}
#[test]
fn test_remove_wme_from_alpha_memory() {
let mut net = ReteNetwork::new();
let cond = const_cond(CondField::Predicate, "type");
net.add_production(Production::new("p", vec![cond.clone()], ""));
let w = wme("alice", "type", "Person");
net.add_wme(w.clone());
net.remove_wme(&w);
let key = super::alpha_key(&cond);
if let Some(am) = net.get_alpha_memory(&key) {
assert_eq!(am.len(), 0);
}
}
#[test]
fn test_variable_binding_in_token() {
let w = wme("alice", "type", "Person");
let cond = var_cond(CondField::Subject, "x");
let bindings = super::collect_bindings(&cond, &w);
assert_eq!(bindings.get("x"), Some(&"alice".to_string()));
}
#[test]
fn test_variable_binding_predicate() {
let w = wme("a", "knows", "b");
let cond = var_cond(CondField::Predicate, "pred");
let bindings = super::collect_bindings(&cond, &w);
assert_eq!(bindings.get("pred"), Some(&"knows".to_string()));
}
#[test]
fn test_constant_cond_no_bindings() {
let w = wme("a", "b", "c");
let cond = const_cond(CondField::Subject, "a");
let bindings = super::collect_bindings(&cond, &w);
assert!(bindings.is_empty());
}
#[test]
fn test_no_trigger_when_only_partial_match() {
let mut net = ReteNetwork::new();
let conds = vec![
const_cond(CondField::Predicate, "type"),
const_cond(CondField::Predicate, "knows"),
];
net.add_production(Production::new("p", conds, "action"));
let triggered = net.add_wme(wme("alice", "type", "Person"));
assert!(!triggered.contains(&"p".to_string()));
}
#[test]
fn test_zero_condition_production_fires_on_any_wme() {
let mut net = ReteNetwork::new();
net.add_production(Production::new("p_empty", vec![], "always fires"));
let triggered = net.add_wme(wme("anything", "here", "works"));
assert!(triggered.contains(&"p_empty".to_string()));
}
#[test]
fn test_alpha_key_constant() {
let c = const_cond(CondField::Subject, "alice");
assert_eq!(super::alpha_key(&c), "s:=alice");
}
#[test]
fn test_alpha_key_variable() -> Result<(), Box<dyn std::error::Error>> {
let c = var_cond(CondField::Predicate, "p");
assert_eq!(super::alpha_key(&c), "p:?p");
Ok(())
}
#[test]
fn test_alpha_key_any() {
let c = any_cond(CondField::Object);
assert_eq!(super::alpha_key(&c), "o:*");
}
#[test]
fn test_alpha_memory_add_remove() {
let mut am = AlphaMemory::new();
let w = wme("a", "b", "c");
am.add(w.clone());
assert_eq!(am.len(), 1);
assert!(am.remove(&w));
assert!(am.is_empty());
}
#[test]
fn test_alpha_memory_remove_missing() {
let mut am = AlphaMemory::new();
let w = wme("a", "b", "c");
assert!(!am.remove(&w));
}
#[test]
fn test_alpha_memory_default() {
let am = AlphaMemory::default();
assert!(am.is_empty());
}
#[test]
fn test_token_extend() {
let t = Token::new();
let w = wme("s", "p", "o");
let mut bindings = HashMap::new();
bindings.insert("x".to_string(), "s".to_string());
let t2 = t.extend(w.clone(), bindings);
assert_eq!(t2.wmes.len(), 1);
assert_eq!(t2.bindings.get("x"), Some(&"s".to_string()));
}
#[test]
fn test_token_default() {
let t = Token::default();
assert!(t.wmes.is_empty());
assert!(t.bindings.is_empty());
}
#[test]
fn test_rete_network_default() {
let net = ReteNetwork::default();
assert_eq!(net.production_count(), 0);
}
#[test]
fn test_production_id_preserved() {
let mut net = ReteNetwork::new();
let prod = Production::new("unique_id_42", vec![], "act");
net.add_production(prod);
assert_eq!(net.productions[0].id, "unique_id_42");
}
#[test]
fn test_add_many_wmes() {
let mut net = ReteNetwork::new();
for i in 0..20_u32 {
net.add_wme(wme(
format!("s{i}").as_str(),
"pred",
format!("o{i}").as_str(),
));
}
assert_eq!(net.wme_count(), 20);
}
#[test]
fn test_production_action_field() {
let p = Production::new("p1", vec![], "ASSERT(fact)");
assert_eq!(p.action, "ASSERT(fact)");
}
}