use super::*;
use plexus_conformance::{
load_cases_from_paths, run_conformance_suite, ConformanceCase, ConformanceGraphMode,
QueryResult as ConfQueryResult, Value as ConfValue,
};
use plexus_ir::mlir_text_to_plan;
use plexus_serde::serialize_plan;
use std::fs;
use std::path::{Path, PathBuf};
fn case_dir() -> PathBuf {
Path::new(env!("CARGO_MANIFEST_DIR"))
.join("src")
.join("tests")
.join("fixtures")
.join("vector")
}
fn local_case_paths() -> Vec<PathBuf> {
let mut files = fs::read_dir(case_dir())
.expect("read vector fixture dir")
.map(|e| e.expect("read_dir entry").path())
.filter(|p| {
matches!(
p.extension().and_then(|x| x.to_str()),
Some("feature") | Some("case")
)
})
.collect::<Vec<_>>();
files.sort();
files
}
fn local_value_to_conf(v: &Value) -> ConfValue {
match v {
Value::Null => ConfValue::Null,
Value::Bool(x) => ConfValue::Bool(*x),
Value::Int(x) => ConfValue::Int(*x),
Value::Float(x) => ConfValue::Float(*x),
Value::String(x) => ConfValue::String(x.clone()),
Value::NodeRef(x) => ConfValue::NodeRef(*x),
Value::RelRef(x) => ConfValue::RelRef(*x),
Value::List(xs) => ConfValue::List(xs.iter().map(local_value_to_conf).collect()),
Value::Map(m) => ConfValue::Map(
m.iter()
.map(|(k, v)| (k.clone(), local_value_to_conf(v)))
.collect(),
),
}
}
fn local_result_to_conf(out: QueryResult) -> ConfQueryResult {
ConfQueryResult {
rows: out
.rows
.into_iter()
.map(|row| row.iter().map(local_value_to_conf).collect())
.collect(),
continuation: None,
}
}
fn has_tag(case: &ConformanceCase, tag: &str) -> bool {
case.tags.iter().any(|t| t.eq_ignore_ascii_case(tag))
}
fn graph_for_case(case: &ConformanceCase) -> Graph {
match case.setup.graph_mode {
Some(ConformanceGraphMode::Empty) => Graph::default(),
Some(ConformanceGraphMode::Fixture) | None => fixture_vector_graph(),
}
}
fn vector_disabled_capabilities(plan: &plexus_serde::Plan) -> EngineCapabilities {
let plan_semver: PlanSemver = (&plan.version).into();
let mut caps = EngineCapabilities::full(VersionRange::new(plan_semver, plan_semver));
caps.supported_ops.remove(&OpKind::VectorScan);
caps.supported_ops.remove(&OpKind::Rerank);
caps.supported_exprs.remove(&ExprKind::VectorSimilarity);
caps
}
fn execute_vector_case(case: &ConformanceCase) -> Result<ConfQueryResult, String> {
let plan_mlir = case
.plan_mlir
.as_deref()
.ok_or_else(|| format!("vector case `{}` missing plan_mlir", case.name))?;
let plan = mlir_text_to_plan(plan_mlir).map_err(|e| format!("{e}"))?;
if has_tag(case, "cap-vector-disabled") {
validate_plan_against_capabilities(&plan, &vector_disabled_capabilities(&plan))
.map_err(|e| format!("{e}"))?;
}
let bytes = serialize_plan(&plan).map_err(|e| format!("{e}"))?;
let mut engine = mock_vector_engine_with_graph(graph_for_case(case));
execute_serialized(&mut engine, &bytes)
.map(local_result_to_conf)
.map_err(|e| format!("{e}"))
}
#[test]
fn vector_conformance_corpus_cases() {
let cases = load_cases_from_paths(&local_case_paths()).expect("load vector cases");
let report = run_conformance_suite(&cases, execute_vector_case);
assert_eq!(
report.failed,
0,
"vector conformance failures: {:?}",
report
.cases
.iter()
.filter_map(|c| c
.outcome
.as_ref()
.err()
.map(|e| (c.case.clone(), format!("{e}"))))
.collect::<Vec<_>>()
);
}