plexus-engine 0.3.4

Engine integration traits for consuming Plexus plans
Documentation
use crate::*;
use std::collections::{BTreeMap, HashMap};

#[derive(Debug)]
struct MergeNodeSpec {
    var: Option<String>,
    labels: Vec<String>,
}

#[derive(Debug, Clone, Copy)]
enum MergeDir {
    Out,
    In,
    Both,
}

#[derive(Debug)]
struct MergeRelSpec {
    src_idx: usize,
    dst_idx: usize,
    rel_type: String,
    dir: MergeDir,
}

#[derive(Debug)]
struct MergePatternSpec {
    nodes: Vec<MergeNodeSpec>,
    rels: Vec<MergeRelSpec>,
}

fn map_value<'a>(map: &'a BTreeMap<String, Value>, key: &str) -> Option<&'a Value> {
    map.get(key)
}

fn parse_string_list(v: &Value) -> Result<Vec<String>, ExecutionError> {
    let Value::List(items) = v else {
        return Err(ExecutionError::UnsupportedOp("merge.pattern.labels"));
    };
    let mut out = Vec::with_capacity(items.len());
    for item in items {
        let Value::String(s) = item else {
            return Err(ExecutionError::UnsupportedOp("merge.pattern.labels"));
        };
        out.push(s.clone());
    }
    Ok(out)
}

fn parse_pattern_expr(
    engine: &InMemoryEngine,
    pattern: &Expr,
) -> Result<MergePatternSpec, ExecutionError> {
    let parsed = match pattern {
        Expr::MapLiteral { .. } => engine.eval_expr(&Vec::new(), pattern)?,
        _ => return Err(ExecutionError::UnsupportedOp("merge.pattern")),
    };
    let Value::Map(map) = parsed else {
        return Err(ExecutionError::UnsupportedOp("merge.pattern"));
    };

    let kind =
        map_value(&map, "kind").ok_or(ExecutionError::UnsupportedOp("merge.pattern.kind"))?;
    let kind = match kind {
        Value::String(k) => k.as_str(),
        _ => return Err(ExecutionError::UnsupportedOp("merge.pattern.kind")),
    };

    let nodes_value =
        map_value(&map, "nodes").ok_or(ExecutionError::UnsupportedOp("merge.pattern.nodes"))?;
    let Value::List(nodes) = nodes_value else {
        return Err(ExecutionError::UnsupportedOp("merge.pattern.nodes"));
    };
    let mut node_specs = Vec::with_capacity(nodes.len());
    for node in nodes {
        let Value::Map(node_map) = node else {
            return Err(ExecutionError::UnsupportedOp("merge.pattern.node"));
        };
        let var = match map_value(node_map, "var") {
            Some(Value::String(s)) => Some(s.clone()),
            Some(Value::Null) | None => None,
            _ => return Err(ExecutionError::UnsupportedOp("merge.pattern.node.var")),
        };
        let labels = match map_value(node_map, "labels") {
            Some(v) => parse_string_list(v)?,
            None => Vec::new(),
        };
        node_specs.push(MergeNodeSpec { var, labels });
    }

    let rels_value =
        map_value(&map, "rels").ok_or(ExecutionError::UnsupportedOp("merge.pattern.rels"))?;
    let Value::List(rels) = rels_value else {
        return Err(ExecutionError::UnsupportedOp("merge.pattern.rels"));
    };
    let mut rel_specs = Vec::with_capacity(rels.len());
    for (idx, rel) in rels.iter().enumerate() {
        let Value::Map(rel_map) = rel else {
            return Err(ExecutionError::UnsupportedOp("merge.pattern.rel"));
        };
        let rel_type = match map_value(rel_map, "types") {
            Some(Value::List(types)) => match types.first() {
                Some(Value::String(s)) => s.clone(),
                _ => String::new(),
            },
            _ => String::new(),
        };
        let dir = match map_value(rel_map, "dir") {
            Some(Value::String(s)) if s == "out" => MergeDir::Out,
            Some(Value::String(s)) if s == "in" => MergeDir::In,
            Some(Value::String(s)) if s == "both" => MergeDir::Both,
            _ => return Err(ExecutionError::UnsupportedOp("merge.pattern.rel.dir")),
        };

        let (src_idx, dst_idx) = match kind {
            "linear_v1" => (idx, idx + 1),
            "graph_v1" => {
                let src = match map_value(rel_map, "src") {
                    Some(Value::Int(v)) if *v >= 0 => *v as usize,
                    _ => return Err(ExecutionError::UnsupportedOp("merge.pattern.rel.src")),
                };
                let dst = match map_value(rel_map, "dst") {
                    Some(Value::Int(v)) if *v >= 0 => *v as usize,
                    _ => return Err(ExecutionError::UnsupportedOp("merge.pattern.rel.dst")),
                };
                (src, dst)
            }
            _ => return Err(ExecutionError::UnsupportedOp("merge.pattern.kind")),
        };
        if src_idx >= node_specs.len() || dst_idx >= node_specs.len() {
            return Err(ExecutionError::UnsupportedOp(
                "merge.pattern.rel.endpoint_idx",
            ));
        }
        rel_specs.push(MergeRelSpec {
            src_idx,
            dst_idx,
            rel_type,
            dir,
        });
    }

    Ok(MergePatternSpec {
        nodes: node_specs,
        rels: rel_specs,
    })
}

