use std::collections::HashMap;
use panproto_gat::Name;
use panproto_inst::CompiledMigration;
use panproto_schema::Edge;
use crate::Lens;
use crate::error::LensError;
pub fn compose(l1: &Lens, l2: &Lens) -> Result<Lens, LensError> {
if l1.tgt_schema.vertex_count() != l2.src_schema.vertex_count()
|| l1.tgt_schema.protocol != l2.src_schema.protocol
{
return Err(LensError::CompositionMismatch);
}
if l1
.tgt_schema
.vertices
.keys()
.collect::<std::collections::BTreeSet<_>>()
!= l2
.src_schema
.vertices
.keys()
.collect::<std::collections::BTreeSet<_>>()
{
return Err(LensError::CompositionMismatch);
}
let compiled = compose_compiled_migrations(&l1.compiled, &l2.compiled);
Ok(Lens {
compiled,
src_schema: l1.src_schema.clone(),
tgt_schema: l2.tgt_schema.clone(),
})
}
pub(crate) fn compose_compiled_migrations(
m1: &CompiledMigration,
m2: &CompiledMigration,
) -> CompiledMigration {
let mut surviving_verts = std::collections::HashSet::new();
for v in &m1.surviving_verts {
let remapped = m1.vertex_remap.get(v).unwrap_or(v);
if m2.surviving_verts.contains(remapped) {
surviving_verts.insert(v.clone());
}
}
let mut surviving_edges = std::collections::HashSet::new();
for e in &m1.surviving_edges {
let remapped = m1.edge_remap.get(e).unwrap_or(e);
if m2.surviving_edges.contains(remapped) {
surviving_edges.insert(e.clone());
}
}
let mut vertex_remap = HashMap::new();
for (src, mid) in &m1.vertex_remap {
let final_v = m2.vertex_remap.get(mid).unwrap_or(mid).clone();
vertex_remap.insert(src.clone(), final_v);
}
for (mid, tgt) in &m2.vertex_remap {
if !m1.vertex_remap.values().any(|v| v == mid) {
vertex_remap
.entry(mid.clone())
.or_insert_with(|| tgt.clone());
}
}
let mut edge_remap: HashMap<Edge, Edge> = HashMap::new();
for (src_e, mid_e) in &m1.edge_remap {
let final_e = m2.edge_remap.get(mid_e).unwrap_or(mid_e).clone();
edge_remap.insert(src_e.clone(), final_e);
}
let mut resolver = m1.resolver.clone();
for (k, v) in &m2.resolver {
resolver.entry(k.clone()).or_insert_with(|| v.clone());
}
let mut hyper_resolver = m1.hyper_resolver.clone();
for (k, v) in &m2.hyper_resolver {
hyper_resolver.entry(k.clone()).or_insert_with(|| v.clone());
}
let field_transforms = compose_field_transforms(m1, m2);
let conditional_survival = compose_conditional_survival(m1, m2);
let mut expansion_path: HashMap<(Name, Name), Vec<Name>> = HashMap::new();
for ((src, tgt), mids) in &m1.expansion_path {
let remapped_tgt = m1.vertex_remap.get(tgt).unwrap_or(tgt);
let mut found_chain = false;
for ((m2_src, m2_tgt), m2_mids) in &m2.expansion_path {
if m2_src == remapped_tgt {
let mut combined = mids.clone();
combined.extend(m2_mids.iter().cloned());
expansion_path.insert((src.clone(), m2_tgt.clone()), combined);
found_chain = true;
}
}
if !found_chain {
expansion_path.insert((src.clone(), tgt.clone()), mids.clone());
}
}
for (k, v) in &m2.expansion_path {
expansion_path.entry(k.clone()).or_insert_with(|| v.clone());
}
CompiledMigration {
surviving_verts,
surviving_edges,
vertex_remap,
edge_remap,
resolver,
hyper_resolver,
field_transforms,
conditional_survival,
expansion_path,
}
}
fn unchanged_by_m1(m1: &CompiledMigration, name: &panproto_gat::Name) -> bool {
m1.vertex_remap.get(name).map_or_else(
|| m1.surviving_verts.contains(name),
|remapped| remapped == name,
)
}
fn compose_field_transforms(
m1: &CompiledMigration,
m2: &CompiledMigration,
) -> HashMap<panproto_gat::Name, Vec<panproto_inst::wtype::FieldTransform>> {
let mut result = m1.field_transforms.clone();
for (m2_anchor, m2_transforms) in &m2.field_transforms {
let mut found = false;
for (m1_src, m1_tgt) in &m1.vertex_remap {
if m1_tgt == m2_anchor {
result
.entry(m1_src.clone())
.or_default()
.extend(m2_transforms.iter().cloned());
found = true;
}
}
if !found && unchanged_by_m1(m1, m2_anchor) {
result
.entry(m2_anchor.clone())
.or_default()
.extend(m2_transforms.iter().cloned());
}
}
result
}
fn compose_conditional_survival(
m1: &CompiledMigration,
m2: &CompiledMigration,
) -> HashMap<panproto_gat::Name, panproto_expr::Expr> {
let mut result = m1.conditional_survival.clone();
for (m2_anchor, m2_pred) in &m2.conditional_survival {
let mut found = false;
for (m1_src, m1_tgt) in &m1.vertex_remap {
if m1_tgt == m2_anchor {
found = true;
let scoped = scope_check_predicate(m2_pred, m1, m1_src);
result
.entry(m1_src.clone())
.and_modify(|existing| {
*existing = panproto_expr::Expr::Builtin(
panproto_expr::BuiltinOp::And,
vec![existing.clone(), scoped.clone()],
);
})
.or_insert(scoped);
}
}
if !found && unchanged_by_m1(m1, m2_anchor) {
let scoped = scope_check_predicate(m2_pred, m1, m2_anchor);
result
.entry(m2_anchor.clone())
.and_modify(|existing| {
*existing = panproto_expr::Expr::Builtin(
panproto_expr::BuiltinOp::And,
vec![existing.clone(), scoped.clone()],
);
})
.or_insert(scoped);
}
}
result
}
fn scope_check_predicate(
pred: &panproto_expr::Expr,
m1: &CompiledMigration,
anchor: &panproto_gat::Name,
) -> panproto_expr::Expr {
let Some(m1_xforms) = m1.field_transforms.get(anchor) else {
return pred.clone();
};
let analysis = analyse_field_transforms(m1_xforms);
if analysis.dropped.is_empty() && analysis.keep_intersection.is_none() {
return pred.clone();
}
let free = panproto_expr::free_vars(pred);
let dropped_hit = free.iter().any(|v| analysis.dropped.contains(v.as_ref()));
let keep_violation = analysis
.keep_intersection
.as_ref()
.is_some_and(|keep| free.iter().any(|v| !keep.contains(v.as_ref())));
if dropped_hit || keep_violation {
return panproto_expr::Expr::Lit(panproto_expr::Literal::Bool(false));
}
pred.clone()
}
struct FieldDropAnalysis {
dropped: std::collections::HashSet<String>,
keep_intersection: Option<std::collections::HashSet<String>>,
}
fn analyse_field_transforms(xforms: &[panproto_inst::wtype::FieldTransform]) -> FieldDropAnalysis {
use panproto_inst::wtype::FieldTransform;
let mut dropped = std::collections::HashSet::new();
let mut keep_intersection: Option<std::collections::HashSet<String>> = None;
for x in xforms {
match x {
FieldTransform::DropField { key } => {
dropped.insert(key.clone());
}
FieldTransform::RenameField { old_key, .. } => {
dropped.insert(old_key.clone());
}
FieldTransform::KeepFields { keys } => {
let next: std::collections::HashSet<String> = keys.iter().cloned().collect();
keep_intersection = Some(match keep_intersection {
None => next,
Some(prev) => prev.intersection(&next).cloned().collect(),
});
}
_ => {}
}
}
FieldDropAnalysis {
dropped,
keep_intersection,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::{identity_lens, three_node_schema};
#[test]
fn compose_identity_with_identity() {
let schema = three_node_schema();
let l1 = identity_lens(&schema);
let l2 = identity_lens(&schema);
let composed = compose(&l1, &l2);
assert!(composed.is_ok(), "composing identity lenses should succeed");
let lens = composed.unwrap_or_else(|e| panic!("compose failed: {e}"));
assert_eq!(
lens.src_schema.vertex_count(),
schema.vertex_count(),
"composed src schema should match original"
);
assert_eq!(
lens.tgt_schema.vertex_count(),
schema.vertex_count(),
"composed tgt schema should match original"
);
}
}