use ipfrs_core::Cid;
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum Term {
Var(String),
Const(Constant),
Fun(String, Vec<Term>),
Ref(TermRef),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum Constant {
String(String),
Int(i64),
Bool(bool),
Float(String),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct TermRef {
#[serde(
serialize_with = "crate::serialize_cid",
deserialize_with = "crate::deserialize_cid"
)]
pub cid: Cid,
pub hint: Option<String>,
}
impl TermRef {
pub fn new(cid: Cid) -> Self {
Self { cid, hint: None }
}
pub fn with_hint(cid: Cid, hint: String) -> Self {
Self {
cid,
hint: Some(hint),
}
}
}
impl Term {
#[inline]
pub fn is_var(&self) -> bool {
matches!(self, Term::Var(_))
}
#[inline]
pub fn is_const(&self) -> bool {
matches!(self, Term::Const(_))
}
#[inline]
pub fn is_ground(&self) -> bool {
match self {
Term::Var(_) => false,
Term::Const(_) => true,
Term::Fun(_, args) => args.iter().all(|t| t.is_ground()),
Term::Ref(_) => true, }
}
pub fn variables(&self) -> Vec<String> {
let capacity = self.estimate_var_count();
let mut vars = Vec::with_capacity(capacity);
self.collect_vars(&mut vars);
vars.sort_unstable();
vars.dedup();
vars
}
#[inline]
fn estimate_var_count(&self) -> usize {
match self {
Term::Var(_) => 1,
Term::Const(_) | Term::Ref(_) => 0,
Term::Fun(_, args) => args.iter().map(|t| t.estimate_var_count()).sum(),
}
}
#[inline]
fn collect_vars(&self, vars: &mut Vec<String>) {
match self {
Term::Var(v) => vars.push(v.clone()),
Term::Fun(_, args) => {
for arg in args {
arg.collect_vars(vars);
}
}
_ => {}
}
}
#[inline]
pub fn complexity(&self) -> usize {
match self {
Term::Var(_) | Term::Const(_) | Term::Ref(_) => 1,
Term::Fun(_, args) => 1 + args.iter().map(|t| t.complexity()).sum::<usize>(),
}
}
}
impl fmt::Display for Term {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Term::Var(v) => write!(f, "?{}", v),
Term::Const(c) => write!(f, "{}", c),
Term::Fun(name, args) => {
write!(f, "{}(", name)?;
for (i, arg) in args.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", arg)?;
}
write!(f, ")")
}
Term::Ref(r) => write!(f, "@{}", r.cid),
}
}
}
impl fmt::Display for Constant {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Constant::String(s) => write!(f, "\"{}\"", s),
Constant::Int(i) => write!(f, "{}", i),
Constant::Bool(b) => write!(f, "{}", b),
Constant::Float(s) => write!(f, "{}", s),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Predicate {
pub name: String,
pub args: Vec<Term>,
}
impl Predicate {
pub fn new(name: String, args: Vec<Term>) -> Self {
Self { name, args }
}
#[inline]
pub fn arity(&self) -> usize {
self.args.len()
}
#[inline]
pub fn is_ground(&self) -> bool {
self.args.iter().all(|t| t.is_ground())
}
pub fn variables(&self) -> Vec<String> {
let capacity: usize = self.args.iter().map(|t| t.estimate_var_count()).sum();
let mut vars = Vec::with_capacity(capacity);
for arg in &self.args {
arg.collect_vars(&mut vars);
}
vars.sort_unstable();
vars.dedup();
vars
}
}
impl fmt::Display for Predicate {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}(", self.name)?;
for (i, arg) in self.args.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", arg)?;
}
write!(f, ")")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Rule {
pub head: Predicate,
pub body: Vec<Predicate>,
}
impl Rule {
pub fn new(head: Predicate, body: Vec<Predicate>) -> Self {
Self { head, body }
}
pub fn fact(head: Predicate) -> Self {
Self {
head,
body: Vec::new(),
}
}
#[inline]
pub fn is_fact(&self) -> bool {
self.body.is_empty()
}
pub fn variables(&self) -> Vec<String> {
let mut vars = self.head.variables();
for pred in &self.body {
for var in pred.variables() {
if !vars.contains(&var) {
vars.push(var);
}
}
}
vars.sort_unstable();
vars
}
}
impl fmt::Display for Rule {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.head)?;
if !self.body.is_empty() {
write!(f, " :- ")?;
for (i, pred) in self.body.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", pred)?;
}
}
write!(f, ".")
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct KnowledgeBase {
pub facts: Vec<Predicate>,
pub rules: Vec<Rule>,
}
impl KnowledgeBase {
pub fn new() -> Self {
Self::default()
}
pub fn add_fact(&mut self, fact: Predicate) {
self.facts.push(fact);
}
pub fn add_rule(&mut self, rule: Rule) {
self.rules.push(rule);
}
#[inline]
pub fn get_predicates(&self, name: &str) -> Vec<&Predicate> {
self.facts.iter().filter(|p| p.name == name).collect()
}
#[inline]
pub fn get_rules(&self, name: &str) -> Vec<&Rule> {
self.rules.iter().filter(|r| r.head.name == name).collect()
}
pub fn stats(&self) -> KnowledgeBaseStats {
KnowledgeBaseStats {
num_facts: self.facts.len(),
num_rules: self.rules.len(),
}
}
}
#[derive(Debug, Clone)]
pub struct KnowledgeBaseStats {
pub num_facts: usize,
pub num_rules: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_term_creation() {
let var = Term::Var("X".to_string());
assert!(var.is_var());
assert!(!var.is_ground());
let const_term = Term::Const(Constant::String("Alice".to_string()));
assert!(const_term.is_const());
assert!(const_term.is_ground());
}
#[test]
fn test_predicate() {
let pred = Predicate::new(
"parent".to_string(),
vec![
Term::Const(Constant::String("Alice".to_string())),
Term::Var("X".to_string()),
],
);
assert_eq!(pred.arity(), 2);
assert!(!pred.is_ground());
assert_eq!(pred.variables(), vec!["X".to_string()]);
}
#[test]
fn test_rule() {
let head = Predicate::new(
"grandparent".to_string(),
vec![Term::Var("X".to_string()), Term::Var("Z".to_string())],
);
let body = vec![
Predicate::new(
"parent".to_string(),
vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
),
Predicate::new(
"parent".to_string(),
vec![Term::Var("Y".to_string()), Term::Var("Z".to_string())],
),
];
let rule = Rule::new(head, body);
assert!(!rule.is_fact());
assert_eq!(
rule.variables(),
vec!["X".to_string(), "Y".to_string(), "Z".to_string()]
);
}
#[test]
fn test_knowledge_base() {
let mut kb = KnowledgeBase::new();
kb.add_fact(Predicate::new(
"parent".to_string(),
vec![
Term::Const(Constant::String("Alice".to_string())),
Term::Const(Constant::String("Bob".to_string())),
],
));
let stats = kb.stats();
assert_eq!(stats.num_facts, 1);
assert_eq!(stats.num_rules, 0);
}
}