fn schema_index(schema: &[plexus_serde::ColDef], name: &str) -> Option<usize> {
    schema.iter().position(|c| c.name == name)
}

fn as_props_map(v: Value) -> Result<HashMap<String, Value>, ExecutionError> {
    match v {
        Value::Null => Ok(HashMap::new()),
        Value::Map(entries) => Ok(entries.into_iter().collect()),
        _ => Err(ExecutionError::ExpectedMapPayload),
    }
}

fn dfs_match_edges(
    engine: &InMemoryEngine,
    pattern_spec: &MergePatternSpec,
    edge_idx: usize,
    node_ids: &[u64],
    rel_ids: &mut Vec<u64>,
) -> bool {
    if edge_idx == pattern_spec.rels.len() {
        return true;
    }
    let rel_spec = &pattern_spec.rels[edge_idx];
    let lhs = node_ids[rel_spec.src_idx];
    let rhs = node_ids[rel_spec.dst_idx];
    for rel in &engine.graph.rels {
        if !rel_spec.rel_type.is_empty() && rel.typ != rel_spec.rel_type {
            continue;
        }
        let edge_ok = match rel_spec.dir {
            MergeDir::Out => rel.src == lhs && rel.dst == rhs,
            MergeDir::In => rel.src == rhs && rel.dst == lhs,
            MergeDir::Both => {
                (rel.src == lhs && rel.dst == rhs) || (rel.src == rhs && rel.dst == lhs)
            }
        };
        if !edge_ok {
            continue;
        }
        rel_ids.push(rel.id);
        if dfs_match_edges(engine, pattern_spec, edge_idx + 1, node_ids, rel_ids) {
            return true;
        }
        rel_ids.pop();
    }
    false
}

fn find_match(
    engine: &InMemoryEngine,
    pattern_spec: &MergePatternSpec,
    node_matches: &impl Fn(u64, usize, &[Option<u64>]) -> bool,
    idx: usize,
    assigned: &mut Vec<Option<u64>>,
    out_nodes: &mut Vec<u64>,
    out_rels: &mut Vec<u64>,
) -> bool {
    if idx == assigned.len() {
        let Some(nodes) = assigned.iter().copied().collect::<Option<Vec<_>>>() else {
            return false;
        };
        let mut rels = Vec::new();
        if dfs_match_edges(engine, pattern_spec, 0, &nodes, &mut rels) {
            *out_nodes = nodes;
            *out_rels = rels;
            return true;
        }
        return false;
    }
    if let Some(existing) = assigned[idx] {
        return node_matches(existing, idx, assigned)
            && find_match(
                engine,
                pattern_spec,
                node_matches,
                idx + 1,
                assigned,
                out_nodes,
                out_rels,
            );
    }
    for node in &engine.graph.nodes {
        if !node_matches(node.id, idx, assigned) {
            continue;
        }
        assigned[idx] = Some(node.id);
        if find_match(
            engine,
            pattern_spec,
            node_matches,
            idx + 1,
            assigned,
            out_nodes,
            out_rels,
        ) {
            return true;
        }
        assigned[idx] = None;
    }
    false
}

