use super::*;
use super::projection::{apply_filter, apply_ordering, project_return};
pub(crate) struct EmbeddingResolver<'a> {
cell: &'a tokio::sync::OnceCell<EmbeddingClient>,
config: Option<&'a crate::embedding::EmbeddingConfig>,
}
impl EmbeddingResolver<'_> {
async fn resolve(&self) -> Result<&EmbeddingClient> {
let config = self.config.cloned();
self.cell
.get_or_try_init(|| async move {
match config {
Some(cfg) => EmbeddingClient::new(cfg),
None => EmbeddingClient::from_env(),
}
})
.await
}
}
impl Omnigraph {
pub async fn query(
&self,
target: impl Into<ReadTarget>,
query_source: &str,
query_name: &str,
params: &ParamMap,
) -> Result<QueryResult> {
self.ensure_schema_state_valid().await?;
let resolved = self.resolved_target(target).await?;
let catalog = self.catalog();
let query_decl = omnigraph_compiler::find_named_query(query_source, query_name)
.map_err(|e| OmniError::manifest(e.to_string()))?;
let type_ctx = typecheck_query(&catalog, &query_decl)?;
let ir = lower_query(&catalog, &query_decl, &type_ctx)?;
let needs_graph = ir
.pipeline
.iter()
.any(|op| matches!(op, IROp::Expand { .. } | IROp::AntiJoin { .. }));
let graph_index = if needs_graph {
GraphIndexHandle::cached(self, &resolved)
} else {
GraphIndexHandle::none()
};
execute_query(
&ir,
params,
&resolved.snapshot,
&graph_index,
&catalog,
&EmbeddingResolver {
cell: self.embedding_cell(),
config: self.embedding_config_ref(),
},
)
.await
}
pub async fn run_query_at(
&self,
version: u64,
query_source: &str,
query_name: &str,
params: &ParamMap,
) -> Result<QueryResult> {
self.ensure_schema_state_valid().await?;
let snapshot = self.snapshot_at_version(version).await?;
let catalog = self.catalog();
let query_decl = omnigraph_compiler::find_named_query(query_source, query_name)
.map_err(|e| OmniError::manifest(e.to_string()))?;
let type_ctx = typecheck_query(&catalog, &query_decl)?;
let ir = lower_query(&catalog, &query_decl, &type_ctx)?;
let needs_graph = ir
.pipeline
.iter()
.any(|op| matches!(op, IROp::Expand { .. } | IROp::AntiJoin { .. }));
let graph_index = if needs_graph {
let edge_types = catalog
.edge_types
.iter()
.map(|(name, et)| (name.clone(), (et.from_type.clone(), et.to_type.clone())))
.collect();
GraphIndexHandle::direct(&snapshot, edge_types)
} else {
GraphIndexHandle::none()
};
execute_query(
&ir,
params,
&snapshot,
&graph_index,
&catalog,
&EmbeddingResolver {
cell: self.embedding_cell(),
config: self.embedding_config_ref(),
},
)
.await
}
}
#[derive(Debug, Default)]
struct SearchMode {
nearest: Option<(String, String, Vec<f32>, usize)>,
bm25: Option<(String, String, String)>,
rrf: Option<RrfMode>,
}
#[derive(Debug)]
struct RrfMode {
primary: Box<SearchMode>,
secondary: Box<SearchMode>,
k: u32,
limit: usize,
}
async fn extract_search_mode(
ir: &QueryIR,
params: &ParamMap,
catalog: &Catalog,
embedding: &EmbeddingResolver<'_>,
) -> Result<SearchMode> {
if ir.order_by.is_empty() {
return Ok(SearchMode::default());
}
let ordering = &ir.order_by[0];
match &ordering.expr {
IRExpr::Nearest {
variable,
property,
query,
} => {
let vec =
resolve_nearest_query_vec(ir, catalog, variable, property, query, params, embedding)
.await?;
let k = ir.limit.ok_or_else(|| {
OmniError::manifest("nearest() ordering requires a limit clause".to_string())
})? as usize;
Ok(SearchMode {
nearest: Some((variable.clone(), property.clone(), vec, k)),
..Default::default()
})
}
IRExpr::Bm25 { field, query } => {
let var = match field.as_ref() {
IRExpr::PropAccess { variable, .. } => variable.clone(),
_ => {
return Err(OmniError::manifest(
"bm25 field must be a property access".to_string(),
));
}
};
let prop = extract_property(field).ok_or_else(|| {
OmniError::manifest("bm25 field must be a property access".to_string())
})?;
let text = resolve_to_string(query, params).ok_or_else(|| {
OmniError::manifest("bm25 query must resolve to a string".to_string())
})?;
Ok(SearchMode {
bm25: Some((var, prop, text)),
..Default::default()
})
}
IRExpr::Rrf {
primary,
secondary,
k,
} => {
let limit = ir.limit.ok_or_else(|| {
OmniError::manifest("rrf() ordering requires a limit clause".to_string())
})? as usize;
let k_val = k
.as_ref()
.and_then(|e| resolve_to_int(e, params))
.unwrap_or(60) as u32;
let primary_mode =
extract_sub_search_mode(ir, primary, params, catalog, ir.limit, embedding).await?;
let secondary_mode =
extract_sub_search_mode(ir, secondary, params, catalog, ir.limit, embedding)
.await?;
Ok(SearchMode {
rrf: Some(RrfMode {
primary: Box::new(primary_mode),
secondary: Box::new(secondary_mode),
k: k_val,
limit,
}),
..Default::default()
})
}
_ => Ok(SearchMode::default()),
}
}
async fn extract_sub_search_mode(
ir: &QueryIR,
expr: &IRExpr,
params: &ParamMap,
catalog: &Catalog,
limit: Option<u64>,
embedding: &EmbeddingResolver<'_>,
) -> Result<SearchMode> {
match expr {
IRExpr::Nearest {
variable,
property,
query,
} => {
let vec =
resolve_nearest_query_vec(ir, catalog, variable, property, query, params, embedding)
.await?;
let k = limit.unwrap_or(100) as usize;
Ok(SearchMode {
nearest: Some((variable.clone(), property.clone(), vec, k)),
..Default::default()
})
}
IRExpr::Bm25 { field, query } => {
let var = match field.as_ref() {
IRExpr::PropAccess { variable, .. } => variable.clone(),
_ => {
return Err(OmniError::manifest(
"bm25 field must be a property access".to_string(),
));
}
};
let prop = extract_property(field).ok_or_else(|| {
OmniError::manifest("bm25 field must be a property access".to_string())
})?;
let text = resolve_to_string(query, params).ok_or_else(|| {
OmniError::manifest("bm25 query must resolve to a string".to_string())
})?;
Ok(SearchMode {
bm25: Some((var, prop, text)),
..Default::default()
})
}
_ => Ok(SearchMode::default()),
}
}
async fn resolve_nearest_query_vec(
ir: &QueryIR,
catalog: &Catalog,
variable: &str,
property: &str,
expr: &IRExpr,
params: &ParamMap,
embedding: &EmbeddingResolver<'_>,
) -> Result<Vec<f32>> {
let lit = resolve_literal_or_param(expr, params)?;
match lit {
Literal::List(_) => literal_to_f32_vec(&lit),
Literal::String(text) => {
let (expected_dim, recorded_model) =
nearest_property_dim_and_model(ir, catalog, variable, property)?;
let client = embedding.resolve().await?;
if let Some(recorded) = &recorded_model {
let resolved = &client.config().model;
if resolved != recorded {
return Err(OmniError::manifest(format!(
"nearest() on '{property}': its stored vectors were embedded with model \
'{recorded}', but the query embedder resolves to '{resolved}'. Set \
OMNIGRAPH_EMBED_MODEL='{recorded}' (and the matching provider) or re-embed \
the stored vectors."
)));
}
}
client.embed_query_text(&text, expected_dim).await
}
_ => Err(OmniError::manifest(
"nearest query must be a string or list of floats".to_string(),
)),
}
}
fn resolve_literal_or_param(expr: &IRExpr, params: &ParamMap) -> Result<Literal> {
Ok(match expr {
IRExpr::Literal(lit) => lit.clone(),
IRExpr::Param(name) => params
.get(name)
.cloned()
.ok_or_else(|| OmniError::manifest(format!("parameter '{}' not provided", name)))?,
_ => {
return Err(OmniError::manifest(
"nearest query must be a literal or parameter".to_string(),
));
}
})
}
fn literal_to_f32_vec(lit: &Literal) -> Result<Vec<f32>> {
match lit {
Literal::List(items) => items
.iter()
.map(|item| match item {
Literal::Float(f) => Ok(*f as f32),
Literal::Integer(n) => Ok(*n as f32),
_ => Err(OmniError::manifest(
"vector elements must be numeric".to_string(),
)),
})
.collect(),
_ => Err(OmniError::manifest(
"nearest query must be a list of floats".to_string(),
)),
}
}
fn nearest_property_dim_and_model(
ir: &QueryIR,
catalog: &Catalog,
variable: &str,
property: &str,
) -> Result<(usize, Option<String>)> {
let type_name = resolve_binding_type_name(&ir.pipeline, variable).ok_or_else(|| {
OmniError::manifest_internal(format!(
"nearest() variable '${}' is not bound to a node type in the lowered pipeline",
variable
))
})?;
let node_type = catalog.node_types.get(type_name).ok_or_else(|| {
OmniError::manifest_internal(format!(
"nearest() binding '${}' resolved unknown node type '{}'",
variable, type_name
))
})?;
let prop = node_type.properties.get(property).ok_or_else(|| {
OmniError::manifest_internal(format!(
"nearest() property '{}.{}' is missing from the catalog",
type_name, property
))
})?;
let dim = match prop.scalar {
ScalarType::Vector(dim) if !prop.list => dim as usize,
_ => {
return Err(OmniError::manifest_internal(format!(
"nearest() property '{}.{}' is not a scalar vector",
type_name, property
)));
}
};
let recorded_model = node_type
.embed_sources
.get(property)
.and_then(|embed| embed.model.clone());
Ok((dim, recorded_model))
}
fn resolve_binding_type_name<'a>(pipeline: &'a [IROp], variable: &str) -> Option<&'a str> {
for op in pipeline {
match op {
IROp::NodeScan {
variable: bound_var,
type_name,
..
} if bound_var == variable => return Some(type_name.as_str()),
IROp::Expand {
dst_var, dst_type, ..
} if dst_var == variable => return Some(dst_type.as_str()),
IROp::AntiJoin { inner, .. } => {
if let Some(type_name) = resolve_binding_type_name(inner, variable) {
return Some(type_name);
}
}
_ => {}
}
}
None
}
pub async fn execute_query(
ir: &QueryIR,
params: &ParamMap,
snapshot: &Snapshot,
graph_index: &GraphIndexHandle<'_>,
catalog: &Catalog,
embedding: &EmbeddingResolver<'_>,
) -> Result<QueryResult> {
let search_mode = extract_search_mode(ir, params, catalog, embedding).await?;
if let Some(ref rrf) = search_mode.rrf {
return execute_rrf_query(ir, params, snapshot, graph_index, catalog, rrf).await;
}
let mut wide: Option<RecordBatch> = None;
execute_pipeline(
&ir.pipeline,
params,
snapshot,
graph_index,
catalog,
&mut wide,
&search_mode,
)
.await?;
let wide_batch = wide.unwrap_or_else(|| RecordBatch::new_empty(Arc::new(Schema::empty())));
let has_aggregates = ir
.return_exprs
.iter()
.any(|p| matches!(&p.expr, IRExpr::Aggregate { .. }));
let mut result_batch = project_return(&wide_batch, &ir.return_exprs, params)?;
if !ir.order_by.is_empty() && !is_search_ordered(&search_mode) {
result_batch = if has_aggregates {
apply_ordering(result_batch.clone(), &ir.order_by, &result_batch, params)?
} else {
apply_ordering(result_batch, &ir.order_by, &wide_batch, params)?
};
}
if let Some(limit) = ir.limit {
let len = result_batch.num_rows().min(limit as usize);
result_batch = result_batch.slice(0, len);
}
Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
}
fn is_search_ordered(search_mode: &SearchMode) -> bool {
search_mode.nearest.is_some() || search_mode.bm25.is_some()
}
async fn execute_rrf_query(
ir: &QueryIR,
params: &ParamMap,
snapshot: &Snapshot,
graph_index: &GraphIndexHandle<'_>,
catalog: &Catalog,
rrf: &RrfMode,
) -> Result<QueryResult> {
let mut primary_wide: Option<RecordBatch> = None;
execute_pipeline(
&ir.pipeline,
params,
snapshot,
graph_index,
catalog,
&mut primary_wide,
&rrf.primary,
)
.await?;
let mut secondary_wide: Option<RecordBatch> = None;
execute_pipeline(
&ir.pipeline,
params,
snapshot,
graph_index,
catalog,
&mut secondary_wide,
&rrf.secondary,
)
.await?;
let primary_var = rrf
.primary
.nearest
.as_ref()
.map(|(v, ..)| v.as_str())
.or_else(|| rrf.primary.bm25.as_ref().map(|(v, ..)| v.as_str()))
.ok_or_else(|| OmniError::manifest("rrf primary must be nearest or bm25".to_string()))?;
let primary_batch = primary_wide.as_ref().ok_or_else(|| {
OmniError::manifest(format!(
"rrf primary variable '{}' not in bindings",
primary_var
))
})?;
let secondary_batch = secondary_wide.as_ref().ok_or_else(|| {
OmniError::manifest(format!(
"rrf secondary variable '{}' not in bindings",
primary_var
))
})?;
let id_col_name = format!("{}.id", primary_var);
let primary_ids = extract_id_column_by_name(primary_batch, &id_col_name)?;
let secondary_ids = extract_id_column_by_name(secondary_batch, &id_col_name)?;
let mut primary_rank: HashMap<String, usize> = HashMap::new();
for (i, id) in primary_ids.iter().enumerate() {
primary_rank.entry(id.clone()).or_insert(i);
}
let mut secondary_rank: HashMap<String, usize> = HashMap::new();
for (i, id) in secondary_ids.iter().enumerate() {
secondary_rank.entry(id.clone()).or_insert(i);
}
let mut all_ids: Vec<String> = primary_ids.clone();
for id in &secondary_ids {
if !primary_rank.contains_key(id) {
all_ids.push(id.clone());
}
}
let k = rrf.k as f64;
let mut scored: Vec<(String, f64)> = all_ids
.iter()
.map(|id| {
let p = primary_rank
.get(id)
.map(|&r| 1.0 / (k + r as f64 + 1.0))
.unwrap_or(0.0);
let s = secondary_rank
.get(id)
.map(|&r| 1.0 / (k + r as f64 + 1.0))
.unwrap_or(0.0);
(id.clone(), p + s)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(rrf.limit);
let winning_ids: Vec<String> = scored.iter().map(|(id, _)| id.clone()).collect();
let mut id_to_batch_row: HashMap<String, (&RecordBatch, usize)> = HashMap::new();
for (i, id) in primary_ids.iter().enumerate() {
id_to_batch_row
.entry(id.clone())
.or_insert((primary_batch, i));
}
for (i, id) in secondary_ids.iter().enumerate() {
id_to_batch_row
.entry(id.clone())
.or_insert((secondary_batch, i));
}
let fused_batch = build_fused_batch(&winning_ids, &id_to_batch_row, primary_batch.schema())?;
let result_batch = project_return(&fused_batch, &ir.return_exprs, params)?;
Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
}
fn extract_id_column_by_name(batch: &RecordBatch, col_name: &str) -> Result<Vec<String>> {
let col = batch.column_by_name(col_name).ok_or_else(|| {
OmniError::manifest(format!("batch missing '{}' column for RRF", col_name))
})?;
let ids = col
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", col_name)))?;
Ok((0..ids.len()).map(|i| ids.value(i).to_string()).collect())
}
fn build_fused_batch(
ordered_ids: &[String],
id_to_batch_row: &HashMap<String, (&RecordBatch, usize)>,
schema: SchemaRef,
) -> Result<RecordBatch> {
if ordered_ids.is_empty() {
return Ok(RecordBatch::new_empty(schema));
}
let mut row_slices: Vec<RecordBatch> = Vec::with_capacity(ordered_ids.len());
for id in ordered_ids {
if let Some(&(batch, row_idx)) = id_to_batch_row.get(id) {
row_slices.push(batch.slice(row_idx, 1));
}
}
if row_slices.is_empty() {
return Ok(RecordBatch::new_empty(schema));
}
let schema = row_slices[0].schema();
arrow_select::concat::concat_batches(&schema, &row_slices)
.map_err(|e| OmniError::Lance(e.to_string()))
}
fn is_search_filter(filter: &IRFilter) -> bool {
matches!(
&filter.left,
IRExpr::Search { .. } | IRExpr::Fuzzy { .. } | IRExpr::MatchText { .. }
)
}
fn search_filter_variable(filter: &IRFilter) -> Option<&str> {
let field = match &filter.left {
IRExpr::Search { field, .. } => field,
IRExpr::Fuzzy { field, .. } => field,
IRExpr::MatchText { field, .. } => field,
_ => return None,
};
match field.as_ref() {
IRExpr::PropAccess { variable, .. } => Some(variable.as_str()),
_ => None,
}
}
fn execute_pipeline<'a>(
pipeline: &'a [IROp],
params: &'a ParamMap,
snapshot: &'a Snapshot,
graph_index: &'a GraphIndexHandle<'a>,
catalog: &'a Catalog,
wide: &'a mut Option<RecordBatch>,
search_mode: &'a SearchMode,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
Box::pin(async move {
let mut hoisted_search_filters: HashMap<String, Vec<IRFilter>> = HashMap::new();
let mut hoisted_indices: HashSet<usize> = HashSet::new();
for (i, op) in pipeline.iter().enumerate() {
if let IROp::Filter(filter) = op {
if is_search_filter(filter) {
if let Some(var) = search_filter_variable(filter) {
hoisted_search_filters
.entry(var.to_string())
.or_default()
.push(filter.clone());
hoisted_indices.insert(i);
}
}
}
}
for (i, op) in pipeline.iter().enumerate() {
if hoisted_indices.contains(&i) {
continue;
}
match op {
IROp::NodeScan {
variable,
type_name,
filters,
} => {
let mut all_filters: Vec<IRFilter> = filters.clone();
if let Some(extra) = hoisted_search_filters.get(variable) {
all_filters.extend(extra.iter().cloned());
}
let batch = execute_node_scan(
type_name,
variable,
&all_filters,
params,
snapshot,
catalog,
search_mode,
)
.await?;
let prefixed = prefix_batch(&batch, variable)?;
*wide = Some(match wide.take() {
None => prefixed,
Some(existing) => cross_join_batches(&existing, &prefixed)?,
});
}
IROp::Filter(filter) => {
if let Some(batch) = wide.as_mut() {
apply_filter(batch, filter, params)?;
}
}
IROp::Expand {
src_var,
dst_var,
edge_type,
direction,
dst_type,
min_hops,
max_hops,
dst_filters,
} => {
if let Some(batch) = wide.as_mut() {
execute_expand(
batch,
graph_index,
snapshot,
catalog,
src_var,
dst_var,
edge_type,
*direction,
dst_type,
*min_hops,
*max_hops,
dst_filters,
params,
)
.await?;
}
}
IROp::AntiJoin { outer_var, inner } => {
let gi = graph_index;
if let Some(batch) = wide.as_mut() {
execute_anti_join(batch, inner, params, snapshot, gi, catalog, outer_var)
.await?;
}
}
}
}
Ok(())
})
}
pub struct GraphIndexHandle<'a> {
cell: tokio::sync::OnceCell<Option<Arc<GraphIndex>>>,
builder: GraphIndexBuilder<'a>,
}
enum GraphIndexBuilder<'a> {
None,
Cached(&'a Omnigraph, &'a crate::db::ResolvedTarget),
Direct(&'a Snapshot, HashMap<String, (String, String)>),
}
impl<'a> GraphIndexHandle<'a> {
fn none() -> Self {
Self {
cell: tokio::sync::OnceCell::new(),
builder: GraphIndexBuilder::None,
}
}
fn cached(db: &'a Omnigraph, resolved: &'a crate::db::ResolvedTarget) -> Self {
Self {
cell: tokio::sync::OnceCell::new(),
builder: GraphIndexBuilder::Cached(db, resolved),
}
}
fn direct(snapshot: &'a Snapshot, edge_types: HashMap<String, (String, String)>) -> Self {
Self {
cell: tokio::sync::OnceCell::new(),
builder: GraphIndexBuilder::Direct(snapshot, edge_types),
}
}
async fn get(&self) -> Result<Option<&GraphIndex>> {
let built = self
.cell
.get_or_try_init(|| async {
match &self.builder {
GraphIndexBuilder::None => Ok::<Option<Arc<GraphIndex>>, OmniError>(None),
GraphIndexBuilder::Cached(db, resolved) => {
Ok(Some(db.graph_index_for_resolved(resolved).await?))
}
GraphIndexBuilder::Direct(snapshot, edge_types) => {
Ok(Some(Arc::new(GraphIndex::build(snapshot, edge_types).await?)))
}
}
})
.await?;
Ok(built.as_deref())
}
fn is_built(&self) -> bool {
matches!(self.cell.get(), Some(Some(_)))
}
}
fn traversal_indexed_override() -> Option<bool> {
match std::env::var("OMNIGRAPH_TRAVERSAL_MODE").ok().as_deref() {
Some("indexed") => Some(true),
Some("csr") => Some(false),
_ => None,
}
}
const DEFAULT_EXPAND_INDEXED_MAX_FRONTIER: usize = 1024;
const DEFAULT_EXPAND_INDEXED_MAX_HOPS: u32 = 6;
fn expand_indexed_max_frontier() -> usize {
std::env::var("OMNIGRAPH_EXPAND_INDEXED_MAX_FRONTIER")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(DEFAULT_EXPAND_INDEXED_MAX_FRONTIER)
}
fn expand_indexed_max_hops() -> u32 {
std::env::var("OMNIGRAPH_EXPAND_INDEXED_MAX_HOPS")
.ok()
.and_then(|v| v.parse::<u32>().ok())
.filter(|&v| v > 0)
.unwrap_or(DEFAULT_EXPAND_INDEXED_MAX_HOPS)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ExpandMode {
IndexedScan,
Csr,
}
const CSR_BUILD_FACTOR: f64 = 1.5;
#[derive(Debug, Clone)]
struct ExpandCostInputs {
frontier_rows: usize,
edge_count: u64,
src_node_count: u64,
effective_max_hops: u32,
max_hops_cap: u32,
max_frontier_cap: usize,
coverage: crate::table_store::IndexCoverage,
csr_cached: bool,
}
fn choose_expand_mode(i: &ExpandCostInputs) -> ExpandMode {
if i.effective_max_hops > i.max_hops_cap || i.frontier_rows > i.max_frontier_cap {
return ExpandMode::Csr;
}
let hops = i.effective_max_hops.max(1) as f64;
let frontier = i.frontier_rows as f64;
let edges = i.edge_count as f64;
let src = i.src_node_count.max(1) as f64;
let fanout = edges / src;
let indexed_cost = match i.coverage {
crate::table_store::IndexCoverage::Indexed => hops * frontier * fanout,
crate::table_store::IndexCoverage::Degraded { .. } => hops * edges,
};
let csr_cost = if i.csr_cached {
0.0
} else {
CSR_BUILD_FACTOR * edges
};
if indexed_cost < csr_cost {
ExpandMode::IndexedScan
} else {
ExpandMode::Csr
}
}
fn cost_effective_hops(requested_max_hops: u32, same_type: bool) -> u32 {
if same_type {
requested_max_hops
} else {
requested_max_hops.min(1)
}
}
fn gather_cost_inputs(
snapshot: &Snapshot,
catalog: &Catalog,
edge_type: &str,
direction: Direction,
frontier_rows: usize,
effective_max_hops: u32,
coverage: crate::table_store::IndexCoverage,
csr_cached: bool,
) -> Option<ExpandCostInputs> {
let edge_entry = snapshot.entry(&format!("edge:{}", edge_type))?;
let edge_def = catalog.edge_types.get(edge_type)?;
let effective_max_hops =
cost_effective_hops(effective_max_hops, edge_def.from_type == edge_def.to_type);
let src_type = match direction {
Direction::Out => &edge_def.from_type,
Direction::In => &edge_def.to_type,
};
let src_entry = snapshot.entry(&format!("node:{}", src_type))?;
Some(ExpandCostInputs {
frontier_rows,
edge_count: edge_entry.row_count,
src_node_count: src_entry.row_count,
effective_max_hops,
max_hops_cap: expand_indexed_max_hops(),
max_frontier_cap: expand_indexed_max_frontier(),
coverage,
csr_cached,
})
}
fn coverage_for_decision(
coverage: &Result<crate::table_store::IndexCoverage>,
) -> crate::table_store::IndexCoverage {
match coverage {
Ok(c) => c.clone(),
Err(_) => crate::table_store::IndexCoverage::Degraded {
reason: "coverage check failed".to_string(),
},
}
}
fn warn_on_degraded_coverage(
coverage: &Result<crate::table_store::IndexCoverage>,
key_col: &str,
edge_type: &str,
) {
match coverage {
Ok(crate::table_store::IndexCoverage::Degraded { reason }) => tracing::warn!(
target: "omnigraph::traverse",
edge = %edge_type,
key_col = key_col,
reason = %reason,
"indexed traversal falls back to a full edge scan (results correct, perf degraded)"
),
Ok(crate::table_store::IndexCoverage::Indexed) => {}
Err(e) => tracing::debug!(
target: "omnigraph::traverse",
error = %e,
"index-coverage check failed; proceeding with traversal"
),
}
}
fn endpoint_columns(direction: Direction) -> (&'static str, &'static str) {
match direction {
Direction::Out => ("src", "dst"),
Direction::In => ("dst", "src"),
}
}
async fn execute_expand(
wide: &mut RecordBatch,
graph_index: &GraphIndexHandle<'_>,
snapshot: &Snapshot,
catalog: &Catalog,
src_var: &str,
dst_var: &str,
edge_type: &str,
direction: Direction,
dst_type: &str,
min_hops: u32,
max_hops: Option<u32>,
dst_filters: &[IRFilter],
params: &ParamMap,
) -> Result<()> {
let frontier_rows = wide.num_rows();
let effective_max_hops = max_hops.unwrap_or(min_hops.max(1));
let (key_col, _) = endpoint_columns(direction);
let edge_table_key = format!("edge:{}", edge_type);
let forced = traversal_indexed_override();
let lean_indexed = match forced {
Some(v) => v,
None => match gather_cost_inputs(
snapshot,
catalog,
edge_type,
direction,
frontier_rows,
effective_max_hops,
crate::table_store::IndexCoverage::Indexed,
graph_index.is_built(),
) {
Some(inputs) => choose_expand_mode(&inputs) == ExpandMode::IndexedScan,
None => {
frontier_rows <= expand_indexed_max_frontier()
&& effective_max_hops <= expand_indexed_max_hops()
}
},
};
if !lean_indexed {
tracing::debug!(
target: "omnigraph::traverse",
edge = %edge_type,
frontier = frontier_rows,
hops = effective_max_hops,
mode = "csr",
"expand mode chosen",
);
let gi = graph_index.get().await?.ok_or_else(|| {
OmniError::manifest("graph index required for CSR traversal".to_string())
})?;
return execute_expand_csr(
wide, gi, snapshot, catalog, src_var, dst_var, edge_type, direction, dst_type,
min_hops, max_hops, dst_filters, params,
)
.await;
}
let edge_ds = snapshot.open(&edge_table_key).await?;
let coverage =
crate::table_store::TableStore::key_column_index_coverage(&edge_ds, key_col).await;
if forced.is_none() {
if let Some(inputs) = gather_cost_inputs(
snapshot,
catalog,
edge_type,
direction,
frontier_rows,
effective_max_hops,
coverage_for_decision(&coverage),
graph_index.is_built(),
) {
if choose_expand_mode(&inputs) == ExpandMode::Csr {
tracing::debug!(
target: "omnigraph::traverse",
edge = %edge_type,
frontier = frontier_rows,
hops = effective_max_hops,
mode = "csr",
reason = "index coverage degraded",
"expand mode chosen",
);
let gi = graph_index.get().await?.ok_or_else(|| {
OmniError::manifest("graph index required for CSR traversal".to_string())
})?;
return execute_expand_csr(
wide, gi, snapshot, catalog, src_var, dst_var, edge_type, direction, dst_type,
min_hops, max_hops, dst_filters, params,
)
.await;
}
}
}
tracing::debug!(
target: "omnigraph::traverse",
edge = %edge_type,
frontier = frontier_rows,
hops = effective_max_hops,
mode = "indexed",
"expand mode chosen",
);
warn_on_degraded_coverage(&coverage, key_col, edge_type);
execute_expand_indexed(
wide, snapshot, catalog, src_var, dst_var, edge_type, direction, dst_type, min_hops,
max_hops, dst_filters, params, edge_ds,
)
.await
}
async fn execute_expand_indexed(
wide: &mut RecordBatch,
snapshot: &Snapshot,
catalog: &Catalog,
src_var: &str,
dst_var: &str,
edge_type: &str,
direction: Direction,
dst_type: &str,
min_hops: u32,
max_hops: Option<u32>,
dst_filters: &[IRFilter],
params: &ParamMap,
edge_ds: Dataset,
) -> Result<()> {
let src_id_col_name = format!("{}.id", src_var);
let src_ids = wide
.column_by_name(&src_id_col_name)
.ok_or_else(|| {
OmniError::manifest(format!("wide batch missing '{}' column", src_id_col_name))
})?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", src_id_col_name)))?
.clone();
let edge_def = catalog
.edge_types
.get(edge_type)
.ok_or_else(|| OmniError::manifest(format!("unknown edge type '{}'", edge_type)))?;
let same_type = edge_def.from_type == edge_def.to_type;
let (key_col, opp_col) = endpoint_columns(direction);
let max = max_hops.unwrap_or(min_hops.max(1));
let max = if same_type { max } else { max.min(1) };
let mut interner = crate::graph_index::TypeIndex::new();
let n = src_ids.len();
let mut frontiers: Vec<Vec<u32>> = Vec::with_capacity(n);
let mut visited: Vec<HashSet<u32>> = Vec::with_capacity(n);
let mut seen_dst: Vec<HashSet<u32>> = Vec::with_capacity(n);
for i in 0..n {
let sid = interner.get_or_insert(src_ids.value(i));
let mut v = HashSet::new();
if same_type {
v.insert(sid);
}
frontiers.push(vec![sid]);
visited.push(v);
seen_dst.push(HashSet::new());
}
let mut src_indices: Vec<u32> = Vec::new();
let mut dst_dense: Vec<u32> = Vec::new();
for hop in 1..=max {
let mut union_dense: Vec<u32> = Vec::new();
{
let mut seen: HashSet<u32> = HashSet::new();
for f in &frontiers {
for &node in f {
if seen.insert(node) {
union_dense.push(node);
}
}
}
}
if union_dense.is_empty() {
break;
}
let union_keys: Vec<String> = union_dense
.iter()
.map(|&u| {
interner
.to_id(u)
.expect("interned frontier id must resolve")
.to_string()
})
.collect();
let batches = crate::table_store::TableStore::scan_edges_by_endpoint(
&edge_ds, key_col, opp_col, &union_keys,
)
.await?;
let mut neighbor_map: HashMap<u32, Vec<u32>> = HashMap::new();
for batch in &batches {
let keys = batch
.column_by_name(key_col)
.ok_or_else(|| OmniError::manifest(format!("edge batch missing '{}'", key_col)))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| OmniError::manifest(format!("edge '{}' is not Utf8", key_col)))?;
let opps = batch
.column_by_name(opp_col)
.ok_or_else(|| OmniError::manifest(format!("edge batch missing '{}'", opp_col)))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| OmniError::manifest(format!("edge '{}' is not Utf8", opp_col)))?;
for r in 0..batch.num_rows() {
let k = interner.get_or_insert(keys.value(r));
let o = interner.get_or_insert(opps.value(r));
neighbor_map.entry(k).or_default().push(o);
}
}
for i in 0..n {
let cur = std::mem::take(&mut frontiers[i]);
let mut next: Vec<u32> = Vec::new();
for &node in &cur {
let Some(neighbors) = neighbor_map.get(&node) else {
continue;
};
for &neighbor in neighbors {
if !same_type || visited[i].insert(neighbor) {
next.push(neighbor);
if hop >= min_hops && seen_dst[i].insert(neighbor) {
src_indices.push(i as u32);
dst_dense.push(neighbor);
}
}
}
}
frontiers[i] = next;
}
}
let dst_ids: Vec<String> = dst_dense
.iter()
.map(|&d| {
interner
.to_id(d)
.expect("interned dst id must resolve")
.to_string()
})
.collect();
expand_hydrate_and_align(
wide, src_indices, dst_ids, snapshot, catalog, dst_type, dst_var, dst_filters, params,
)
.await
}
async fn expand_hydrate_and_align(
wide: &mut RecordBatch,
src_indices: Vec<u32>,
dst_ids: Vec<String>,
snapshot: &Snapshot,
catalog: &Catalog,
dst_type: &str,
dst_var: &str,
dst_filters: &[IRFilter],
params: &ParamMap,
) -> Result<()> {
let non_pushable: Vec<&IRFilter> = dst_filters
.iter()
.filter(|f| ir_filter_to_expr(f, params, None).is_none())
.collect();
let mut unique_dst_list: Vec<String> = Vec::new();
{
let mut seen: HashSet<&str> = HashSet::with_capacity(dst_ids.len());
for id in &dst_ids {
if seen.insert(id.as_str()) {
unique_dst_list.push(id.clone());
}
}
}
let dst_batch =
hydrate_nodes(snapshot, catalog, dst_type, &unique_dst_list, dst_filters, params).await?;
let dst_batch_id_col = dst_batch
.column_by_name("id")
.ok_or_else(|| OmniError::manifest("hydrated batch missing 'id' column".to_string()))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| OmniError::manifest("hydrated 'id' column is not Utf8".to_string()))?;
let mut id_to_row: HashMap<&str, u32> = HashMap::with_capacity(dst_batch_id_col.len());
for row in 0..dst_batch_id_col.len() {
id_to_row.insert(dst_batch_id_col.value(row), row as u32);
}
let mut final_src_indices: Vec<u32> = Vec::with_capacity(src_indices.len());
let mut dst_indices: Vec<u32> = Vec::with_capacity(src_indices.len());
for (&src_idx, dst_id) in src_indices.iter().zip(dst_ids.iter()) {
if let Some(&dst_row) = id_to_row.get(dst_id.as_str()) {
final_src_indices.push(src_idx);
dst_indices.push(dst_row);
}
}
let src_take = UInt32Array::from(final_src_indices);
let dst_take = UInt32Array::from(dst_indices);
let expanded_wide = take_batch(wide, &src_take)?;
let dst_prefixed = prefix_batch(&dst_batch, dst_var)?;
let aligned_dst = take_batch(&dst_prefixed, &dst_take)?;
*wide = hconcat_batches(&expanded_wide, &aligned_dst)?;
for f in &non_pushable {
apply_filter(wide, f, params)?;
}
Ok(())
}
async fn execute_expand_csr(
wide: &mut RecordBatch,
graph_index: &GraphIndex,
snapshot: &Snapshot,
catalog: &Catalog,
src_var: &str,
dst_var: &str,
edge_type: &str,
direction: Direction,
dst_type: &str,
min_hops: u32,
max_hops: Option<u32>,
dst_filters: &[IRFilter],
params: &ParamMap,
) -> Result<()> {
let src_id_col_name = format!("{}.id", src_var);
let src_ids = wide
.column_by_name(&src_id_col_name)
.ok_or_else(|| {
OmniError::manifest(format!("wide batch missing '{}' column", src_id_col_name))
})?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", src_id_col_name)))?
.clone();
let edge_def = catalog
.edge_types
.get(edge_type)
.ok_or_else(|| OmniError::manifest(format!("unknown edge type '{}'", edge_type)))?;
let (src_type_name, dst_type_name) = match direction {
Direction::Out => (&edge_def.from_type, &edge_def.to_type),
Direction::In => (&edge_def.to_type, &edge_def.from_type),
};
let src_type_idx = graph_index
.type_index(src_type_name)
.ok_or_else(|| OmniError::manifest(format!("no type index for '{}'", src_type_name)))?;
let dst_type_idx = graph_index
.type_index(dst_type_name)
.ok_or_else(|| OmniError::manifest(format!("no type index for '{}'", dst_type_name)))?;
let adj = match direction {
Direction::Out => graph_index.csr(edge_type),
Direction::In => graph_index.csc(edge_type),
}
.ok_or_else(|| OmniError::manifest(format!("no adjacency index for edge '{}'", edge_type)))?;
let max = max_hops.unwrap_or(min_hops.max(1));
let same_type = src_type_name == dst_type_name;
let max = if same_type { max } else { max.min(1) };
let mut src_indices: Vec<u32> = Vec::new();
let mut dst_dense_list: Vec<u32> = Vec::new();
for i in 0..src_ids.len() {
let src_id = src_ids.value(i);
let Some(src_dense) = src_type_idx.to_dense(src_id) else {
continue;
};
let mut frontier: Vec<u32> = vec![src_dense];
let mut visited: HashSet<u32> = HashSet::new();
let mut seen_dst_dense: HashSet<u32> = HashSet::new();
if same_type {
visited.insert(src_dense);
}
for hop in 1..=max {
let mut next_frontier = Vec::new();
for &node in &frontier {
for &neighbor in adj.neighbors(node) {
if !same_type || visited.insert(neighbor) {
next_frontier.push(neighbor);
if hop >= min_hops && seen_dst_dense.insert(neighbor) {
src_indices.push(i as u32);
dst_dense_list.push(neighbor);
}
}
}
}
frontier = next_frontier;
if frontier.is_empty() {
break;
}
}
}
let mut tail_src_indices: Vec<u32> = Vec::with_capacity(src_indices.len());
let mut dst_ids: Vec<String> = Vec::with_capacity(dst_dense_list.len());
for (&s, &d) in src_indices.iter().zip(dst_dense_list.iter()) {
if let Some(id) = dst_type_idx.to_id(d) {
tail_src_indices.push(s);
dst_ids.push(id.to_string());
}
}
expand_hydrate_and_align(
wide,
tail_src_indices,
dst_ids,
snapshot,
catalog,
dst_type,
dst_var,
dst_filters,
params,
)
.await
}
async fn hydrate_nodes(
snapshot: &Snapshot,
catalog: &Catalog,
type_name: &str,
ids: &[String],
dst_filters: &[IRFilter],
params: &ParamMap,
) -> Result<RecordBatch> {
use datafusion::prelude::{col, lit};
let node_type = catalog
.node_types
.get(type_name)
.ok_or_else(|| OmniError::manifest(format!("unknown node type '{}'", type_name)))?;
if ids.is_empty() {
return Ok(RecordBatch::new_empty(node_type.arrow_schema.clone()));
}
let table_key = format!("node:{}", type_name);
let ds = snapshot.open(&table_key).await?;
let id_list: Vec<datafusion::prelude::Expr> = ids.iter().map(|id| lit(id.clone())).collect();
let mut filter_expr = col("id").in_list(id_list, false);
if let Some(dst_expr) = build_lance_filter_expr(dst_filters, params, Some(&node_type.arrow_schema))
{
filter_expr = filter_expr.and(dst_expr);
}
let has_blobs = !node_type.blob_properties.is_empty();
let non_blob_cols: Vec<&str> = node_type
.arrow_schema
.fields()
.iter()
.filter(|f| !node_type.blob_properties.contains(f.name()))
.map(|f| f.name().as_str())
.collect();
let projection = has_blobs.then_some(non_blob_cols.as_slice());
let batches = crate::table_store::TableStore::scan_stream_with(
&ds,
projection,
None,
None,
false,
|scanner| {
scanner.filter_expr(filter_expr);
Ok(())
},
)
.await?
.try_collect::<Vec<RecordBatch>>()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let scan_result = if batches.is_empty() {
return Ok(RecordBatch::new_empty(node_type.arrow_schema.clone()));
} else if batches.len() == 1 {
batches.into_iter().next().unwrap()
} else {
let schema = batches[0].schema();
arrow_select::concat::concat_batches(&schema, &batches)
.map_err(|e| OmniError::Lance(e.to_string()))?
};
if has_blobs {
return add_null_blob_columns(&scan_result, node_type);
}
Ok(scan_result)
}
fn bulk_anti_join_applies(inner_pipeline: &[IROp], outer_var: &str) -> bool {
matches!(
inner_pipeline,
[IROp::Expand { src_var, dst_filters, min_hops, max_hops, .. }]
if src_var == outer_var
&& dst_filters.is_empty()
&& *min_hops == 1
&& (*max_hops).unwrap_or(1) == 1
)
}
fn try_bulk_anti_join_mask(
wide: &RecordBatch,
inner_pipeline: &[IROp],
graph_index: Option<&GraphIndex>,
catalog: &Catalog,
outer_var: &str,
) -> Option<BooleanArray> {
if !bulk_anti_join_applies(inner_pipeline, outer_var) {
return None;
}
let IROp::Expand {
edge_type,
direction,
..
} = &inner_pipeline[0]
else {
return None;
};
let gi = graph_index?;
let edge_def = catalog.edge_types.get(edge_type.as_str())?;
let src_type_name = match direction {
Direction::Out => &edge_def.from_type,
Direction::In => &edge_def.to_type,
};
let adj = match direction {
Direction::Out => gi.csr(edge_type),
Direction::In => gi.csc(edge_type),
}?;
let type_idx = gi.type_index(src_type_name)?;
let id_col_name = format!("{}.id", outer_var);
let outer_ids = wide
.column_by_name(&id_col_name)?
.as_any()
.downcast_ref::<StringArray>()?;
let keep_mask: Vec<bool> = (0..outer_ids.len())
.map(|i| {
let id = outer_ids.value(i);
match type_idx.to_dense(id) {
Some(dense) => !adj.has_neighbors(dense),
None => true, }
})
.collect();
Some(BooleanArray::from(keep_mask))
}
async fn execute_anti_join(
wide: &mut RecordBatch,
inner_pipeline: &[IROp],
params: &ParamMap,
snapshot: &Snapshot,
graph_index: &GraphIndexHandle<'_>,
catalog: &Catalog,
outer_var: &str,
) -> Result<()> {
let gi = if bulk_anti_join_applies(inner_pipeline, outer_var) {
graph_index.get().await?
} else {
None
};
if let Some(mask) = try_bulk_anti_join_mask(wide, inner_pipeline, gi, catalog, outer_var) {
*wide = arrow_select::filter::filter_record_batch(wide, &mask)
.map_err(|e| OmniError::Lance(e.to_string()))?;
return Ok(());
}
let num_rows = wide.num_rows();
if num_rows == 0 {
return Ok(());
}
let tag_col: String = {
let mut n = 0usize;
loop {
let candidate = format!("__antijoin_outer_row_{n}");
if wide.schema().column_with_name(&candidate).is_none() {
break candidate;
}
n += 1;
}
};
let mut fields: Vec<Field> = wide
.schema()
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect();
fields.push(Field::new(tag_col.as_str(), DataType::UInt32, false));
let mut columns: Vec<ArrayRef> = wide.columns().to_vec();
columns.push(Arc::new(UInt32Array::from_iter_values(0..num_rows as u32)));
let tagged = RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
.map_err(|e| OmniError::Lance(e.to_string()))?;
let mut inner_wide: Option<RecordBatch> = Some(tagged);
let no_search = SearchMode::default();
execute_pipeline(
inner_pipeline,
params,
snapshot,
graph_index,
catalog,
&mut inner_wide,
&no_search,
)
.await?;
let mut matched: HashSet<u32> = HashSet::new();
if let Some(batch) = inner_wide {
if batch.num_rows() > 0 {
let tags = batch
.column_by_name(tag_col.as_str())
.ok_or_else(|| {
OmniError::manifest(
"anti-join inner pipeline dropped the correlation column".to_string(),
)
})?
.as_any()
.downcast_ref::<UInt32Array>()
.ok_or_else(|| {
OmniError::manifest(format!("'{}' column is not UInt32", tag_col))
})?;
for i in 0..tags.len() {
matched.insert(tags.value(i));
}
}
}
let keep_mask: Vec<bool> = (0..num_rows as u32).map(|i| !matched.contains(&i)).collect();
let mask = BooleanArray::from(keep_mask);
*wide = arrow_select::filter::filter_record_batch(wide, &mask)
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(())
}
async fn execute_node_scan(
type_name: &str,
variable: &str,
filters: &[IRFilter],
params: &ParamMap,
snapshot: &Snapshot,
catalog: &Catalog,
search_mode: &SearchMode,
) -> Result<RecordBatch> {
let table_key = format!("node:{}", type_name);
let ds = snapshot.open(&table_key).await?;
let node_type = &catalog.node_types[type_name];
let filter_expr = build_lance_filter_expr(filters, params, Some(&node_type.arrow_schema));
let has_blobs = !node_type.blob_properties.is_empty();
let non_blob_cols: Vec<&str> = node_type
.arrow_schema
.fields()
.iter()
.filter(|f| !node_type.blob_properties.contains(f.name()))
.map(|f| f.name().as_str())
.collect();
let projection = has_blobs.then_some(non_blob_cols.as_slice());
let batches = crate::table_store::TableStore::scan_stream_with(
&ds,
projection,
None,
None,
false,
|scanner| {
if let Some(ref expr) = filter_expr {
scanner.filter_expr(expr.clone());
}
for filter in filters {
if is_search_filter(filter) {
if let Some(fts_query) = build_fts_query(&filter.left, params) {
scanner.full_text_search(fts_query).map_err(|e| {
OmniError::Lance(format!("full_text_search filter: {}", e))
})?;
}
}
}
if let Some((ref var, ref prop, ref vec, k)) = search_mode.nearest {
if var == variable {
let query_arr = Float32Array::from(vec.clone());
scanner
.nearest(prop, &query_arr, k)
.map_err(|e| OmniError::Lance(format!("nearest: {}", e)))?;
}
}
if let Some((ref var, ref prop, ref text)) = search_mode.bm25 {
if var == variable {
let fts_query = lance_index::scalar::FullTextSearchQuery::new(text.clone())
.with_column(prop.clone())
.map_err(|e| OmniError::Lance(format!("fts with_column: {}", e)))?;
scanner
.full_text_search(fts_query)
.map_err(|e| OmniError::Lance(format!("full_text_search: {}", e)))?;
}
}
Ok(())
},
)
.await?
.try_collect::<Vec<RecordBatch>>()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let scan_result = if batches.is_empty() {
RecordBatch::new_empty(batches.first().map(|b| b.schema()).unwrap_or_else(|| {
let fields: Vec<_> = node_type
.arrow_schema
.fields()
.iter()
.filter(|f| !node_type.blob_properties.contains(f.name()))
.map(|f| f.as_ref().clone())
.collect();
Arc::new(Schema::new(fields))
}))
} else if batches.len() == 1 {
batches.into_iter().next().unwrap()
} else {
let schema = batches[0].schema();
arrow_select::concat::concat_batches(&schema, &batches)
.map_err(|e| OmniError::Lance(e.to_string()))?
};
if has_blobs {
return add_null_blob_columns(&scan_result, node_type);
}
Ok(scan_result)
}
fn add_null_blob_columns(
batch: &RecordBatch,
node_type: &omnigraph_compiler::catalog::NodeType,
) -> Result<RecordBatch> {
let num_rows = batch.num_rows();
let mut fields = Vec::with_capacity(node_type.arrow_schema.fields().len());
let mut columns: Vec<ArrayRef> = Vec::with_capacity(node_type.arrow_schema.fields().len());
for field in node_type.arrow_schema.fields() {
if node_type.blob_properties.contains(field.name()) {
fields.push(Field::new(field.name(), DataType::Utf8, true));
columns.push(Arc::new(StringArray::from(vec![None::<&str>; num_rows])));
} else if let Some(col) = batch.column_by_name(field.name()) {
let batch_schema = batch.schema();
let batch_field = batch_schema
.field_with_name(field.name())
.map_err(|e| OmniError::Lance(e.to_string()))?;
fields.push(batch_field.clone());
columns.push(col.clone());
}
}
RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
.map_err(|e| OmniError::Lance(e.to_string()))
}
fn build_fts_query(
expr: &IRExpr,
params: &ParamMap,
) -> Option<lance_index::scalar::FullTextSearchQuery> {
match expr {
IRExpr::Search { field, query } => {
let prop = extract_property(field)?;
let q = resolve_to_string(query, params)?;
lance_index::scalar::FullTextSearchQuery::new(q)
.with_column(prop)
.ok()
}
IRExpr::Fuzzy {
field,
query,
max_edits,
} => {
let prop = extract_property(field)?;
let q = resolve_to_string(query, params)?;
let edits = max_edits
.as_ref()
.and_then(|e| resolve_to_int(e, params))
.unwrap_or(2) as u32;
lance_index::scalar::FullTextSearchQuery::new_fuzzy(q, Some(edits))
.with_column(prop)
.ok()
}
IRExpr::MatchText { field, query } => {
let prop = extract_property(field)?;
let q = resolve_to_string(query, params)?;
lance_index::scalar::FullTextSearchQuery::new(q)
.with_column(prop)
.ok()
}
_ => None,
}
}
fn extract_property(expr: &IRExpr) -> Option<String> {
match expr {
IRExpr::PropAccess { property, .. } => Some(property.clone()),
_ => None,
}
}
fn resolve_to_string(expr: &IRExpr, params: &ParamMap) -> Option<String> {
match expr {
IRExpr::Literal(Literal::String(s)) => Some(s.clone()),
IRExpr::Param(name) => match params.get(name)? {
Literal::String(s) => Some(s.clone()),
_ => None,
},
_ => None,
}
}
fn resolve_to_int(expr: &IRExpr, params: &ParamMap) -> Option<i64> {
match expr {
IRExpr::Literal(Literal::Integer(n)) => Some(*n),
IRExpr::Param(name) => match params.get(name)? {
Literal::Integer(n) => Some(*n),
_ => None,
},
_ => None,
}
}
pub(super) fn literal_to_sql(lit: &Literal) -> String {
match lit {
Literal::Null => "NULL".to_string(),
Literal::String(s) => format!("'{}'", s.replace('\'', "''")),
Literal::Integer(n) => n.to_string(),
Literal::Float(f) => f.to_string(),
Literal::Bool(b) => b.to_string(),
Literal::Date(s) => format!("'{}'", s.replace('\'', "''")),
Literal::DateTime(s) => format!("'{}'", s.replace('\'', "''")),
Literal::List(_) => "NULL".to_string(), }
}
pub(super) fn build_lance_filter_expr(
filters: &[IRFilter],
params: &ParamMap,
schema: Option<&Schema>,
) -> Option<datafusion::prelude::Expr> {
use datafusion::logical_expr::Operator;
use datafusion::prelude::Expr;
let mut acc: Option<Expr> = None;
for f in filters {
let Some(e) = ir_filter_to_expr(f, params, schema) else {
continue;
};
acc = Some(match acc {
None => e,
Some(prev) => Expr::BinaryExpr(datafusion::logical_expr::BinaryExpr::new(
Box::new(prev),
Operator::And,
Box::new(e),
)),
});
}
acc
}
pub(super) fn ir_filter_to_expr(
filter: &IRFilter,
params: &ParamMap,
schema: Option<&Schema>,
) -> Option<datafusion::prelude::Expr> {
use datafusion::functions_nested::expr_fn::array_has;
if is_search_filter(filter) {
return None;
}
if matches!(filter.op, CompOp::Contains) {
let left = ir_expr_to_expr(&filter.left, params, None)?;
let right = ir_expr_to_expr(&filter.right, params, None)?;
return Some(array_has(left, right));
}
let left_col_type = prop_data_type(&filter.left, schema);
let right_col_type = prop_data_type(&filter.right, schema);
let left = ir_expr_to_expr(&filter.left, params, right_col_type.as_ref())?;
let right = ir_expr_to_expr(&filter.right, params, left_col_type.as_ref())?;
Some(match filter.op {
CompOp::Eq => left.eq(right),
CompOp::Ne => left.not_eq(right),
CompOp::Gt => left.gt(right),
CompOp::Lt => left.lt(right),
CompOp::Ge => left.gt_eq(right),
CompOp::Le => left.lt_eq(right),
CompOp::Contains => unreachable!("handled above"),
})
}
pub(super) fn ir_expr_to_expr(
expr: &IRExpr,
params: &ParamMap,
target: Option<&arrow_schema::DataType>,
) -> Option<datafusion::prelude::Expr> {
use datafusion::prelude::col;
match expr {
IRExpr::PropAccess { property, .. } => Some(col(property)),
IRExpr::Literal(l) => literal_to_expr_coerced(l, target),
IRExpr::Param(name) => params
.get(name)
.and_then(|l| literal_to_expr_coerced(l, target)),
_ => None,
}
}
fn prop_data_type(expr: &IRExpr, schema: Option<&Schema>) -> Option<arrow_schema::DataType> {
match expr {
IRExpr::PropAccess { property, .. } => schema?
.field_with_name(property)
.ok()
.map(|f| f.data_type().clone()),
_ => None,
}
}
fn literal_to_expr_coerced(
lit: &Literal,
target: Option<&arrow_schema::DataType>,
) -> Option<datafusion::prelude::Expr> {
if let Some(target) = target {
if let Some(e) = literal_to_typed_expr(lit, target) {
return Some(e);
}
}
literal_to_expr(lit)
}
fn literal_to_typed_expr(
lit: &Literal,
target: &arrow_schema::DataType,
) -> Option<datafusion::prelude::Expr> {
use datafusion::prelude::lit as df_lit;
use datafusion::scalar::ScalarValue;
let arr = super::projection::literal_to_array(lit, 1).ok()?;
if arr.data_type() == target {
return Some(df_lit(ScalarValue::try_from_array(&arr, 0).ok()?));
}
let casted = arrow_cast::cast::cast(&arr, target).ok()?;
if target.is_integer() {
let back = arrow_cast::cast::cast(&casted, arr.data_type()).ok()?;
let original = ScalarValue::try_from_array(&arr, 0).ok()?;
let round_tripped = ScalarValue::try_from_array(&back, 0).ok()?;
if original != round_tripped {
return None;
}
}
Some(df_lit(ScalarValue::try_from_array(&casted, 0).ok()?))
}
fn literal_to_expr(lit: &Literal) -> Option<datafusion::prelude::Expr> {
use datafusion::prelude::lit as df_lit;
Some(match lit {
Literal::Null => df_lit(datafusion::scalar::ScalarValue::Null),
Literal::String(s) => df_lit(s.clone()),
Literal::Integer(n) => df_lit(*n),
Literal::Float(f) => df_lit(*f),
Literal::Bool(b) => df_lit(*b),
Literal::Date(s) => df_lit(s.clone()),
Literal::DateTime(s) => df_lit(s.clone()),
Literal::List(_) => return None,
})
}
fn prefix_batch(batch: &RecordBatch, variable: &str) -> Result<RecordBatch> {
let fields: Vec<Field> = batch
.schema()
.fields()
.iter()
.map(|f| {
Field::new(
format!("{}.{}", variable, f.name()),
f.data_type().clone(),
f.is_nullable(),
)
})
.collect();
let schema = Arc::new(Schema::new(fields));
RecordBatch::try_new(schema, batch.columns().to_vec())
.map_err(|e| OmniError::Lance(e.to_string()))
}
fn cross_join_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
let n = left.num_rows();
let m = right.num_rows();
if n == 0 || m == 0 {
let mut fields: Vec<Field> = left
.schema()
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect();
fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
return Ok(RecordBatch::new_empty(Arc::new(Schema::new(fields))));
}
let left_indices: Vec<u32> = (0..n as u32)
.flat_map(|i| std::iter::repeat(i).take(m))
.collect();
let right_indices: Vec<u32> = (0..n).flat_map(|_| 0..m as u32).collect();
let left_expanded = take_batch(left, &UInt32Array::from(left_indices))?;
let right_expanded = take_batch(right, &UInt32Array::from(right_indices))?;
hconcat_batches(&left_expanded, &right_expanded)
}
fn hconcat_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
let mut fields: Vec<Field> = left
.schema()
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect();
if cfg!(debug_assertions) {
let left_schema = left.schema();
let left_names: HashSet<&str> = left_schema
.fields()
.iter()
.map(|f| f.name().as_str())
.collect();
let right_schema = right.schema();
for f in right_schema.fields() {
debug_assert!(
!left_names.contains(f.name().as_str()),
"hconcat_batches: duplicate column '{}'",
f.name()
);
}
}
fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
let mut columns: Vec<ArrayRef> = left.columns().to_vec();
columns.extend(right.columns().to_vec());
RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
.map_err(|e| OmniError::Lance(e.to_string()))
}
fn take_batch(batch: &RecordBatch, indices: &UInt32Array) -> Result<RecordBatch> {
let columns: Vec<ArrayRef> = batch
.columns()
.iter()
.map(|col| arrow_select::take::take(col.as_ref(), indices, None))
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| OmniError::Lance(e.to_string()))?;
RecordBatch::try_new(batch.schema(), columns).map_err(|e| OmniError::Lance(e.to_string()))
}
#[cfg(test)]
mod expand_chooser_tests {
use super::*;
use crate::table_store::IndexCoverage;
fn inputs(
frontier_rows: usize,
edge_count: u64,
src_node_count: u64,
effective_max_hops: u32,
coverage: IndexCoverage,
) -> ExpandCostInputs {
ExpandCostInputs {
frontier_rows,
edge_count,
src_node_count,
effective_max_hops,
max_hops_cap: 6,
max_frontier_cap: 1024,
coverage,
csr_cached: false,
}
}
#[test]
fn selective_frontier_on_large_graph_picks_indexed() {
let m = choose_expand_mode(&inputs(50, 10_000_000, 1_000_000, 1, IndexCoverage::Indexed));
assert_eq!(m, ExpandMode::IndexedScan);
}
#[test]
fn flat_in_edge_count_same_selectivity_same_choice() {
let small = choose_expand_mode(&inputs(50, 100_000, 1_000_000, 1, IndexCoverage::Indexed));
let huge =
choose_expand_mode(&inputs(50, 100_000_000, 1_000_000, 1, IndexCoverage::Indexed));
assert_eq!(small, ExpandMode::IndexedScan);
assert_eq!(huge, ExpandMode::IndexedScan);
}
#[test]
fn frontier_large_fraction_of_source_picks_csr() {
let m = choose_expand_mode(&inputs(200, 1_000, 100, 1, IndexCoverage::Indexed));
assert_eq!(m, ExpandMode::Csr);
}
#[test]
fn frontier_over_hard_cap_picks_csr() {
let m = choose_expand_mode(&inputs(2000, 10_000_000, 1_000_000, 1, IndexCoverage::Indexed));
assert_eq!(m, ExpandMode::Csr);
}
#[test]
fn hops_over_hard_cap_picks_csr() {
let m = choose_expand_mode(&inputs(10, 10_000_000, 1_000_000, 8, IndexCoverage::Indexed));
assert_eq!(m, ExpandMode::Csr);
}
#[test]
fn degraded_single_hop_tiny_frontier_stays_indexed() {
let m = choose_expand_mode(&inputs(
5,
10_000,
10_000,
1,
IndexCoverage::Degraded {
reason: "no btree".into(),
},
));
assert_eq!(m, ExpandMode::IndexedScan);
}
#[test]
fn degraded_multi_hop_picks_csr() {
let m = choose_expand_mode(&inputs(
5,
10_000,
10_000,
2,
IndexCoverage::Degraded {
reason: "no btree".into(),
},
));
assert_eq!(m, ExpandMode::Csr);
}
#[test]
fn warm_csr_is_always_reused() {
let mut i = inputs(1, 10_000_000, 1_000_000, 1, IndexCoverage::Indexed);
i.csr_cached = true;
assert_eq!(choose_expand_mode(&i), ExpandMode::Csr);
}
#[test]
fn cost_model_caps_cross_type_hops() {
assert_eq!(cost_effective_hops(5, true), 5);
assert_eq!(cost_effective_hops(5, false), 1);
assert_eq!(cost_effective_hops(1, false), 1);
let mut i = inputs(50, 10_000, 100, cost_effective_hops(5, false), IndexCoverage::Indexed);
assert_eq!(choose_expand_mode(&i), ExpandMode::IndexedScan);
i.effective_max_hops = 5; assert_eq!(choose_expand_mode(&i), ExpandMode::Csr);
}
}
#[cfg(test)]
mod literal_lowering_tests {
use super::*;
use datafusion::prelude::Expr;
use datafusion::scalar::ScalarValue;
#[test]
fn date_literals_coerce_to_typed_arrow_scalars() {
use arrow_schema::DataType;
let dt = literal_to_expr_coerced(
&Literal::DateTime("2024-06-01T12:00:00Z".into()),
Some(&DataType::Date64),
)
.unwrap();
assert!(
matches!(dt, Expr::Literal(ScalarValue::Date64(Some(_)), ..)),
"DateTime vs Date64 column must coerce to a typed Date64, got {dt:?}"
);
let d = literal_to_expr_coerced(&Literal::Date("2024-06-01".into()), Some(&DataType::Date32))
.unwrap();
assert!(
matches!(d, Expr::Literal(ScalarValue::Date32(Some(_)), ..)),
"Date vs Date32 column must coerce to a typed Date32, got {d:?}"
);
let nat = literal_to_expr_coerced(&Literal::Date("2024-06-01".into()), None).unwrap();
assert!(
matches!(nat, Expr::Literal(ScalarValue::Utf8(Some(_)), ..)),
"no target should keep the natural Utf8 date literal, got {nat:?}"
);
}
#[test]
fn malformed_date_literal_falls_back_to_string() {
use arrow_schema::DataType;
let bad = literal_to_expr_coerced(
&Literal::DateTime("not-a-date".into()),
Some(&DataType::Date64),
)
.unwrap();
assert!(
matches!(bad, Expr::Literal(ScalarValue::Utf8(Some(_)), ..)),
"malformed DateTime literal should fall back to a Utf8 literal, got {bad:?}"
);
}
#[test]
fn integer_literal_coerces_to_narrow_column_type() {
use arrow_schema::DataType;
let i32_lit = literal_to_expr_coerced(&Literal::Integer(5), Some(&DataType::Int32)).unwrap();
assert!(
matches!(i32_lit, Expr::Literal(ScalarValue::Int32(Some(5)), ..)),
"integer literal vs Int32 column must lower to Int32, got {i32_lit:?}"
);
let u32_lit = literal_to_expr_coerced(&Literal::Integer(7), Some(&DataType::UInt32)).unwrap();
assert!(
matches!(u32_lit, Expr::Literal(ScalarValue::UInt32(Some(7)), ..)),
"integer literal vs UInt32 column must lower to UInt32, got {u32_lit:?}"
);
}
#[test]
fn float_literal_coerces_to_f32_column_type() {
use arrow_schema::DataType;
let f32_lit =
literal_to_expr_coerced(&Literal::Float(1.5), Some(&DataType::Float32)).unwrap();
assert!(
matches!(f32_lit, Expr::Literal(ScalarValue::Float32(Some(_)), ..)),
"float literal vs Float32 column must lower to Float32, got {f32_lit:?}"
);
}
#[test]
fn fractional_float_vs_int_column_falls_back_not_truncate() {
use arrow_schema::DataType;
let e = literal_to_expr_coerced(&Literal::Float(2.7), Some(&DataType::Int32)).unwrap();
assert!(
matches!(e, Expr::Literal(ScalarValue::Float64(Some(_)), ..)),
"fractional float vs Int32 must fall back to natural Float64, got {e:?}"
);
}
#[test]
fn whole_float_vs_int_column_coerces() {
use arrow_schema::DataType;
let e = literal_to_expr_coerced(&Literal::Float(2.0), Some(&DataType::Int32)).unwrap();
assert!(
matches!(e, Expr::Literal(ScalarValue::Int32(Some(2)), ..)),
"whole-number float vs Int32 is lossless and must coerce to Int32(2), got {e:?}"
);
}
#[test]
fn out_of_range_int_vs_narrow_column_falls_back() {
use arrow_schema::DataType;
let e = literal_to_expr_coerced(&Literal::Integer(3_000_000_000), Some(&DataType::Int32))
.unwrap();
assert!(
matches!(e, Expr::Literal(ScalarValue::Int64(Some(3_000_000_000)), ..)),
"out-of-range integer vs Int32 must fall back to natural Int64, got {e:?}"
);
}
#[test]
fn float_vs_f32_column_coerces_even_when_not_exactly_representable() {
use arrow_schema::DataType;
let e = literal_to_expr_coerced(&Literal::Float(0.1), Some(&DataType::Float32)).unwrap();
assert!(
matches!(e, Expr::Literal(ScalarValue::Float32(Some(_)), ..)),
"float target must coerce 0.1 to Float32 (exempt from lossless guard), got {e:?}"
);
}
#[test]
fn literal_without_target_keeps_natural_width() {
let nat = literal_to_expr_coerced(&Literal::Integer(5), None).unwrap();
assert!(
matches!(nat, Expr::Literal(ScalarValue::Int64(Some(5)), ..)),
"no target should keep the natural Int64 width, got {nat:?}"
);
}
fn binary_has_int32_literal(e: &Expr) -> bool {
if let Expr::BinaryExpr(b) = e {
[b.left.as_ref(), b.right.as_ref()]
.iter()
.any(|side| matches!(side, Expr::Literal(ScalarValue::Int32(Some(_)), ..)))
} else {
false
}
}
fn int32_schema() -> arrow_schema::Schema {
use arrow_schema::{DataType, Field};
arrow_schema::Schema::new(vec![Field::new("count", DataType::Int32, true)])
}
fn count_prop() -> IRExpr {
IRExpr::PropAccess {
variable: "m".into(),
property: "count".into(),
}
}
#[test]
fn ir_filter_coerces_literal_for_range_op() {
let schema = int32_schema();
let filter = IRFilter {
left: count_prop(),
op: CompOp::Ge,
right: IRExpr::Literal(Literal::Integer(2)),
};
let expr = ir_filter_to_expr(&filter, &ParamMap::new(), Some(&schema)).unwrap();
assert!(
binary_has_int32_literal(&expr),
"range-op literal must coerce to the Int32 column type, got {expr:?}"
);
}
#[test]
fn ir_filter_coerces_literal_when_column_is_on_the_right() {
let schema = int32_schema();
let filter = IRFilter {
left: IRExpr::Literal(Literal::Integer(2)),
op: CompOp::Lt,
right: count_prop(),
};
let expr = ir_filter_to_expr(&filter, &ParamMap::new(), Some(&schema)).unwrap();
assert!(
binary_has_int32_literal(&expr),
"reversed-operand literal must coerce to the Int32 column type, got {expr:?}"
);
}
}