Skip to main content

panproto_inst/
poly.rs

1//! Polynomial functor operations on instances.
2//!
3//! Schemas are polynomial functors; instances are W-types (initial algebras).
4//! This module exposes the derived operations that arise from the adjoint
5//! triple Σ ⊣ Δ ⊣ Π applied to the polynomial interpretation:
6//!
7//! - **Fiber**: preimage of a migration at a target anchor (Δ at a point)
8//! - **Group-by**: partition source nodes by fiber (explicit Π on trees)
9//! - **Join**: pullback of two instances along shared projections
10//! - **Section**: construct an enriched instance from base + annotation data
11
12use std::collections::{HashMap, VecDeque};
13
14use panproto_gat::Name;
15use panproto_schema::Edge;
16use rustc_hash::{FxHashMap, FxHashSet};
17
18use crate::metadata::Node;
19use crate::value::{FieldPresence, Value};
20use crate::wtype::{
21    CompiledMigration, WInstance, apply_field_transforms, build_env_from_extra_fields,
22    collect_scalar_child_values, reconstruct_fans, resolve_edge, value_to_expr_literal,
23};
24
25// ---------------------------------------------------------------------------
26// Complement infrastructure
27// ---------------------------------------------------------------------------
28
29/// A node that was dropped during restriction, with provenance info.
30#[derive(Debug, Clone)]
31pub struct DroppedNode {
32    /// Original node ID in the source instance.
33    pub original_id: u32,
34    /// Anchor of the dropped node.
35    pub anchor: Name,
36    /// The surviving node this was contracted into (nearest surviving ancestor).
37    pub contracted_into: Option<u32>,
38}
39
40/// Complement data from a restrict operation. Stores everything needed
41/// to reconstruct the original instance from the restricted result.
42#[derive(Debug, Clone, Default)]
43pub struct Complement {
44    /// Nodes that were dropped during restriction.
45    pub dropped_nodes: Vec<DroppedNode>,
46    /// Arcs that were dropped (both endpoints must have been in the source).
47    pub dropped_arcs: Vec<(u32, u32, Edge)>,
48    /// Pre-transform `extra_fields` for nodes that had `field_transforms` applied.
49    /// Used by backward migration to restore original field values.
50    pub original_extra_fields: HashMap<u32, HashMap<String, crate::value::Value>>,
51}
52
53/// An enrichment to add when constructing a section.
54#[derive(Debug, Clone)]
55pub struct SectionEnrichment {
56    /// Node ID in the base instance that this enrichment annotates.
57    pub base_node_id: u32,
58    /// Anchor for the new enrichment node (must be a vertex in the
59    /// source schema but not in the target schema).
60    pub anchor: Name,
61    /// Edge connecting the base node to this enrichment.
62    pub edge: Edge,
63    /// Value for the enrichment node.
64    pub value: Option<FieldPresence>,
65    /// Extra fields for the enrichment node.
66    pub extra_fields: FxHashMap<String, Value>,
67}
68
69/// Compute the fiber of a compiled migration at a specific target anchor.
70///
71/// Given migration m: S → T and target anchor `a` in T, returns all source
72/// node IDs whose remapped anchor equals `a`. This is the `Δ_f` operation
73/// applied to a representable (a single point).
74#[must_use]
75pub fn fiber_at_anchor(
76    compiled: &CompiledMigration,
77    source: &WInstance,
78    target_anchor: &Name,
79) -> Vec<u32> {
80    source
81        .nodes
82        .iter()
83        .filter(|(_, node)| {
84            compiled
85                .vertex_remap
86                .get(&node.anchor)
87                .is_some_and(|remapped| remapped == target_anchor)
88        })
89        .map(|(id, _)| *id)
90        .collect()
91}
92
93/// Compute fibers for ALL target anchors simultaneously.
94///
95/// Returns a map: target anchor → source node IDs. This is the complete
96/// fiber decomposition of the migration. Every source node appears in
97/// exactly one fiber (the fibers partition the source).
98#[must_use]
99pub fn fiber_decomposition(
100    compiled: &CompiledMigration,
101    source: &WInstance,
102) -> FxHashMap<Name, Vec<u32>> {
103    let mut fibers: FxHashMap<Name, Vec<u32>> = FxHashMap::default();
104    for (&id, node) in &source.nodes {
105        if let Some(target) = compiled.vertex_remap.get(&node.anchor) {
106            fibers.entry(target.clone()).or_default().push(id);
107        }
108    }
109    fibers
110}
111
112/// Fiber with a value predicate: compute f⁻¹(a) ∩ {x | pred(x)}.
113///
114/// Combines `Δ_f` (pullback) with conditional survival. The predicate is
115/// evaluated against each source node's `extra_fields`, with all fields
116/// bound as expression variables.
117#[must_use]
118pub fn fiber_with_predicate(
119    compiled: &CompiledMigration,
120    source: &WInstance,
121    target_anchor: &Name,
122    predicate: &panproto_expr::Expr,
123    eval_config: &panproto_expr::EvalConfig,
124) -> Vec<u32> {
125    fiber_at_anchor(compiled, source, target_anchor)
126        .into_iter()
127        .filter(|&id| {
128            let Some(node) = source.nodes.get(&id) else {
129                return false;
130            };
131            let mut env = build_env_from_extra_fields(&node.extra_fields);
132            if let Some(FieldPresence::Present(ref v)) = node.value {
133                env = env.extend(std::sync::Arc::from("_value"), value_to_expr_literal(v));
134            }
135            env = env.extend(
136                std::sync::Arc::from("_anchor"),
137                panproto_expr::Literal::Str(node.anchor.as_ref().into()),
138            );
139            matches!(
140                panproto_expr::eval(predicate, &env, eval_config),
141                Ok(panproto_expr::Literal::Bool(true))
142            )
143        })
144        .collect()
145}
146
147/// Group source nodes by their image under a migration.
148///
149/// Returns: for each target anchor, a sub-instance containing only the
150/// source nodes in that fiber (with internal arcs preserved).
151///
152/// This is the dependent product `Π_f` computed explicitly on trees.
153#[must_use]
154pub fn group_by(compiled: &CompiledMigration, source: &WInstance) -> FxHashMap<Name, WInstance> {
155    let fibers = fiber_decomposition(compiled, source);
156    fibers
157        .into_iter()
158        .map(|(anchor, node_ids)| {
159            let sub = extract_subinstance(source, &node_ids);
160            (anchor, sub)
161        })
162        .collect()
163}
164
165/// Join two instances along a shared projection.
166///
167/// Given A →f C ←g B, compute the pullback A ×_C B: pairs (a, b) where
168/// f(a) and g(b) map to the same target anchor.
169///
170/// Returns all matching pairs as (`left_node_id`, `right_node_id`).
171#[must_use]
172pub fn join(
173    left: &WInstance,
174    right: &WInstance,
175    left_compiled: &CompiledMigration,
176    right_compiled: &CompiledMigration,
177) -> Vec<(u32, u32)> {
178    let left_fibers = fiber_decomposition(left_compiled, left);
179    let right_fibers = fiber_decomposition(right_compiled, right);
180
181    let mut pairs = Vec::new();
182    for (anchor, left_ids) in &left_fibers {
183        if let Some(right_ids) = right_fibers.get(anchor) {
184            for &l in left_ids {
185                for &r in right_ids {
186                    pairs.push((l, r));
187                }
188            }
189        }
190    }
191    pairs
192}
193
194// ---------------------------------------------------------------------------
195// Restrict with complement
196// ---------------------------------------------------------------------------
197
198/// Restrict an instance and collect complement data.
199///
200/// Like `wtype_restrict` but also returns a [`Complement`] recording all
201/// dropped nodes and their nearest surviving ancestors. This complement
202/// is needed for [`fiber_at_node`] and for backward data migration.
203///
204/// # Errors
205///
206/// Returns [`RestrictError`](crate::error::RestrictError) if the root is
207/// pruned or edge resolution fails during ancestor contraction.
208pub fn restrict_with_complement(
209    instance: &WInstance,
210    _src_schema: &panproto_schema::Schema,
211    tgt_schema: &panproto_schema::Schema,
212    migration: &CompiledMigration,
213) -> Result<(WInstance, Complement), crate::error::RestrictError> {
214    use crate::error::RestrictError;
215
216    let root_node = instance
217        .nodes
218        .get(&instance.root)
219        .ok_or(RestrictError::RootPruned)?;
220    let root_target_anchor = migration
221        .vertex_remap
222        .get(&root_node.anchor)
223        .unwrap_or(&root_node.anchor);
224    if !migration.surviving_verts.contains(root_target_anchor) {
225        return Err(RestrictError::RootPruned);
226    }
227
228    let mut new_nodes: HashMap<u32, Node> = HashMap::new();
229    let mut new_arcs: Vec<(u32, u32, Edge)> = Vec::new();
230    let mut surviving_set: FxHashSet<u32> = FxHashSet::default();
231    let mut complement = Complement::default();
232    let mut queue: VecDeque<(u32, Option<u32>)> = VecDeque::new();
233
234    // Process root
235    let mut root_node_cloned = root_node.clone();
236    if let Some(remapped) = migration.vertex_remap.get(&root_node.anchor) {
237        root_node_cloned.anchor.clone_from(remapped);
238    }
239    // Check conditional survival for root
240    if let Some(pred) = migration.conditional_survival.get(&root_node.anchor) {
241        let env = build_env_from_extra_fields(&root_node.extra_fields);
242        let config = panproto_expr::EvalConfig::default();
243        if matches!(
244            panproto_expr::eval(pred, &env, &config),
245            Ok(panproto_expr::Literal::Bool(false))
246        ) {
247            return Err(RestrictError::RootPruned);
248        }
249    }
250    // Apply field transforms to root
251    if let Some(transforms) = migration.field_transforms.get(&root_node.anchor) {
252        complement
253            .original_extra_fields
254            .insert(instance.root, root_node.extra_fields.clone());
255        let scalars = collect_scalar_child_values(instance, instance.root);
256        apply_field_transforms(&mut root_node_cloned, transforms, &scalars);
257    }
258    new_nodes.insert(instance.root, root_node_cloned);
259    surviving_set.insert(instance.root);
260    queue.push_back((instance.root, None));
261
262    // BFS: visit each node, tracking nearest surviving ancestor
263    while let Some((current_id, ancestor_id)) = queue.pop_front() {
264        let child_ancestor = if surviving_set.contains(&current_id) {
265            Some(current_id)
266        } else {
267            ancestor_id
268        };
269        restrict_bfs_step(
270            instance,
271            tgt_schema,
272            migration,
273            current_id,
274            child_ancestor,
275            &mut new_nodes,
276            &mut new_arcs,
277            &mut surviving_set,
278            &mut complement,
279            &mut queue,
280        )?;
281    }
282
283    // Record dropped arcs
284    collect_dropped_arcs(instance, &surviving_set, &mut complement);
285
286    // Fan reconstruction
287    let empty_ancestors = FxHashMap::default();
288    let new_fans = reconstruct_fans(
289        instance,
290        &surviving_set,
291        &empty_ancestors,
292        migration,
293        tgt_schema,
294    )?;
295
296    let new_schema_root = migration
297        .vertex_remap
298        .get(&instance.schema_root)
299        .cloned()
300        .unwrap_or_else(|| instance.schema_root.clone());
301
302    let restricted = WInstance::new(
303        new_nodes,
304        new_arcs,
305        new_fans,
306        instance.root,
307        new_schema_root,
308    );
309    Ok((restricted, complement))
310}
311
312/// Process one BFS level: check children of `current_id` for survival.
313#[allow(clippy::too_many_arguments)]
314fn restrict_bfs_step(
315    instance: &WInstance,
316    tgt_schema: &panproto_schema::Schema,
317    migration: &CompiledMigration,
318    current_id: u32,
319    child_ancestor: Option<u32>,
320    new_nodes: &mut HashMap<u32, Node>,
321    new_arcs: &mut Vec<(u32, u32, Edge)>,
322    surviving_set: &mut FxHashSet<u32>,
323    complement: &mut Complement,
324    queue: &mut VecDeque<(u32, Option<u32>)>,
325) -> Result<(), crate::error::RestrictError> {
326    use crate::error::RestrictError;
327
328    for &child_id in instance.children(current_id) {
329        let Some(child_node) = instance.nodes.get(&child_id) else {
330            continue;
331        };
332
333        let target_anchor = migration
334            .vertex_remap
335            .get(&child_node.anchor)
336            .unwrap_or(&child_node.anchor);
337        let mut child_survives = migration.surviving_verts.contains(target_anchor);
338
339        if child_survives {
340            if let Some(pred) = migration.conditional_survival.get(&child_node.anchor) {
341                let env = build_env_from_extra_fields(&child_node.extra_fields);
342                let config = panproto_expr::EvalConfig::default();
343                if matches!(
344                    panproto_expr::eval(pred, &env, &config),
345                    Ok(panproto_expr::Literal::Bool(false))
346                ) {
347                    child_survives = false;
348                }
349            }
350        }
351
352        if child_survives {
353            surviving_set.insert(child_id);
354            let mut new_node = child_node.clone();
355            if let Some(remapped) = migration.vertex_remap.get(&child_node.anchor) {
356                new_node.anchor.clone_from(remapped);
357            }
358            if let Some(transforms) = migration.field_transforms.get(&child_node.anchor) {
359                // Capture pre-transform extra_fields before applying transforms
360                complement
361                    .original_extra_fields
362                    .insert(child_id, child_node.extra_fields.clone());
363                let scalars = collect_scalar_child_values(instance, child_id);
364                apply_field_transforms(&mut new_node, transforms, &scalars);
365            }
366            new_nodes.insert(child_id, new_node.clone());
367
368            if let Some(anc_id) = child_ancestor {
369                let anc_node = new_nodes.get(&anc_id).ok_or(RestrictError::RootPruned)?;
370                let edge = resolve_edge(
371                    tgt_schema,
372                    &migration.resolver,
373                    &anc_node.anchor,
374                    &new_node.anchor,
375                )?;
376                new_arcs.push((anc_id, child_id, edge));
377            }
378        } else {
379            complement.dropped_nodes.push(DroppedNode {
380                original_id: child_id,
381                anchor: child_node.anchor.clone(),
382                contracted_into: child_ancestor,
383            });
384        }
385
386        queue.push_back((child_id, child_ancestor));
387    }
388    Ok(())
389}
390
391/// Collect arcs from the source instance where at least one endpoint
392/// did not survive restriction.
393fn collect_dropped_arcs(
394    instance: &WInstance,
395    surviving_set: &FxHashSet<u32>,
396    complement: &mut Complement,
397) {
398    for (src, tgt, edge) in &instance.arcs {
399        if !surviving_set.contains(src) || !surviving_set.contains(tgt) {
400            complement.dropped_arcs.push((*src, *tgt, edge.clone()));
401        }
402    }
403}
404
405// ---------------------------------------------------------------------------
406// Fiber at node
407// ---------------------------------------------------------------------------
408
409/// Compute fiber at a specific node ID in the restricted (target) instance.
410///
411/// Given source instance S, target instance T = restrict(S, migration),
412/// and a node n in T, find all nodes in S that were either:
413/// (a) remapped to n's anchor, or
414/// (b) contracted into n during ancestor contraction.
415#[must_use]
416pub fn fiber_at_node(
417    source: &WInstance,
418    target: &WInstance,
419    target_node_id: u32,
420    complement: &Complement,
421) -> Vec<u32> {
422    let Some(target_node) = target.nodes.get(&target_node_id) else {
423        return vec![];
424    };
425
426    let mut result = Vec::new();
427
428    // Direct preimage: source nodes with matching anchor
429    for (&id, node) in &source.nodes {
430        if node.anchor == target_node.anchor {
431            result.push(id);
432        }
433    }
434
435    // Contracted nodes
436    for dropped in &complement.dropped_nodes {
437        if dropped.contracted_into == Some(target_node_id) {
438            result.push(dropped.original_id);
439        }
440    }
441
442    result
443}
444
445// ---------------------------------------------------------------------------
446// Section construction
447// ---------------------------------------------------------------------------
448
449/// Construct a section of a projection.
450///
451/// Given:
452/// - `base`: an instance of the target schema T
453/// - `projection`: a compiled migration S -> T
454/// - `enrichments`: nodes to add in the S-instance fibers
455///
456/// Produces an S-instance that:
457/// 1. Contains all base nodes (with anchors inverse-mapped to source schema)
458/// 2. Contains all enrichment nodes attached at the correct positions
459/// 3. Projects back to base under the migration:
460///    restrict(section(base, projection, enrichments), projection) = base
461///
462/// # Errors
463///
464/// Returns [`InstError::NodeNotFound`](crate::error::InstError::NodeNotFound)
465/// if an enrichment references a base node ID that does not exist.
466pub fn section(
467    base: &WInstance,
468    projection: &CompiledMigration,
469    enrichments: Vec<SectionEnrichment>,
470) -> Result<WInstance, crate::error::InstError> {
471    // Build inverse vertex remap: target_anchor -> source_anchor
472    let inverse_remap: HashMap<Name, Name> = projection
473        .vertex_remap
474        .iter()
475        .map(|(src, tgt)| (tgt.clone(), src.clone()))
476        .collect();
477
478    // Step 1: copy base nodes, remapping anchors back to source schema
479    let mut nodes: HashMap<u32, Node> = HashMap::new();
480    let mut next_id: u32 = base.nodes.keys().max().copied().unwrap_or(0) + 1;
481
482    for (&id, node) in &base.nodes {
483        let mut new_node = node.clone();
484        if let Some(src_anchor) = inverse_remap.get(&node.anchor) {
485            new_node.anchor = src_anchor.clone();
486        }
487        nodes.insert(id, new_node);
488    }
489
490    // Remap arcs to use source anchors
491    let arcs: Vec<_> = base
492        .arcs
493        .iter()
494        .map(|(src_id, tgt_id, edge)| {
495            let mut new_edge = edge.clone();
496            if let Some(src_anchor) = inverse_remap.get(&new_edge.src) {
497                new_edge.src = src_anchor.clone();
498            }
499            if let Some(tgt_anchor) = inverse_remap.get(&new_edge.tgt) {
500                new_edge.tgt = tgt_anchor.clone();
501            }
502            (*src_id, *tgt_id, new_edge)
503        })
504        .collect();
505
506    let mut all_arcs = arcs;
507
508    // Step 2: add enrichment nodes
509    for enrichment in enrichments {
510        if !base.nodes.contains_key(&enrichment.base_node_id) {
511            return Err(crate::error::InstError::NodeNotFound(
512                enrichment.base_node_id,
513            ));
514        }
515
516        let enrichment_id = next_id;
517        next_id += 1;
518
519        let mut new_node = Node::new(enrichment_id, enrichment.anchor.clone());
520        if let Some(value) = enrichment.value {
521            new_node = new_node.with_value(value);
522        }
523        for (k, v) in enrichment.extra_fields {
524            new_node.extra_fields.insert(k, v);
525        }
526
527        nodes.insert(enrichment_id, new_node);
528        all_arcs.push((enrichment.base_node_id, enrichment_id, enrichment.edge));
529    }
530
531    let schema_root = inverse_remap
532        .get(&base.schema_root)
533        .cloned()
534        .unwrap_or_else(|| base.schema_root.clone());
535
536    Ok(WInstance::new(
537        nodes,
538        all_arcs,
539        base.fans.clone(),
540        base.root,
541        schema_root,
542    ))
543}
544
545/// Extract a sub-instance containing only the specified nodes and arcs
546/// between them.
547#[must_use]
548fn extract_subinstance(source: &WInstance, node_ids: &[u32]) -> WInstance {
549    let id_set: FxHashSet<u32> = node_ids.iter().copied().collect();
550    let nodes: HashMap<u32, Node> = source
551        .nodes
552        .iter()
553        .filter(|(id, _)| id_set.contains(id))
554        .map(|(&id, n)| (id, n.clone()))
555        .collect();
556    let arcs: Vec<_> = source
557        .arcs
558        .iter()
559        .filter(|(src, tgt, _)| id_set.contains(src) && id_set.contains(tgt))
560        .cloned()
561        .collect();
562    let root = node_ids.first().copied().unwrap_or(0);
563    WInstance::new(nodes, arcs, vec![], root, source.schema_root.clone())
564}
565
566#[cfg(test)]
567#[allow(clippy::unwrap_used, clippy::cast_possible_truncation)]
568mod tests {
569    use super::*;
570    use crate::value::Value;
571    use crate::wtype::wtype_restrict;
572    use panproto_schema::Edge;
573
574    /// Build a simple test instance: root with two annotation children.
575    fn make_test_instance() -> (WInstance, CompiledMigration) {
576        let mut nodes = HashMap::new();
577        let root = Node::new(0, "root");
578        nodes.insert(0, root);
579
580        let mut node_a = Node::new(1, "annotation");
581        node_a
582            .extra_fields
583            .insert("label".into(), Value::Str("ingredient".into()));
584        node_a
585            .extra_fields
586            .insert("confidence".into(), Value::Float(0.9));
587        nodes.insert(1, node_a);
588
589        let mut node_b = Node::new(2, "annotation");
590        node_b
591            .extra_fields
592            .insert("label".into(), Value::Str("step".into()));
593        node_b
594            .extra_fields
595            .insert("confidence".into(), Value::Float(0.5));
596        nodes.insert(2, node_b);
597
598        let edge = Edge {
599            src: Name::from("root"),
600            tgt: Name::from("annotation"),
601            kind: Name::from("child"),
602            name: None,
603        };
604        let arcs = vec![(0, 1, edge.clone()), (0, 2, edge)];
605
606        let inst = WInstance::new(nodes, arcs, vec![], 0, Name::from("root"));
607
608        let mut vertex_remap = HashMap::new();
609        vertex_remap.insert(Name::from("root"), Name::from("document"));
610        vertex_remap.insert(Name::from("annotation"), Name::from("span"));
611
612        let compiled = CompiledMigration {
613            surviving_verts: ["root", "annotation"]
614                .iter()
615                .map(|s| Name::from(*s))
616                .collect(),
617            surviving_edges: std::collections::HashSet::new(),
618            vertex_remap,
619            edge_remap: HashMap::new(),
620            resolver: HashMap::new(),
621            hyper_resolver: HashMap::new(),
622            field_transforms: HashMap::new(),
623            conditional_survival: HashMap::new(),
624            expansion_path: HashMap::new(),
625        };
626
627        (inst, compiled)
628    }
629
630    #[test]
631    fn fiber_at_anchor_basic() {
632        let (inst, compiled) = make_test_instance();
633        let fiber = fiber_at_anchor(&compiled, &inst, &Name::from("span"));
634        assert_eq!(fiber.len(), 2);
635        assert!(fiber.contains(&1));
636        assert!(fiber.contains(&2));
637    }
638
639    #[test]
640    fn fiber_at_anchor_root() {
641        let (inst, compiled) = make_test_instance();
642        let fiber = fiber_at_anchor(&compiled, &inst, &Name::from("document"));
643        assert_eq!(fiber, vec![0]);
644    }
645
646    #[test]
647    fn fiber_at_anchor_nonexistent() {
648        let (inst, compiled) = make_test_instance();
649        let fiber = fiber_at_anchor(&compiled, &inst, &Name::from("nonexistent"));
650        assert!(fiber.is_empty());
651    }
652
653    #[test]
654    fn fiber_decomposition_partitions() {
655        let (inst, compiled) = make_test_instance();
656        let fibers = fiber_decomposition(&compiled, &inst);
657
658        // All source nodes appear in exactly one fiber
659        let mut all_ids: Vec<u32> = fibers.values().flatten().copied().collect();
660        all_ids.sort_unstable();
661        assert_eq!(all_ids, vec![0, 1, 2]);
662
663        // Two fibers: document and span
664        assert_eq!(fibers.len(), 2);
665        assert_eq!(fibers[&Name::from("document")].len(), 1);
666        assert_eq!(fibers[&Name::from("span")].len(), 2);
667    }
668
669    #[test]
670    fn fiber_with_predicate_filters() {
671        let (inst, compiled) = make_test_instance();
672        let config = panproto_expr::EvalConfig::default();
673
674        // Filter: confidence > 0.8
675        let predicate = panproto_expr::Expr::Builtin(
676            panproto_expr::BuiltinOp::Gt,
677            vec![
678                panproto_expr::Expr::Var("confidence".into()),
679                panproto_expr::Expr::Lit(panproto_expr::Literal::Float(0.8)),
680            ],
681        );
682
683        let filtered =
684            fiber_with_predicate(&compiled, &inst, &Name::from("span"), &predicate, &config);
685        // Only node 1 has confidence 0.9 > 0.8
686        assert_eq!(filtered, vec![1]);
687    }
688
689    #[test]
690    fn group_by_partitions() {
691        let (inst, compiled) = make_test_instance();
692        let groups = group_by(&compiled, &inst);
693
694        assert_eq!(groups.len(), 2);
695        assert_eq!(groups[&Name::from("document")].nodes.len(), 1);
696        assert_eq!(groups[&Name::from("span")].nodes.len(), 2);
697    }
698
699    #[test]
700    fn join_computes_pullback() {
701        let (left, left_compiled) = make_test_instance();
702        let (right, right_compiled) = make_test_instance();
703
704        let pairs = join(&left, &right, &left_compiled, &right_compiled);
705
706        // Both instances have 2 "span" nodes and 1 "document" node.
707        // Span × Span = 4 pairs, Document × Document = 1 pair → 5 total.
708        assert_eq!(pairs.len(), 5);
709    }
710
711    #[test]
712    fn fiber_at_node_basic() {
713        // Source: root(0) -> annotation(1), root(0) -> annotation(2), root(0) -> text(3)
714        let mut nodes = HashMap::new();
715        nodes.insert(0, Node::new(0, "root"));
716        nodes.insert(1, Node::new(1, "annotation"));
717        nodes.insert(2, Node::new(2, "annotation"));
718        nodes.insert(3, Node::new(3, "text"));
719
720        let edge_ann = Edge {
721            src: Name::from("root"),
722            tgt: Name::from("annotation"),
723            kind: Name::from("child"),
724            name: None,
725        };
726        let edge_txt = Edge {
727            src: Name::from("root"),
728            tgt: Name::from("text"),
729            kind: Name::from("child"),
730            name: None,
731        };
732        let source = WInstance::new(
733            nodes,
734            vec![(0, 1, edge_ann.clone()), (0, 2, edge_ann), (0, 3, edge_txt)],
735            vec![],
736            0,
737            Name::from("root"),
738        );
739
740        // Target: root(0) -> text(3), annotation nodes dropped
741        let mut tgt_nodes = HashMap::new();
742        tgt_nodes.insert(0, Node::new(0, "root"));
743        tgt_nodes.insert(3, Node::new(3, "text"));
744
745        let tgt_edge = Edge {
746            src: Name::from("root"),
747            tgt: Name::from("text"),
748            kind: Name::from("child"),
749            name: None,
750        };
751        let target = WInstance::new(
752            tgt_nodes,
753            vec![(0, 3, tgt_edge)],
754            vec![],
755            0,
756            Name::from("root"),
757        );
758
759        // Complement: annotations contracted into root
760        let complement = Complement {
761            dropped_nodes: vec![
762                DroppedNode {
763                    original_id: 1,
764                    anchor: Name::from("annotation"),
765                    contracted_into: Some(0),
766                },
767                DroppedNode {
768                    original_id: 2,
769                    anchor: Name::from("annotation"),
770                    contracted_into: Some(0),
771                },
772            ],
773            dropped_arcs: vec![],
774            original_extra_fields: HashMap::new(),
775        };
776
777        // Fiber at root (id=0): direct match on anchor "root" (node 0) + contracted (1, 2)
778        let fiber = fiber_at_node(&source, &target, 0, &complement);
779        assert!(fiber.contains(&0)); // direct preimage
780        assert!(fiber.contains(&1)); // contracted
781        assert!(fiber.contains(&2)); // contracted
782        assert_eq!(fiber.len(), 3);
783
784        // Fiber at text (id=3): direct match only
785        let fiber_text = fiber_at_node(&source, &target, 3, &complement);
786        assert!(fiber_text.contains(&3));
787        assert_eq!(fiber_text.len(), 1);
788    }
789
790    /// Build a minimal test schema with the given vertex names and edges.
791    fn make_schema(vertices: &[&str], edges: &[Edge]) -> panproto_schema::Schema {
792        use smallvec::smallvec;
793        let mut between: HashMap<(Name, Name), smallvec::SmallVec<Edge, 2>> = HashMap::new();
794        for edge in edges {
795            between
796                .entry((Name::from(&*edge.src), Name::from(&*edge.tgt)))
797                .or_insert_with(|| smallvec![])
798                .push(edge.clone());
799        }
800        panproto_schema::Schema {
801            protocol: "test".into(),
802            vertices: vertices
803                .iter()
804                .map(|&v| {
805                    (
806                        Name::from(v),
807                        panproto_schema::Vertex {
808                            id: Name::from(v),
809                            kind: Name::from("object"),
810                            nsid: None,
811                        },
812                    )
813                })
814                .collect(),
815            edges: HashMap::new(),
816            hyper_edges: HashMap::new(),
817            constraints: HashMap::new(),
818            required: HashMap::new(),
819            nsids: HashMap::new(),
820            entries: Vec::new(),
821            variants: HashMap::new(),
822            orderings: HashMap::new(),
823            recursion_points: HashMap::new(),
824            spans: HashMap::new(),
825            usage_modes: HashMap::new(),
826            nominal: HashMap::new(),
827            coercions: HashMap::new(),
828            mergers: HashMap::new(),
829            defaults: HashMap::new(),
830            policies: HashMap::new(),
831            outgoing: HashMap::new(),
832            incoming: HashMap::new(),
833            between,
834        }
835    }
836
837    #[test]
838    fn restrict_with_complement_tracks_drops() {
839        let doc_ann_edge = Edge {
840            src: Name::from("doc"),
841            tgt: Name::from("annotation"),
842            kind: Name::from("child"),
843            name: None,
844        };
845        let doc_text_edge = Edge {
846            src: Name::from("doc"),
847            tgt: Name::from("text"),
848            kind: Name::from("child"),
849            name: None,
850        };
851
852        let tgt_schema = make_schema(&["doc", "text"], std::slice::from_ref(&doc_text_edge));
853        let src_schema = make_schema(
854            &["doc", "annotation", "text"],
855            &[doc_ann_edge, doc_text_edge],
856        );
857
858        // Migration: doc -> doc, text -> text, annotation not in surviving_verts
859        let mut vertex_remap = HashMap::new();
860        vertex_remap.insert(Name::from("doc"), Name::from("doc"));
861        vertex_remap.insert(Name::from("text"), Name::from("text"));
862        let migration = CompiledMigration {
863            surviving_verts: ["doc", "text"].iter().map(|s| Name::from(*s)).collect(),
864            surviving_edges: std::collections::HashSet::new(),
865            vertex_remap,
866            edge_remap: HashMap::new(),
867            resolver: HashMap::new(),
868            hyper_resolver: HashMap::new(),
869            field_transforms: HashMap::new(),
870            conditional_survival: HashMap::new(),
871            expansion_path: HashMap::new(),
872        };
873
874        // Instance: doc(0) -> annotation(1), doc(0) -> text(2)
875        let mut nodes = HashMap::new();
876        nodes.insert(0, Node::new(0, "doc"));
877        nodes.insert(1, Node::new(1, "annotation"));
878        nodes.insert(2, Node::new(2, "text"));
879
880        let instance = WInstance::new(
881            nodes,
882            vec![
883                (
884                    0,
885                    1,
886                    Edge {
887                        src: Name::from("doc"),
888                        tgt: Name::from("annotation"),
889                        kind: Name::from("child"),
890                        name: None,
891                    },
892                ),
893                (
894                    0,
895                    2,
896                    Edge {
897                        src: Name::from("doc"),
898                        tgt: Name::from("text"),
899                        kind: Name::from("child"),
900                        name: None,
901                    },
902                ),
903            ],
904            vec![],
905            0,
906            Name::from("doc"),
907        );
908
909        let (restricted, complement) =
910            restrict_with_complement(&instance, &src_schema, &tgt_schema, &migration).unwrap();
911
912        // Restricted should have 2 nodes: doc and text
913        assert_eq!(restricted.nodes.len(), 2);
914        assert!(restricted.nodes.contains_key(&0));
915        assert!(restricted.nodes.contains_key(&2));
916
917        // Complement should have 1 dropped node: annotation
918        assert_eq!(complement.dropped_nodes.len(), 1);
919        assert_eq!(complement.dropped_nodes[0].original_id, 1);
920        assert_eq!(complement.dropped_nodes[0].anchor, Name::from("annotation"));
921        assert_eq!(complement.dropped_nodes[0].contracted_into, Some(0));
922
923        // Dropped arcs: the arc from doc -> annotation
924        assert_eq!(complement.dropped_arcs.len(), 1);
925    }
926
927    #[test]
928    fn section_roundtrip() {
929        let doc_ann_edge = Edge {
930            src: Name::from("doc"),
931            tgt: Name::from("annotation"),
932            kind: Name::from("child"),
933            name: None,
934        };
935        let doc_text_edge = Edge {
936            src: Name::from("doc"),
937            tgt: Name::from("text"),
938            kind: Name::from("child"),
939            name: None,
940        };
941
942        let tgt_schema = make_schema(&["doc", "text"], std::slice::from_ref(&doc_text_edge));
943        let src_schema = make_schema(
944            &["doc", "annotation", "text"],
945            &[doc_ann_edge, doc_text_edge],
946        );
947
948        // Migration: doc -> doc, text -> text
949        let mut vertex_remap = HashMap::new();
950        vertex_remap.insert(Name::from("doc"), Name::from("doc"));
951        vertex_remap.insert(Name::from("text"), Name::from("text"));
952        let migration = CompiledMigration {
953            surviving_verts: ["doc", "text"].iter().map(|s| Name::from(*s)).collect(),
954            surviving_edges: std::collections::HashSet::new(),
955            vertex_remap,
956            edge_remap: HashMap::new(),
957            resolver: HashMap::new(),
958            hyper_resolver: HashMap::new(),
959            field_transforms: HashMap::new(),
960            conditional_survival: HashMap::new(),
961            expansion_path: HashMap::new(),
962        };
963
964        // Base instance (target schema): doc(0) -> text(1)
965        let mut base_nodes = HashMap::new();
966        base_nodes.insert(0, Node::new(0, "doc"));
967        base_nodes.insert(1, Node::new(1, "text"));
968
969        let base = WInstance::new(
970            base_nodes,
971            vec![(
972                0,
973                1,
974                Edge {
975                    src: Name::from("doc"),
976                    tgt: Name::from("text"),
977                    kind: Name::from("child"),
978                    name: None,
979                },
980            )],
981            vec![],
982            0,
983            Name::from("doc"),
984        );
985
986        // Add one enrichment: an annotation node attached to doc
987        let enrichments = vec![SectionEnrichment {
988            base_node_id: 0,
989            anchor: Name::from("annotation"),
990            edge: Edge {
991                src: Name::from("doc"),
992                tgt: Name::from("annotation"),
993                kind: Name::from("child"),
994                name: None,
995            },
996            value: Some(FieldPresence::Present(Value::Str("test".into()))),
997            extra_fields: FxHashMap::default(),
998        }];
999
1000        let section_inst = section(&base, &migration, enrichments).unwrap();
1001
1002        // Section should have 3 nodes: doc, text, annotation
1003        assert_eq!(section_inst.nodes.len(), 3);
1004
1005        // Restricting the section back should match the base
1006        let restricted =
1007            wtype_restrict(&section_inst, &src_schema, &tgt_schema, &migration).unwrap();
1008        assert_eq!(restricted.nodes.len(), base.nodes.len());
1009
1010        // Verify anchors match: both should have doc and text
1011        let restricted_anchors: FxHashSet<_> = restricted
1012            .nodes
1013            .values()
1014            .map(|n| n.anchor.clone())
1015            .collect();
1016        let base_anchors: FxHashSet<_> = base.nodes.values().map(|n| n.anchor.clone()).collect();
1017        assert_eq!(restricted_anchors, base_anchors);
1018    }
1019}