pub(super) fn execute_merge_pattern(
    engine: &mut InMemoryEngine,
    pattern: &Expr,
    on_create_props: &Expr,
    on_match_props: &Expr,
    schema: &[plexus_serde::ColDef],
    row: &Row,
) -> Result<(), ExecutionError> {
    let pattern_spec = parse_pattern_expr(engine, pattern)?;
    let on_create_props_map = as_props_map(engine.eval_expr(row, on_create_props)?)?;
    let on_match_props_map = as_props_map(engine.eval_expr(row, on_match_props)?)?;
    if pattern_spec.nodes.is_empty() {
        return Err(ExecutionError::UnsupportedOp("merge.pattern.empty"));
    }

    let mut bound_node_ids = Vec::with_capacity(pattern_spec.nodes.len());
    for node in &pattern_spec.nodes {
        let mut bound = None;
        if let Some(var) = &node.var {
            if let Some(idx) = schema_index(schema, var) {
                match row.get(idx) {
                    Some(Value::NodeRef(id)) => bound = Some(*id),
                    Some(Value::Null) | None => {}
                    Some(_) => return Err(ExecutionError::ExpectedNodeRef { idx }),
                }
            }
        }
        bound_node_ids.push(bound);
    }

    let node_matches = |node_id: u64, i: usize, node_ids: &[Option<u64>]| -> bool {
        let node_spec = &pattern_spec.nodes[i];
        let Some(node) = engine.graph.node_by_id(node_id) else {
            return false;
        };
        if !node_spec.labels.iter().all(|l| node.labels.contains(l)) {
            return false;
        }
        if let Some(bound) = node_ids[i] {
            node_id == bound
        } else {
            true
        }
    };

    let mut assigned_nodes = bound_node_ids.clone();
    let mut matched_nodes = Vec::new();
    let mut matched_rels = Vec::new();
    let matched = find_match(
        engine,
        &pattern_spec,
        &node_matches,
        0,
        &mut assigned_nodes,
        &mut matched_nodes,
        &mut matched_rels,
    );

    if matched {
        if !on_match_props_map.is_empty() {
            if matched_rels.is_empty() {
                for (key, value) in &on_match_props_map {
                    MutationEngine::set_property(
                        engine,
                        &Value::NodeRef(matched_nodes[0]),
                        key,
                        value.clone(),
                    )?;
                }
            } else {
                for rel_id in matched_rels {
                    for (key, value) in &on_match_props_map {
                        MutationEngine::set_property(
                            engine,
                            &Value::RelRef(rel_id),
                            key,
                            value.clone(),
                        )?;
                    }
                }
            }
        }
        return Ok(());
    }

    let mut created_node_ids = Vec::with_capacity(pattern_spec.nodes.len());
    for (i, node_spec) in pattern_spec.nodes.iter().enumerate() {
        if let Some(node_id) = bound_node_ids[i] {
            if engine.graph.node_by_id(node_id).is_none() {
                return Err(ExecutionError::UnknownNode(node_id));
            }
            created_node_ids.push(node_id);
        } else {
            let node_id = MutationEngine::create_node(engine, &node_spec.labels, HashMap::new())?;
            created_node_ids.push(node_id);
        }
    }

    let mut created_rel_ids = Vec::with_capacity(pattern_spec.rels.len());
    for rel in &pattern_spec.rels {
        let lhs = created_node_ids[rel.src_idx];
        let rhs = created_node_ids[rel.dst_idx];
        let (src, dst) = match rel.dir {
            MergeDir::Out => (lhs, rhs),
            MergeDir::In => (rhs, lhs),
            MergeDir::Both => (lhs, rhs),
        };
        let rel_id = MutationEngine::create_rel(engine, src, dst, &rel.rel_type, HashMap::new())?;
        created_rel_ids.push(rel_id);
    }

    if !on_create_props_map.is_empty() {
        if created_rel_ids.is_empty() {
            for (key, value) in &on_create_props_map {
                MutationEngine::set_property(
                    engine,
                    &Value::NodeRef(created_node_ids[0]),
                    key,
                    value.clone(),
                )?;
            }
        } else {
            for rel_id in created_rel_ids {
                for (key, value) in &on_create_props_map {
                    MutationEngine::set_property(
                        engine,
                        &Value::RelRef(rel_id),
                        key,
                        value.clone(),
                    )?;
                }
            }
        }
    }
    Ok(())
}