use std::future::Future;
use std::sync::Arc;
use std::sync::OnceLock;
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::error::Result as DFResult;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use futures::stream;
use uni_common::Value;
use uni_plugin::traits::procedure::{
NamedArgType, ProcedureContext, ProcedureMode, ProcedurePlugin, ProcedureSignature,
};
use uni_plugin::traits::scalar::ArgType;
use uni_plugin::{FnError, PluginError, PluginRegistrar, QName, SideEffects};
use crate::procedures_plugin::host_args::{columnar_args_to_values, require_host};
use crate::query::df_graph::search_procedures::run_vector_query;
use crate::query::executor::procedure_host::QueryProcedureHost;
fn signature() -> &'static ProcedureSignature {
static SIG: OnceLock<ProcedureSignature> = OnceLock::new();
SIG.get_or_init(|| ProcedureSignature {
args: vec![
NamedArgType {
name: smol_str::SmolStr::new("label"),
ty: ArgType::Primitive(DataType::Utf8),
default: None,
doc: "Vertex label to search.".to_owned(),
},
NamedArgType {
name: smol_str::SmolStr::new("property"),
ty: ArgType::Primitive(DataType::Utf8),
default: None,
doc: "Vector property name on the label.".to_owned(),
},
NamedArgType {
name: smol_str::SmolStr::new("query"),
ty: ArgType::CypherValue,
default: None,
doc: "Query vector (List<Float>) or query text (String, auto-embedded).".to_owned(),
},
NamedArgType {
name: smol_str::SmolStr::new("k"),
ty: ArgType::Primitive(DataType::Int64),
default: None,
doc: "Number of nearest neighbours to return.".to_owned(),
},
NamedArgType {
name: smol_str::SmolStr::new("filter"),
ty: ArgType::Primitive(DataType::Utf8),
default: None,
doc: "Optional pushdown filter expression.".to_owned(),
},
NamedArgType {
name: smol_str::SmolStr::new("threshold"),
ty: ArgType::Primitive(DataType::Float64),
default: None,
doc: "Optional maximum distance threshold (post-filter).".to_owned(),
},
NamedArgType {
name: smol_str::SmolStr::new("options"),
ty: ArgType::CypherValue,
default: None,
doc: "Optional reranker / extra options map.".to_owned(),
},
],
yields: vector_query_yields(),
mode: ProcedureMode::Read,
side_effects: SideEffects::ReadOnly,
retry_contract: None,
batch_input: None,
docs:
"Approximate-nearest-neighbour over a vector index with optional cross-encoder rerank."
.to_owned(),
})
}
fn vector_query_yields() -> Vec<Field> {
vec![
vid_field(),
Field::new("distance", DataType::Float64, true),
Field::new("score", DataType::Float32, true),
Field::new("rerank_score", DataType::Float32, true),
]
}
pub(super) fn fts_query_yields() -> Vec<Field> {
vec![
vid_field(),
Field::new("score", DataType::Float32, true),
Field::new("rerank_score", DataType::Float32, true),
]
}
pub(super) fn hybrid_search_yields() -> Vec<Field> {
vec![
vid_field(),
Field::new("score", DataType::Float32, true),
Field::new("rerank_score", DataType::Float32, true),
Field::new("vector_score", DataType::Float32, true),
Field::new("fts_score", DataType::Float32, true),
Field::new("distance", DataType::Float64, true),
]
}
fn vid_field() -> Field {
let mut md = std::collections::HashMap::new();
md.insert("_yield_kind".to_owned(), "node_vid_source".to_owned());
Field::new("vid", DataType::Int64, true).with_metadata(md)
}
#[derive(Debug)]
struct VectorQueryProc;
impl ProcedurePlugin for VectorQueryProc {
fn signature(&self) -> &ProcedureSignature {
signature()
}
fn invoke(
&self,
ctx: ProcedureContext<'_>,
args: &[ColumnarValue],
) -> Result<SendableRecordBatchStream, FnError> {
run_search_procedure(
"uni.vector.query",
&ctx,
args,
signature(),
|host, uni_args, yield_items, output_schema| async move {
let target_properties = host.target_properties().clone();
run_vector_query(
&host,
&uni_args,
&yield_items,
&target_properties,
&output_schema,
)
.await
},
)
}
}
pub(super) fn resolve_yields_and_schema(
host: &crate::query::executor::procedure_host::QueryProcedureHost,
sig: &ProcedureSignature,
fallback_schema: &Arc<Schema>,
) -> (Vec<(String, Option<String>)>, Arc<Schema>) {
let host_yields = host.yield_items();
if host_yields.is_empty() {
let yield_items: Vec<(String, Option<String>)> = sig
.yields
.iter()
.map(|f| (f.name().clone(), None))
.collect();
(yield_items, fallback_schema.clone())
} else {
let output_schema = host
.expected_schema()
.cloned()
.unwrap_or_else(|| fallback_schema.clone());
(host_yields.to_vec(), output_schema)
}
}
pub(super) fn run_search_procedure<F, Fut>(
proc_name: &'static str,
ctx: &ProcedureContext<'_>,
args: &[ColumnarValue],
sig: &'static ProcedureSignature,
run_fn: F,
) -> Result<SendableRecordBatchStream, FnError>
where
F: FnOnce(QueryProcedureHost, Vec<Value>, Vec<(String, Option<String>)>, SchemaRef) -> Fut
+ Send
+ 'static,
Fut: Future<Output = DFResult<Option<RecordBatch>>> + Send + 'static,
{
let host = require_host(ctx, proc_name)?.clone();
let uni_args = columnar_args_to_values(args);
let fallback_schema = Arc::new(Schema::new(sig.yields.clone()));
let (yield_items, output_schema) = resolve_yields_and_schema(&host, sig, &fallback_schema);
let stream_schema = output_schema.clone();
let stream = stream::once(async move {
let batch = run_fn(host, uni_args, yield_items, output_schema.clone())
.await?
.unwrap_or_else(|| RecordBatch::new_empty(output_schema.clone()));
Ok::<_, datafusion::error::DataFusionError>(batch)
});
Ok(Box::pin(RecordBatchStreamAdapter::new(
stream_schema,
stream,
)))
}
pub fn register_into(r: &mut PluginRegistrar<'_>) -> Result<(), PluginError> {
r.procedure(
QName::new("uni", "vector.query"),
signature().clone(),
Arc::new(VectorQueryProc),
)?;
Ok(())
}