use crate::query::df_graph::GraphExecutionContext;
use crate::query::similar_to::{
FusionMethod, SimilarToOptions, fuse_scores, normalize_bm25, parse_options, score_vectors,
validate_options, value_to_f32_vec,
};
use crate::types::QueryWarning;
use arrow_array::builder::Float64Builder;
use arrow_array::{Array, UInt64Array};
use arrow_schema::{DataType, Schema};
use datafusion::physical_plan::{ColumnarValue, DisplayAs, DisplayFormatType, PhysicalExpr};
use std::collections::HashMap;
use std::sync::Arc;
use uni_common::Value;
use uni_common::core::id::Vid;
use uni_common::core::schema::DistanceMetric;
pub(crate) struct SimilarToExecExpr {
source_children: Vec<Arc<dyn PhysicalExpr>>,
query_children: Vec<Arc<dyn PhysicalExpr>>,
options_child: Option<Arc<dyn PhysicalExpr>>,
graph_ctx: Arc<GraphExecutionContext>,
source_variable: Option<String>,
source_property_names: Vec<Option<String>>,
source_metrics: Vec<Option<DistanceMetric>>,
}
impl SimilarToExecExpr {
pub(crate) fn new(
source_children: Vec<Arc<dyn PhysicalExpr>>,
query_children: Vec<Arc<dyn PhysicalExpr>>,
options_child: Option<Arc<dyn PhysicalExpr>>,
graph_ctx: Arc<GraphExecutionContext>,
source_variable: Option<String>,
source_property_names: Vec<Option<String>>,
source_metrics: Vec<Option<DistanceMetric>>,
) -> Self {
Self {
source_children,
query_children,
options_child,
graph_ctx,
source_variable,
source_property_names,
source_metrics,
}
}
}
impl std::fmt::Debug for SimilarToExecExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SimilarToExecExpr")
.field("num_sources", &self.source_children.len())
.field("source_variable", &self.source_variable)
.finish_non_exhaustive()
}
}
impl std::fmt::Display for SimilarToExecExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "similar_to(<{} sources>)", self.source_children.len())
}
}
impl PartialEq<dyn PhysicalExpr> for SimilarToExecExpr {
fn eq(&self, _other: &dyn PhysicalExpr) -> bool {
false
}
}
impl PartialEq for SimilarToExecExpr {
fn eq(&self, _other: &Self) -> bool {
false
}
}
impl Eq for SimilarToExecExpr {}
impl std::hash::Hash for SimilarToExecExpr {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
"SimilarToExecExpr".hash(state);
}
}
impl DisplayAs for SimilarToExecExpr {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self)
}
}
fn columnar_value_to_value(
cv: &ColumnarValue,
_batch: &arrow_array::RecordBatch,
row: usize,
) -> Value {
match cv {
ColumnarValue::Array(arr) => arrow_to_value_at(arr.as_ref(), row),
ColumnarValue::Scalar(sv) => sv
.to_array_of_size(1)
.map(|arr| arrow_to_value_at(arr.as_ref(), 0))
.unwrap_or(Value::Null),
}
}
fn arrow_to_value_at(col: &dyn Array, row: usize) -> Value {
use arrow_array::*;
if col.is_null(row) {
return Value::Null;
}
match col.data_type() {
DataType::LargeBinary => {
let bytes = col
.as_any()
.downcast_ref::<LargeBinaryArray>()
.unwrap()
.value(row);
if bytes.is_empty() {
Value::Null
} else {
uni_common::cypher_value_codec::decode(bytes).unwrap_or(Value::Null)
}
}
DataType::Float64 => Value::Float(
col.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.value(row),
),
DataType::Float32 => Value::Float(
col.as_any()
.downcast_ref::<Float32Array>()
.unwrap()
.value(row) as f64,
),
DataType::Utf8 => Value::String(
col.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.value(row)
.to_string(),
),
DataType::LargeUtf8 => Value::String(
col.as_any()
.downcast_ref::<LargeStringArray>()
.unwrap()
.value(row)
.to_string(),
),
DataType::Int64 => Value::Int(
col.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
.value(row),
),
DataType::UInt64 => Value::Int(
col.as_any()
.downcast_ref::<UInt64Array>()
.unwrap()
.value(row) as i64,
),
DataType::FixedSizeList(_, _) => {
let fsl = col.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
let values = fsl.value(row);
if let Some(f32_arr) = values.as_any().downcast_ref::<Float32Array>() {
Value::Vector((0..f32_arr.len()).map(|i| f32_arr.value(i)).collect())
} else if let Some(f64_arr) = values.as_any().downcast_ref::<Float64Array>() {
Value::Vector(
(0..f64_arr.len())
.map(|i| f64_arr.value(i) as f32)
.collect(),
)
} else {
Value::Null
}
}
DataType::LargeList(_) => {
let values = col
.as_any()
.downcast_ref::<LargeListArray>()
.unwrap()
.value(row);
Value::List(
(0..values.len())
.map(|i| arrow_to_value_at(values.as_ref(), i))
.collect(),
)
}
DataType::List(_) => {
let values = col.as_any().downcast_ref::<ListArray>().unwrap().value(row);
Value::List(
(0..values.len())
.map(|i| arrow_to_value_at(values.as_ref(), i))
.collect(),
)
}
_ => uni_store::storage::arrow_convert::arrow_to_value(col, row, None),
}
}
struct PrecomputedResources {
embed_vectors: Vec<Option<Vec<f32>>>,
fts_results: Vec<Option<HashMap<Vid, f32>>>,
}
enum ScoringMode {
Vector(DistanceMetric),
AutoEmbed(DistanceMetric),
Fts,
}
impl PhysicalExpr for SimilarToExecExpr {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn data_type(&self, _input_schema: &Schema) -> datafusion::error::Result<DataType> {
Ok(DataType::Float64)
}
fn nullable(&self, _input_schema: &Schema) -> datafusion::error::Result<bool> {
Ok(true)
}
fn evaluate(
&self,
batch: &arrow_array::RecordBatch,
) -> datafusion::error::Result<ColumnarValue> {
let num_rows = batch.num_rows();
let num_sources = self.source_children.len();
let source_cvs: Vec<_> = self
.source_children
.iter()
.map(|c| c.evaluate(batch))
.collect::<datafusion::error::Result<Vec<_>>>()?;
let query_cvs: Vec<_> = self
.query_children
.iter()
.map(|c| c.evaluate(batch))
.collect::<datafusion::error::Result<Vec<_>>>()?;
let opts = if let Some(ref opts_child) = self.options_child {
let opts_cv = opts_child.evaluate(batch)?;
let opts_val = columnar_value_to_value(&opts_cv, batch, 0);
parse_options(&opts_val).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!("similar_to options: {}", e))
})?
} else {
SimilarToOptions::default()
};
validate_options(&opts, num_sources).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!("similar_to: {}", e))
})?;
if num_rows == 0 {
let mut builder = Float64Builder::with_capacity(0);
return Ok(ColumnarValue::Array(Arc::new(builder.finish())));
}
let first_row_source_vals: Vec<Value> = source_cvs
.iter()
.map(|cv| columnar_value_to_value(cv, batch, 0))
.collect();
let first_row_query_vals: Vec<Value> = query_cvs
.iter()
.map(|cv| columnar_value_to_value(cv, batch, 0))
.collect();
let scoring_modes: Vec<ScoringMode> = first_row_source_vals
.iter()
.zip(first_row_query_vals.iter())
.enumerate()
.map(|(i, (s, q))| {
determine_scoring_mode(s, q, self.source_metrics.get(i).and_then(|m| m.as_ref()))
})
.collect::<Result<Vec<_>, _>>()
.map_err(|e| {
datafusion::error::DataFusionError::Execution(format!("similar_to: {}", e))
})?;
let vid_col_idx = self.source_variable.as_ref().and_then(|var| {
let vid_col_name = format!("{}._vid", var);
batch.schema().index_of(&vid_col_name).ok()
});
let label = self.source_variable.as_ref().and_then(|var| {
let labels_col_name = format!("{}._labels", var);
if let Ok(idx) = batch.schema().index_of(&labels_col_name) {
let col = batch.column(idx);
let val = arrow_to_value_at(col.as_ref(), 0);
match val {
Value::String(s) => Some(s),
Value::List(list) => list.first().and_then(|v| v.as_str()).map(String::from),
_ => None,
}
} else {
None
}
});
let graph_ctx = self.graph_ctx.clone();
let source_property_names = self.source_property_names.clone();
let query_strings: Vec<Option<String>> = first_row_query_vals
.iter()
.map(|v| match v {
Value::String(s) => Some(s.clone()),
_ => None,
})
.collect();
let pre_label = label.clone();
let opts_fts_k = opts.fts_k;
let precomputed = std::thread::scope(|s| {
s.spawn(|| {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| {
datafusion::error::DataFusionError::Execution(format!(
"Failed to create runtime for similar_to: {}",
e
))
})?;
let mut embed_vectors = vec![None; num_sources];
let mut fts_results = vec![None; num_sources];
for (i, mode) in scoring_modes.iter().enumerate() {
match mode {
ScoringMode::AutoEmbed(_) => {
let query_text = query_strings[i].as_deref().unwrap_or("");
let vec = rt.block_on(auto_embed_query(
&graph_ctx,
pre_label.as_deref(),
source_property_names.get(i).and_then(|p| p.as_deref()),
query_text,
))?;
embed_vectors[i] = Some(vec);
}
ScoringMode::Fts => {
let query_text = query_strings[i].as_deref().unwrap_or("");
let (lbl, prop) = resolve_fts_label_property(
&graph_ctx,
pre_label.as_deref(),
source_property_names.get(i).and_then(|p| p.as_deref()),
)?;
let results = rt.block_on(fts_search_batch(
&graph_ctx, &lbl, &prop, query_text, opts_fts_k,
))?;
fts_results[i] = Some(results);
}
ScoringMode::Vector(_) => {} }
}
Ok::<_, datafusion::error::DataFusionError>(PrecomputedResources {
embed_vectors,
fts_results,
})
})
.join()
.unwrap_or_else(|_| {
Err(datafusion::error::DataFusionError::Execution(
"similar_to precomputation thread panicked".to_string(),
))
})
})?;
let mut builder = Float64Builder::with_capacity(num_rows);
for row_idx in 0..num_rows {
let mut scores = Vec::with_capacity(num_sources);
for (src_idx, mode) in scoring_modes.iter().enumerate() {
let score = match mode {
ScoringMode::Vector(metric) => {
let sv = columnar_value_to_value(&source_cvs[src_idx], batch, row_idx);
let qv = columnar_value_to_value(&query_cvs[src_idx], batch, row_idx);
score_vectors_from_values(&sv, &qv, metric).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!(
"similar_to vector: {}",
e
))
})?
}
ScoringMode::AutoEmbed(metric) => {
let sv = columnar_value_to_value(&source_cvs[src_idx], batch, row_idx);
let embed_vec = precomputed.embed_vectors[src_idx]
.as_ref()
.expect("auto-embed should have been precomputed");
score_vectors_precomputed(&sv, embed_vec, metric).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!(
"similar_to auto-embed: {}",
e
))
})?
}
ScoringMode::Fts => {
let fts_map = precomputed.fts_results[src_idx]
.as_ref()
.expect("FTS should have been precomputed");
let vid = vid_col_idx.and_then(|idx| {
let col = batch.column(idx);
col.as_any()
.downcast_ref::<UInt64Array>()
.map(|u| Vid::from(u.value(row_idx)))
});
match vid {
Some(v) => fts_map.get(&v).copied().unwrap_or(0.0),
None => 0.0,
}
}
};
scores.push(score);
}
let fused = fuse_scores(&scores, &opts).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!("similar_to fusion: {}", e))
})?;
builder.append_value(fused as f64);
}
if opts.method == FusionMethod::Rrf && num_sources > 1 {
self.graph_ctx.push_warning(QueryWarning::RrfPointContext);
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
self.source_children
.iter()
.chain(&self.query_children)
.chain(self.options_child.iter())
.collect()
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> datafusion::error::Result<Arc<dyn PhysicalExpr>> {
let ns = self.source_children.len();
let nq = self.query_children.len();
let has_opts = self.options_child.is_some();
let expected = ns + nq + if has_opts { 1 } else { 0 };
if children.len() != expected {
return Err(datafusion::error::DataFusionError::Plan(format!(
"SimilarToExecExpr expected {} children, got {}",
expected,
children.len()
)));
}
let source_children = children[..ns].to_vec();
let query_children = children[ns..ns + nq].to_vec();
let options_child = if has_opts {
Some(children[ns + nq].clone())
} else {
None
};
Ok(Arc::new(SimilarToExecExpr::new(
source_children,
query_children,
options_child,
self.graph_ctx.clone(),
self.source_variable.clone(),
self.source_property_names.clone(),
self.source_metrics.clone(),
)))
}
fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self)
}
}
fn determine_scoring_mode(
source: &Value,
query: &Value,
metric: Option<&DistanceMetric>,
) -> Result<ScoringMode, String> {
let m = metric.cloned().unwrap_or(DistanceMetric::Cosine);
match (source, query) {
(Value::Vector(_) | Value::List(_), Value::Vector(_) | Value::List(_)) => {
Ok(ScoringMode::Vector(m))
}
(Value::Vector(_) | Value::List(_), Value::String(_)) => Ok(ScoringMode::AutoEmbed(m)),
(Value::String(_), Value::String(_)) => Ok(ScoringMode::Fts),
(Value::String(_), Value::Vector(_) | Value::List(_)) => {
Err("FTS source cannot be scored against a vector query".to_string())
}
_ => Err(format!(
"unsupported source/query type combination: {:?} vs {:?}",
std::mem::discriminant(source),
std::mem::discriminant(query)
)),
}
}
fn score_vectors_from_values(
source: &Value,
query: &Value,
metric: &DistanceMetric,
) -> Result<f32, String> {
let v1 = value_to_f32_vec(source).map_err(|e| e.to_string())?;
let v2 = value_to_f32_vec(query).map_err(|e| e.to_string())?;
score_vectors(&v1, &v2, metric).map_err(|e| e.to_string())
}
fn score_vectors_precomputed(
source: &Value,
query_vec: &[f32],
metric: &DistanceMetric,
) -> Result<f32, String> {
let v1 = value_to_f32_vec(source).map_err(|e| e.to_string())?;
score_vectors(&v1, query_vec, metric).map_err(|e| e.to_string())
}
async fn auto_embed_query(
graph_ctx: &GraphExecutionContext,
label: Option<&str>,
property: Option<&str>,
query_text: &str,
) -> datafusion::error::Result<Vec<f32>> {
let storage = graph_ctx.storage();
let schema = storage.schema_manager().schema();
let embedding_alias = if let (Some(lbl), Some(prop)) = (label, property) {
schema
.vector_index_for_property(lbl, prop)
.and_then(|cfg| cfg.embedding_config.as_ref().map(|ec| ec.alias.clone()))
} else {
None
};
let embedding_alias = embedding_alias.or_else(|| {
schema.indexes.iter().find_map(|idx| {
if let uni_common::core::schema::IndexDefinition::Vector(config) = idx {
config.embedding_config.as_ref().map(|ec| ec.alias.clone())
} else {
None
}
})
});
let alias = embedding_alias.ok_or_else(|| {
datafusion::error::DataFusionError::Execution(
"similar_to: no vector index with embedding config found. \
Cannot auto-embed text query."
.to_string(),
)
})?;
let runtime = graph_ctx.xervo_runtime().ok_or_else(|| {
datafusion::error::DataFusionError::Execution(
"similar_to: cannot auto-embed text — Uni-Xervo runtime not configured. \
Provide a pre-computed vector instead."
.to_string(),
)
})?;
let embedder = runtime.embedding(&alias).await.map_err(|e| {
datafusion::error::DataFusionError::Execution(format!(
"similar_to: failed to get embedder: {}",
e
))
})?;
let embeddings = embedder.embed(vec![query_text]).await.map_err(|e| {
datafusion::error::DataFusionError::Execution(format!(
"similar_to: embedding failed: {}",
e
))
})?;
embeddings.into_iter().next().ok_or_else(|| {
datafusion::error::DataFusionError::Execution(
"similar_to: embedding service returned no results".to_string(),
)
})
}
async fn fts_search_batch(
graph_ctx: &GraphExecutionContext,
label: &str,
property: &str,
query_text: &str,
fts_k: f32,
) -> datafusion::error::Result<HashMap<Vid, f32>> {
let storage = graph_ctx.storage();
let results = storage
.fts_search(label, property, query_text, 1000, None, None)
.await
.map_err(|e| {
datafusion::error::DataFusionError::Execution(format!(
"similar_to: FTS search failed: {}",
e
))
})?;
Ok(results
.into_iter()
.map(|(vid, score)| (vid, normalize_bm25(score, fts_k)))
.collect())
}
fn resolve_fts_label_property(
graph_ctx: &GraphExecutionContext,
label: Option<&str>,
property: Option<&str>,
) -> datafusion::error::Result<(String, String)> {
let lbl = label.unwrap_or("");
let schema = graph_ctx.storage().schema_manager().schema();
if let (Some(l), Some(p)) = (label, property) {
let has_fts = schema.indexes.iter().any(|idx| {
matches!(idx, uni_common::core::schema::IndexDefinition::FullText(config)
if config.label == l && config.properties.contains(&p.to_string()))
});
if has_fts {
return Ok((l.to_string(), p.to_string()));
}
return Err(datafusion::error::DataFusionError::Execution(format!(
"similar_to: no vector or full-text index found for property '{}.{}'. \
Cannot compute text similarity without an appropriate index.",
l, p
)));
}
find_fts_property_from_ctx(graph_ctx, lbl)
.map(|prop| (lbl.to_string(), prop))
.ok_or_else(|| {
datafusion::error::DataFusionError::Execution(format!(
"similar_to: no full-text index found for label '{}'. \
Cannot compute text similarity without an FTS index.",
lbl
))
})
}
fn find_fts_property_from_ctx(graph_ctx: &GraphExecutionContext, label: &str) -> Option<String> {
let schema = graph_ctx.storage().schema_manager().schema();
for idx in &schema.indexes {
if let uni_common::core::schema::IndexDefinition::FullText(config) = idx
&& config.label == label
&& let Some(prop) = config.properties.first()
{
return Some(prop.clone());
}
}
None
}