plexus-engine 0.3.4

Engine integration traits for consuming Plexus plans
Documentation
use super::*;
use plexus_conformance::{
    load_cases_from_paths, run_conformance_suite, ConformanceCase, ConformanceGraphMode,
    QueryResult as ConfQueryResult, Value as ConfValue,
};
use plexus_ir::mlir_text_to_plan;
use plexus_serde::serialize_plan;
use std::fs;
use std::path::{Path, PathBuf};

fn case_dir() -> PathBuf {
    Path::new(env!("CARGO_MANIFEST_DIR"))
        .join("src")
        .join("tests")
        .join("fixtures")
        .join("vector")
}

fn local_case_paths() -> Vec<PathBuf> {
    let mut files = fs::read_dir(case_dir())
        .expect("read vector fixture dir")
        .map(|e| e.expect("read_dir entry").path())
        .filter(|p| {
            matches!(
                p.extension().and_then(|x| x.to_str()),
                Some("feature") | Some("case")
            )
        })
        .collect::<Vec<_>>();
    files.sort();
    files
}

fn local_value_to_conf(v: &Value) -> ConfValue {
    match v {
        Value::Null => ConfValue::Null,
        Value::Bool(x) => ConfValue::Bool(*x),
        Value::Int(x) => ConfValue::Int(*x),
        Value::Float(x) => ConfValue::Float(*x),
        Value::String(x) => ConfValue::String(x.clone()),
        Value::NodeRef(x) => ConfValue::NodeRef(*x),
        Value::RelRef(x) => ConfValue::RelRef(*x),
        Value::List(xs) => ConfValue::List(xs.iter().map(local_value_to_conf).collect()),
        Value::Map(m) => ConfValue::Map(
            m.iter()
                .map(|(k, v)| (k.clone(), local_value_to_conf(v)))
                .collect(),
        ),
    }
}

fn local_result_to_conf(out: QueryResult) -> ConfQueryResult {
    ConfQueryResult {
        rows: out
            .rows
            .into_iter()
            .map(|row| row.iter().map(local_value_to_conf).collect())
            .collect(),
    }
}

fn has_tag(case: &ConformanceCase, tag: &str) -> bool {
    case.tags.iter().any(|t| t.eq_ignore_ascii_case(tag))
}

fn graph_for_case(case: &ConformanceCase) -> Graph {
    match case.setup.graph_mode {
        Some(ConformanceGraphMode::Empty) => Graph::default(),
        Some(ConformanceGraphMode::Fixture) | None => fixture_vector_graph(),
    }
}

fn vector_disabled_capabilities(plan: &plexus_serde::Plan) -> EngineCapabilities {
    let plan_semver: PlanSemver = (&plan.version).into();
    let mut caps = EngineCapabilities::full(VersionRange::new(plan_semver, plan_semver));
    caps.supported_ops.remove(&OpKind::VectorScan);
    caps.supported_ops.remove(&OpKind::Rerank);
    caps.supported_exprs.remove(&ExprKind::VectorSimilarity);
    caps
}

fn execute_vector_case(case: &ConformanceCase) -> Result<ConfQueryResult, String> {
    let plan_mlir = case
        .plan_mlir
        .as_deref()
        .ok_or_else(|| format!("vector case `{}` missing plan_mlir", case.name))?;
    let plan = mlir_text_to_plan(plan_mlir).map_err(|e| format!("{e}"))?;
    if has_tag(case, "cap-vector-disabled") {
        validate_plan_against_capabilities(&plan, &vector_disabled_capabilities(&plan))
            .map_err(|e| format!("{e}"))?;
    }
    let bytes = serialize_plan(&plan).map_err(|e| format!("{e}"))?;
    let mut engine = mock_vector_engine_with_graph(graph_for_case(case));
    execute_serialized(&mut engine, &bytes)
        .map(local_result_to_conf)
        .map_err(|e| format!("{e}"))
}

#[test]
fn vector_conformance_corpus_cases() {
    let cases = load_cases_from_paths(&local_case_paths()).expect("load vector cases");
    let report = run_conformance_suite(&cases, execute_vector_case);
    assert_eq!(
        report.failed,
        0,
        "vector conformance failures: {:?}",
        report
            .cases
            .iter()
            .filter_map(|c| c
                .outcome
                .as_ref()
                .err()
                .map(|e| (c.case.clone(), format!("{e}"))))
            .collect::<Vec<_>>()
    );
}