use super::reference::RefRelationStore;
use crate::ast::{BodyLiteral, Rule, Term};
use std::collections::BTreeMap;
use xlog_core::ScalarType;
#[derive(Debug, Clone, PartialEq)]
pub enum InferenceError {
ConflictingPredicateColumnType {
predicate: String,
column: usize,
first_rule_index: usize,
first_type: ScalarType,
second_rule_index: usize,
second_type: ScalarType,
},
}
pub type InferredSchemas = BTreeMap<String, Vec<Option<ScalarType>>>;
pub fn infer_scc_predicate_schemas(
rules: &BTreeMap<String, Vec<Rule>>,
base_relations: &RefRelationStore,
) -> Result<InferredSchemas, InferenceError> {
let mut schemas: InferredSchemas = BTreeMap::new();
for (predicate, group) in rules.iter() {
let arity = group
.iter()
.find(|r| !r.head.terms.is_empty())
.map(|r| r.head.terms.len())
.unwrap_or(0);
schemas.insert(predicate.clone(), vec![None; arity]);
}
let mut origins: BTreeMap<(String, usize), usize> = BTreeMap::new();
let total_columns: usize = schemas.values().map(|s| s.len()).sum();
let max_iterations = total_columns + 1;
let mut converged = false;
for _ in 0..max_iterations {
let mut changed = false;
for (predicate, group) in rules.iter() {
for (rule_index, rule) in group.iter().enumerate() {
let var_types = derive_rule_var_types(rule, base_relations, &schemas);
for (col, term) in rule.head.terms.iter().enumerate() {
let name = match term {
Term::Variable(n) => n,
_ => continue,
};
let Some(&derived) = var_types.get(name) else {
continue;
};
let schema = schemas
.get_mut(predicate)
.expect("predicate in initialized schemas");
if col >= schema.len() {
continue;
}
match schema[col] {
None => {
schema[col] = Some(derived);
origins.insert((predicate.clone(), col), rule_index);
changed = true;
}
Some(existing) if existing == derived => {
}
Some(existing) => {
let first_rule_index =
origins.get(&(predicate.clone(), col)).copied().unwrap_or(0);
return Err(InferenceError::ConflictingPredicateColumnType {
predicate: predicate.clone(),
column: col,
first_rule_index,
first_type: existing,
second_rule_index: rule_index,
second_type: derived,
});
}
}
}
}
}
if !changed {
converged = true;
break;
}
}
debug_assert!(
converged,
"type inference failed to converge within {max_iterations} iterations \
(monotonicity invariant violated)"
);
Ok(schemas)
}
fn derive_rule_var_types(
rule: &Rule,
base_relations: &RefRelationStore,
inferred: &InferredSchemas,
) -> BTreeMap<String, ScalarType> {
let mut var_types: BTreeMap<String, ScalarType> = BTreeMap::new();
for literal in &rule.body {
let body_atom = match literal {
BodyLiteral::Positive(a) => a,
_ => continue,
};
let schema_opt: Option<&[Option<ScalarType>]> =
if let Some(rel) = base_relations.get(&body_atom.predicate) {
let limit = body_atom.terms.len().min(rel.schema.len());
for (pos, term) in body_atom.terms[..limit].iter().enumerate() {
if let Term::Variable(name) = term {
var_types.entry(name.clone()).or_insert(rel.schema[pos]);
}
}
None
} else {
inferred.get(&body_atom.predicate).map(|v| v.as_slice())
};
if let Some(schema) = schema_opt {
let limit = body_atom.terms.len().min(schema.len());
for (pos, term) in body_atom.terms[..limit].iter().enumerate() {
if let Term::Variable(name) = term {
if let Some(ty) = schema[pos] {
var_types.entry(name.clone()).or_insert(ty);
}
}
}
}
}
var_types
}
pub(super) fn derive_vertex_types_with_inference(
rule: &Rule,
base_relations: &RefRelationStore,
inferred_schemas: &InferredSchemas,
) -> Result<BTreeMap<String, ScalarType>, super::RefEvalError> {
struct FirstSite {
predicate: String,
position: usize,
ty: ScalarType,
}
let mut sites: BTreeMap<String, FirstSite> = BTreeMap::new();
for literal in &rule.body {
let body_atom = match literal {
BodyLiteral::Positive(a) => a,
_ => continue,
};
let position_types: Vec<Option<ScalarType>> =
if let Some(rel) = base_relations.get(&body_atom.predicate) {
let limit = body_atom.terms.len().min(rel.schema.len());
let mut v: Vec<Option<ScalarType>> = vec![None; body_atom.terms.len()];
for (pos_idx, slot) in v.iter_mut().enumerate().take(limit) {
*slot = Some(rel.schema[pos_idx]);
}
v
} else if let Some(schema) = inferred_schemas.get(&body_atom.predicate) {
let limit = body_atom.terms.len().min(schema.len());
let mut v: Vec<Option<ScalarType>> = vec![None; body_atom.terms.len()];
for (pos_idx, slot) in v.iter_mut().enumerate().take(limit) {
*slot = schema[pos_idx];
}
v
} else {
continue; };
for (position, term) in body_atom.terms.iter().enumerate() {
let var_name = match term {
Term::Variable(name) => name.clone(),
_ => continue,
};
let Some(ty) = position_types[position] else {
continue;
};
match sites.get(&var_name) {
None => {
sites.insert(
var_name,
FirstSite {
predicate: body_atom.predicate.clone(),
position,
ty,
},
);
}
Some(prior) if prior.ty == ty => {
}
Some(prior) => {
return Err(super::RefEvalError::ConflictingVariableType {
var: var_name,
first_predicate: prior.predicate.clone(),
first_position: prior.position,
first_type: prior.ty,
second_predicate: body_atom.predicate.clone(),
second_position: position,
second_type: ty,
});
}
}
}
}
Ok(sites
.into_iter()
.map(|(name, site)| (name, site.ty))
.collect())
}