use std::sync::Arc;
use std::sync::OnceLock;
use arrow_schema::DataType;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::logical_expr::ColumnarValue;
use uni_plugin::traits::procedure::{
NamedArgType, ProcedureContext, ProcedureMode, ProcedurePlugin, ProcedureSignature,
};
use uni_plugin::traits::scalar::ArgType;
use uni_plugin::{FnError, PluginError, PluginRegistrar, QName, SideEffects};
use crate::procedures_plugin::vector::{hybrid_search_yields, run_search_procedure};
use crate::query::df_graph::search_procedures::run_hybrid_search;
fn signature() -> &'static ProcedureSignature {
static SIG: OnceLock<ProcedureSignature> = OnceLock::new();
SIG.get_or_init(|| ProcedureSignature {
args: vec![
NamedArgType {
name: smol_str::SmolStr::new("label"),
ty: ArgType::Primitive(DataType::Utf8),
default: None,
doc: "Vertex label to search.".to_owned(),
},
NamedArgType {
name: smol_str::SmolStr::new("properties"),
ty: ArgType::CypherValue,
default: None,
doc: "Either a property name (used for both vector and fts) or a map `{vector: '...', fts: '...'}`."
.to_owned(),
},
NamedArgType {
name: smol_str::SmolStr::new("query_text"),
ty: ArgType::Primitive(DataType::Utf8),
default: None,
doc: "Free-text query (used for FTS and, optionally, auto-embedding).".to_owned(),
},
NamedArgType {
name: smol_str::SmolStr::new("query_vector"),
ty: ArgType::CypherValue,
default: None,
doc: "Optional pre-computed query vector (List<Float>); omit to auto-embed.".to_owned(),
},
NamedArgType {
name: smol_str::SmolStr::new("k"),
ty: ArgType::Primitive(DataType::Int64),
default: None,
doc: "Number of fused results to return.".to_owned(),
},
NamedArgType {
name: smol_str::SmolStr::new("filter"),
ty: ArgType::Primitive(DataType::Utf8),
default: None,
doc: "Optional pushdown filter expression.".to_owned(),
},
NamedArgType {
name: smol_str::SmolStr::new("options"),
ty: ArgType::CypherValue,
default: None,
doc: "Optional options map (fusion method, alpha, rrf_k, reranker, …).".to_owned(),
},
],
yields: hybrid_search_yields(),
mode: ProcedureMode::Read,
side_effects: SideEffects::ReadOnly,
retry_contract: None,
batch_input: None,
docs: "Hybrid vector + FTS search with RRF (or weighted) fusion and optional rerank."
.to_owned(),
})
}
#[derive(Debug)]
struct HybridSearchProc;
impl ProcedurePlugin for HybridSearchProc {
fn signature(&self) -> &ProcedureSignature {
signature()
}
fn invoke(
&self,
ctx: ProcedureContext<'_>,
args: &[ColumnarValue],
) -> Result<SendableRecordBatchStream, FnError> {
run_search_procedure(
"uni.search",
&ctx,
args,
signature(),
|host, uni_args, yield_items, output_schema| async move {
let target_properties = host.target_properties().clone();
run_hybrid_search(
&host,
&uni_args,
&yield_items,
&target_properties,
&output_schema,
)
.await
},
)
}
}
pub fn register_into(r: &mut PluginRegistrar<'_>) -> Result<(), PluginError> {
r.procedure(
QName::new("uni", "search"),
signature().clone(),
Arc::new(HybridSearchProc),
)?;
Ok(())
}