use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tensorlogic_ir::TLExpr;
use crate::error::AdapterError;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CompositePredicate {
pub name: String,
pub parameters: Vec<String>,
pub body: PredicateBody,
pub description: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum PredicateBody {
Expression(Box<TLExpr>),
Reference { name: String, args: Vec<String> },
And(Vec<PredicateBody>),
Or(Vec<PredicateBody>),
Not(Box<PredicateBody>),
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct CompositeRegistry {
predicates: HashMap<String, CompositePredicate>,
}
impl CompositePredicate {
pub fn new(name: impl Into<String>, parameters: Vec<String>, body: PredicateBody) -> Self {
CompositePredicate {
name: name.into(),
parameters,
body,
description: None,
}
}
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
pub fn arity(&self) -> usize {
self.parameters.len()
}
pub fn validate(&self) -> Result<(), AdapterError> {
let mut seen = std::collections::HashSet::new();
for param in &self.parameters {
if !seen.insert(param) {
return Err(AdapterError::InvalidParametricType(format!(
"Duplicate parameter '{}' in predicate '{}'",
param, self.name
)));
}
}
self.body.validate(&self.parameters)?;
Ok(())
}
pub fn expand(&self, args: &[String]) -> Result<PredicateBody, AdapterError> {
if args.len() != self.parameters.len() {
return Err(AdapterError::ArityMismatch {
name: self.name.clone(),
expected: self.parameters.len(),
found: args.len(),
});
}
let mut substitutions = HashMap::new();
for (param, arg) in self.parameters.iter().zip(args.iter()) {
substitutions.insert(param.clone(), arg.clone());
}
self.body.substitute(&substitutions)
}
}
impl PredicateBody {
fn validate(&self, parameters: &[String]) -> Result<(), AdapterError> {
match self {
PredicateBody::Expression(_) => Ok(()), PredicateBody::Reference { args, .. } => {
for arg in args {
if !parameters.contains(arg) && !arg.starts_with('_') {
return Err(AdapterError::UnboundVariable(arg.clone()));
}
}
Ok(())
}
PredicateBody::And(bodies) | PredicateBody::Or(bodies) => {
for body in bodies {
body.validate(parameters)?;
}
Ok(())
}
PredicateBody::Not(body) => body.validate(parameters),
}
}
fn substitute(
&self,
substitutions: &HashMap<String, String>,
) -> Result<PredicateBody, AdapterError> {
match self {
PredicateBody::Expression(expr) => {
Ok(PredicateBody::Expression(expr.clone()))
}
PredicateBody::Reference { name, args } => {
let new_args = args
.iter()
.map(|arg| {
substitutions
.get(arg)
.cloned()
.unwrap_or_else(|| arg.clone())
})
.collect();
Ok(PredicateBody::Reference {
name: name.clone(),
args: new_args,
})
}
PredicateBody::And(bodies) => {
let new_bodies: Result<Vec<_>, _> =
bodies.iter().map(|b| b.substitute(substitutions)).collect();
Ok(PredicateBody::And(new_bodies?))
}
PredicateBody::Or(bodies) => {
let new_bodies: Result<Vec<_>, _> =
bodies.iter().map(|b| b.substitute(substitutions)).collect();
Ok(PredicateBody::Or(new_bodies?))
}
PredicateBody::Not(body) => Ok(PredicateBody::Not(Box::new(
body.substitute(substitutions)?,
))),
}
}
}
impl CompositeRegistry {
pub fn new() -> Self {
CompositeRegistry::default()
}
pub fn register(&mut self, predicate: CompositePredicate) -> Result<(), AdapterError> {
predicate.validate()?;
self.predicates.insert(predicate.name.clone(), predicate);
Ok(())
}
pub fn get(&self, name: &str) -> Option<&CompositePredicate> {
self.predicates.get(name)
}
pub fn contains(&self, name: &str) -> bool {
self.predicates.contains_key(name)
}
pub fn expand(&self, name: &str, args: &[String]) -> Result<PredicateBody, AdapterError> {
let predicate = self
.get(name)
.ok_or_else(|| AdapterError::PredicateNotFound(name.to_string()))?;
predicate.expand(args)
}
pub fn len(&self) -> usize {
self.predicates.len()
}
pub fn is_empty(&self) -> bool {
self.predicates.is_empty()
}
pub fn list_predicates(&self) -> Vec<String> {
self.predicates.keys().cloned().collect()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PredicateTemplate {
pub name: String,
pub type_params: Vec<String>,
pub value_params: Vec<String>,
pub body: PredicateBody,
}
impl PredicateTemplate {
pub fn new(
name: impl Into<String>,
type_params: Vec<String>,
value_params: Vec<String>,
body: PredicateBody,
) -> Self {
PredicateTemplate {
name: name.into(),
type_params,
value_params,
body,
}
}
pub fn instantiate(
&self,
type_args: &[String],
value_args: &[String],
) -> Result<CompositePredicate, AdapterError> {
if type_args.len() != self.type_params.len() {
return Err(AdapterError::ArityMismatch {
name: format!("{}[type params]", self.name),
expected: self.type_params.len(),
found: type_args.len(),
});
}
if value_args.len() != self.value_params.len() {
return Err(AdapterError::ArityMismatch {
name: format!("{}[value params]", self.name),
expected: self.value_params.len(),
found: value_args.len(),
});
}
let mut substitutions = HashMap::new();
for (param, arg) in self.type_params.iter().zip(type_args.iter()) {
substitutions.insert(param.clone(), arg.clone());
}
for (param, arg) in self.value_params.iter().zip(value_args.iter()) {
substitutions.insert(param.clone(), arg.clone());
}
let instance_name = format!("{}<{}>", self.name, type_args.join(", "));
let instance_body = self.body.substitute(&substitutions)?;
Ok(CompositePredicate {
name: instance_name,
parameters: value_args.to_vec(),
body: instance_body,
description: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_composite_predicate_creation() {
let pred = CompositePredicate::new(
"friend",
vec!["x".to_string(), "y".to_string()],
PredicateBody::Reference {
name: "knows".to_string(),
args: vec!["x".to_string(), "y".to_string()],
},
);
assert_eq!(pred.name, "friend");
assert_eq!(pred.arity(), 2);
}
#[test]
fn test_composite_predicate_validation() {
let valid = CompositePredicate::new(
"test",
vec!["x".to_string(), "y".to_string()],
PredicateBody::Reference {
name: "knows".to_string(),
args: vec!["x".to_string(), "y".to_string()],
},
);
assert!(valid.validate().is_ok());
let invalid = CompositePredicate::new(
"test",
vec!["x".to_string(), "x".to_string()], PredicateBody::Reference {
name: "knows".to_string(),
args: vec!["x".to_string()],
},
);
assert!(invalid.validate().is_err());
}
#[test]
fn test_composite_registry() {
let mut registry = CompositeRegistry::new();
let pred = CompositePredicate::new(
"friend",
vec!["x".to_string(), "y".to_string()],
PredicateBody::Reference {
name: "knows".to_string(),
args: vec!["x".to_string(), "y".to_string()],
},
);
registry.register(pred).expect("unwrap");
assert!(registry.contains("friend"));
assert_eq!(registry.len(), 1);
}
#[test]
fn test_predicate_expansion() {
let pred = CompositePredicate::new(
"friend",
vec!["x".to_string(), "y".to_string()],
PredicateBody::Reference {
name: "knows".to_string(),
args: vec!["x".to_string(), "y".to_string()],
},
);
let expanded = pred
.expand(&["alice".to_string(), "bob".to_string()])
.expect("unwrap");
match expanded {
PredicateBody::Reference { name, args } => {
assert_eq!(name, "knows");
assert_eq!(args, vec!["alice".to_string(), "bob".to_string()]);
}
_ => panic!("Expected Reference"),
}
}
#[test]
fn test_predicate_template() {
let template = PredicateTemplate::new(
"related",
vec!["T".to_string()],
vec!["x".to_string(), "y".to_string()],
PredicateBody::Reference {
name: "connected".to_string(),
args: vec!["x".to_string(), "y".to_string()],
},
);
let instance = template
.instantiate(&["Person".to_string()], &["a".to_string(), "b".to_string()])
.expect("unwrap");
assert_eq!(instance.name, "related<Person>");
assert_eq!(instance.parameters, vec!["a".to_string(), "b".to_string()]);
}
#[test]
fn test_composite_and() {
let body = PredicateBody::And(vec![
PredicateBody::Reference {
name: "knows".to_string(),
args: vec!["x".to_string(), "y".to_string()],
},
PredicateBody::Reference {
name: "trusts".to_string(),
args: vec!["x".to_string(), "y".to_string()],
},
]);
let pred = CompositePredicate::new("friend", vec!["x".to_string(), "y".to_string()], body);
assert!(pred.validate().is_ok());
}
#[test]
fn test_composite_or() {
let body = PredicateBody::Or(vec![
PredicateBody::Reference {
name: "colleague".to_string(),
args: vec!["x".to_string(), "y".to_string()],
},
PredicateBody::Reference {
name: "friend".to_string(),
args: vec!["x".to_string(), "y".to_string()],
},
]);
let pred =
CompositePredicate::new("connected", vec!["x".to_string(), "y".to_string()], body);
assert!(pred.validate().is_ok());
}
#[test]
fn test_composite_not() {
let body = PredicateBody::Not(Box::new(PredicateBody::Reference {
name: "enemy".to_string(),
args: vec!["x".to_string(), "y".to_string()],
}));
let pred =
CompositePredicate::new("not_enemy", vec!["x".to_string(), "y".to_string()], body);
assert!(pred.validate().is_ok());
}
}