use super::Catalog;
use crate::catalog::{AtomArgumentSignature, AtomSignature, CatalogError};
use crate::parser::{Atom, AtomArg, FlowLogRule, Predicate};
impl Catalog {
pub(crate) fn map_modify(
&mut self,
atom_signature: AtomSignature,
new_atom_name: String,
new_atom_fingerprint: u64,
) -> Result<(), CatalogError> {
let rhs_index = self.rhs_index_from_signature(atom_signature);
let new_atom = match &self.rule.rhs()[rhs_index] {
Predicate::PositiveAtom(atom) => Predicate::PositiveAtom(Atom::new(
&new_atom_name,
atom.arguments().to_vec(),
new_atom_fingerprint,
)),
Predicate::NegativeAtom(atom) => Predicate::NegativeAtom(Atom::new(
&new_atom_name,
atom.arguments().to_vec(),
new_atom_fingerprint,
)),
other => {
return Err(CatalogError::internal(format!(
"map_modify: target predicate at rhs index {rhs_index} is not an atom: {other}"
)));
}
};
self.update_rule_in_place(rhs_index, new_atom)
}
pub(crate) fn projection_modify(
&mut self,
atom_signature: AtomSignature,
arguments_to_delete: Vec<AtomArgumentSignature>,
new_atom_name: String,
new_atom_fingerprint: u64,
) -> Result<(), CatalogError> {
for arg_sig in &arguments_to_delete {
if *arg_sig.atom_signature() != atom_signature {
return Err(CatalogError::internal(format!(
"projection_modify: argument signature {arg_sig:?} does not belong \
to target atom {atom_signature:?}"
)));
}
}
let rhs_index = self.rhs_index_from_signature(atom_signature);
let mut arg_ids_to_delete: Vec<usize> = arguments_to_delete
.iter()
.map(|s| s.argument_id())
.collect();
arg_ids_to_delete.sort_unstable();
arg_ids_to_delete.reverse();
let build_projected_atom = |atom: &Atom| -> Result<Atom, CatalogError> {
for &arg_id in &arg_ids_to_delete {
if arg_id >= atom.arity() {
return Err(CatalogError::internal(format!(
"projection_modify: argument id {arg_id} out of bounds for atom \
`{}` with arity {}",
atom.name(),
atom.arity()
)));
}
}
let mut new_args = atom.arguments().to_vec();
for &arg_id in &arg_ids_to_delete {
new_args.remove(arg_id);
}
Ok(Atom::new(&new_atom_name, new_args, new_atom_fingerprint))
};
let new_atom = match &self.rule.rhs()[rhs_index] {
Predicate::PositiveAtom(atom) => Predicate::PositiveAtom(build_projected_atom(atom)?),
Predicate::NegativeAtom(atom) => Predicate::NegativeAtom(build_projected_atom(atom)?),
other => {
return Err(CatalogError::internal(format!(
"projection_modify: target predicate at rhs index {rhs_index} \
is not an atom: {other}"
)));
}
};
self.update_rule_in_place(rhs_index, new_atom)
}
pub(crate) fn sip_modify(
&mut self,
right_atom_signature: AtomSignature,
new_argument_list: Vec<AtomArgumentSignature>,
new_atom_name: String,
new_atom_fingerprint: u64,
) -> Result<(), CatalogError> {
let rhs_index = self.rhs_index_from_signature(right_atom_signature);
if !matches!(self.rule.rhs()[rhs_index], Predicate::PositiveAtom(_)) {
return Err(CatalogError::internal(format!(
"sip_modify: target predicate at rhs index {rhs_index} is not a positive atom: {}",
self.rule.rhs()[rhs_index]
)));
}
let new_atom_args = self.lookup_arg_vars(&new_argument_list, "sip_modify")?;
let new_atom = Atom::new(&new_atom_name, new_atom_args, new_atom_fingerprint);
self.update_rule_in_place(rhs_index, Predicate::PositiveAtom(new_atom))
}
pub(crate) fn join_modify(
&mut self,
left_atom_signature: AtomSignature,
right_atom_signatures: Vec<AtomSignature>,
new_arguments_list: Vec<Vec<AtomArgumentSignature>>,
new_names: Vec<String>,
new_fingerprints: Vec<u64>,
) -> Result<(), CatalogError> {
let num_right_atoms = right_atom_signatures.len();
if new_arguments_list.len() != num_right_atoms
|| new_names.len() != num_right_atoms
|| new_fingerprints.len() != num_right_atoms
{
return Err(CatalogError::internal(format!(
"join_modify: parameter length mismatch — right_atom_signatures={}, \
new_arguments_list={}, new_names={}, new_fingerprints={}",
num_right_atoms,
new_arguments_list.len(),
new_names.len(),
new_fingerprints.len()
)));
}
let left_rhs_index = self.rhs_index_from_signature(left_atom_signature);
match &self.rule.rhs()[left_rhs_index] {
Predicate::PositiveAtom(_) | Predicate::NegativeAtom(_) => {}
other => {
return Err(CatalogError::internal(format!(
"join_modify: left predicate at rhs index {left_rhs_index} \
is not an atom: {other}"
)));
}
}
let right_indices =
self.validate_atom_rhs_indices(&right_atom_signatures, "join_modify")?;
let mut new_joined_atoms = Vec::with_capacity(num_right_atoms);
for i in 0..num_right_atoms {
let new_atom_args = self.lookup_arg_vars(&new_arguments_list[i], "join_modify")?;
let new_atom = Atom::new(&new_names[i], new_atom_args, new_fingerprints[i]);
new_joined_atoms.push(Predicate::PositiveAtom(new_atom));
}
self.remove_and_update_rule(left_rhs_index, right_indices, new_joined_atoms)
}
pub(crate) fn comparison_modify(
&mut self,
comparison_index: usize,
right_atom_signatures: Vec<AtomSignature>,
new_names: Vec<String>,
new_fingerprints: Vec<u64>,
) -> Result<(), CatalogError> {
let num_atoms = right_atom_signatures.len();
if new_names.len() != num_atoms || new_fingerprints.len() != num_atoms {
return Err(CatalogError::internal(format!(
"comparison_modify: parameter length mismatch — right_atom_signatures={}, \
new_names={}, new_fingerprints={}",
num_atoms,
new_names.len(),
new_fingerprints.len()
)));
}
let comparison_predicate = &self.comparison_predicates[comparison_index];
let comparison_rhs_index = self
.rule
.rhs()
.iter()
.enumerate()
.find_map(|(idx, p)| match p {
Predicate::Compare(expr) if expr == comparison_predicate => Some(idx),
_ => None,
})
.ok_or_else(|| {
CatalogError::internal(format!(
"comparison_modify: comparison predicate at index {comparison_index} \
not found in rule RHS"
))
})?;
let right_indices =
self.validate_atom_rhs_indices(&right_atom_signatures, "comparison_modify")?;
let new_filtered_atoms = self.build_renamed_atom_copies(
&right_indices,
&new_names,
&new_fingerprints,
"comparison_modify",
)?;
self.remove_and_update_rule(comparison_rhs_index, right_indices, new_filtered_atoms)
}
pub(crate) fn fn_call_modify(
&mut self,
fn_call_index: usize,
right_atom_signatures: Vec<AtomSignature>,
new_names: Vec<String>,
new_fingerprints: Vec<u64>,
) -> Result<(), CatalogError> {
let num_atoms = right_atom_signatures.len();
if new_names.len() != num_atoms || new_fingerprints.len() != num_atoms {
return Err(CatalogError::internal(format!(
"fn_call_modify: parameter length mismatch — right_atom_signatures={}, \
new_names={}, new_fingerprints={}",
num_atoms,
new_names.len(),
new_fingerprints.len()
)));
}
let fn_call_predicate = &self.fn_call_predicates[fn_call_index];
let fn_call_rhs_index = self
.rule
.rhs()
.iter()
.enumerate()
.find_map(|(idx, p)| match p {
Predicate::FnCall(fc) if fc == fn_call_predicate => Some(idx),
_ => None,
})
.ok_or_else(|| {
CatalogError::internal(format!(
"fn_call_modify: fn_call predicate at index {fn_call_index} \
not found in rule RHS"
))
})?;
let right_indices =
self.validate_atom_rhs_indices(&right_atom_signatures, "fn_call_modify")?;
let new_filtered_atoms = self.build_renamed_atom_copies(
&right_indices,
&new_names,
&new_fingerprints,
"fn_call_modify",
)?;
self.remove_and_update_rule(fn_call_rhs_index, right_indices, new_filtered_atoms)
}
fn validate_atom_rhs_indices(
&self,
signatures: &[AtomSignature],
context: &str,
) -> Result<Vec<usize>, CatalogError> {
signatures
.iter()
.map(|&sig| {
let idx = self.rhs_index_from_signature(sig);
match &self.rule.rhs()[idx] {
Predicate::PositiveAtom(_) | Predicate::NegativeAtom(_) => Ok(idx),
other => Err(CatalogError::internal(format!(
"{context}: right predicate at rhs index {idx} is not an atom: {other}"
))),
}
})
.collect()
}
fn lookup_arg_vars(
&self,
signatures: &[AtomArgumentSignature],
context: &str,
) -> Result<Vec<AtomArg>, CatalogError> {
signatures
.iter()
.map(|arg_sig| {
self.signature_to_argument_str_map
.get(arg_sig)
.cloned()
.map(AtomArg::Var)
.ok_or_else(|| {
CatalogError::internal(format!(
"{context}: argument signature {arg_sig:?} not found in signature map"
))
})
})
.collect()
}
fn build_renamed_atom_copies(
&self,
indices: &[usize],
new_names: &[String],
new_fingerprints: &[u64],
context: &str,
) -> Result<Vec<Predicate>, CatalogError> {
indices
.iter()
.enumerate()
.map(|(i, &atom_idx)| {
let args = match &self.rule.rhs()[atom_idx] {
Predicate::PositiveAtom(atom) | Predicate::NegativeAtom(atom) => {
atom.arguments().to_vec()
}
other => {
return Err(CatalogError::internal(format!(
"{context}: expected atom predicate at rhs index {atom_idx}, got: {other}"
)));
}
};
let new_atom = Atom::new(&new_names[i], args, new_fingerprints[i]);
Ok(Predicate::PositiveAtom(new_atom))
})
.collect()
}
fn update_rule_in_place(
&mut self,
global_rhs_idx: usize,
new_predicate: Predicate,
) -> Result<(), CatalogError> {
let mut new_rhs = self.rule.rhs().to_vec();
new_rhs[global_rhs_idx] = new_predicate;
let new_rule = FlowLogRule::new(self.rule.head().clone(), new_rhs);
self.update_rule(&new_rule)
}
fn remove_and_update_rule(
&mut self,
global_rhs_index_to_remove: usize,
global_rhs_indices_to_update: Vec<usize>,
new_predicates: Vec<Predicate>,
) -> Result<(), CatalogError> {
let mut new_rhs = self.rule.rhs().to_vec();
for (idx, pred) in global_rhs_indices_to_update.into_iter().zip(new_predicates) {
new_rhs[idx] = pred;
}
new_rhs.remove(global_rhs_index_to_remove);
let new_rule = FlowLogRule::new(self.rule.head().clone(), new_rhs);
self.update_rule(&new_rule)
}
}