Skip to main content

omnigraph/exec/
query.rs

1use super::*;
2
3use super::projection::{apply_filter, apply_ordering, project_return};
4
5impl Omnigraph {
6    /// Run a named query against an explicit branch or snapshot target.
7    pub async fn query(
8        &self,
9        target: impl Into<ReadTarget>,
10        query_source: &str,
11        query_name: &str,
12        params: &ParamMap,
13    ) -> Result<QueryResult> {
14        self.ensure_schema_state_valid().await?;
15        let resolved = self.resolved_target(target).await?;
16
17        let query_decl = omnigraph_compiler::find_named_query(query_source, query_name)
18            .map_err(|e| OmniError::manifest(e.to_string()))?;
19        let type_ctx = typecheck_query(self.catalog(), &query_decl)?;
20        let ir = lower_query(self.catalog(), &query_decl, &type_ctx)?;
21
22        let needs_graph = ir
23            .pipeline
24            .iter()
25            .any(|op| matches!(op, IROp::Expand { .. } | IROp::AntiJoin { .. }));
26        let graph_index = if needs_graph {
27            Some(self.graph_index_for_resolved(&resolved).await?)
28        } else {
29            None
30        };
31
32        execute_query(
33            &ir,
34            params,
35            &resolved.snapshot,
36            graph_index.as_deref(),
37            self.catalog(),
38        )
39        .await
40    }
41
42    /// Run a named query against the graph as it existed at a prior manifest version.
43    ///
44    /// Compiles the query normally, builds a temporary (non-cached) graph index
45    /// if traversal is needed, and executes against the historical snapshot.
46    pub async fn run_query_at(
47        &self,
48        version: u64,
49        query_source: &str,
50        query_name: &str,
51        params: &ParamMap,
52    ) -> Result<QueryResult> {
53        self.ensure_schema_state_valid().await?;
54        let snapshot = self.snapshot_at_version(version).await?;
55
56        let query_decl = omnigraph_compiler::find_named_query(query_source, query_name)
57            .map_err(|e| OmniError::manifest(e.to_string()))?;
58        let type_ctx = typecheck_query(self.catalog(), &query_decl)?;
59        let ir = lower_query(self.catalog(), &query_decl, &type_ctx)?;
60
61        let needs_graph = ir
62            .pipeline
63            .iter()
64            .any(|op| matches!(op, IROp::Expand { .. } | IROp::AntiJoin { .. }));
65        let graph_index = if needs_graph {
66            let edge_types = self
67                .catalog()
68                .edge_types
69                .iter()
70                .map(|(name, et)| (name.clone(), (et.from_type.clone(), et.to_type.clone())))
71                .collect();
72            Some(Arc::new(GraphIndex::build(&snapshot, &edge_types).await?))
73        } else {
74            None
75        };
76
77        execute_query(
78            &ir,
79            params,
80            &snapshot,
81            graph_index.as_deref(),
82            self.catalog(),
83        )
84        .await
85    }
86}
87
88// ─── Search mode ─────────────────────────────────────────────────────────────
89
90/// Describes how the query's ordering changes the scan mode.
91#[derive(Debug, Default)]
92struct SearchMode {
93    /// Vector ANN search: (variable, property, query_vector, k).
94    nearest: Option<(String, String, Vec<f32>, usize)>,
95    /// BM25 full-text search: (variable, property, query_text).
96    bm25: Option<(String, String, String)>,
97    /// RRF fusion: (primary, secondary, k_constant, limit).
98    rrf: Option<RrfMode>,
99}
100
101#[derive(Debug)]
102struct RrfMode {
103    primary: Box<SearchMode>,
104    secondary: Box<SearchMode>,
105    k: u32,
106    limit: usize,
107}
108
109/// Extract search ordering mode from the IR.
110async fn extract_search_mode(
111    ir: &QueryIR,
112    params: &ParamMap,
113    catalog: &Catalog,
114) -> Result<SearchMode> {
115    if ir.order_by.is_empty() {
116        return Ok(SearchMode::default());
117    }
118    let ordering = &ir.order_by[0];
119    match &ordering.expr {
120        IRExpr::Nearest {
121            variable,
122            property,
123            query,
124        } => {
125            let vec =
126                resolve_nearest_query_vec(ir, catalog, variable, property, query, params).await?;
127            let k = ir.limit.ok_or_else(|| {
128                OmniError::manifest("nearest() ordering requires a limit clause".to_string())
129            })? as usize;
130            Ok(SearchMode {
131                nearest: Some((variable.clone(), property.clone(), vec, k)),
132                ..Default::default()
133            })
134        }
135        IRExpr::Bm25 { field, query } => {
136            let var = match field.as_ref() {
137                IRExpr::PropAccess { variable, .. } => variable.clone(),
138                _ => {
139                    return Err(OmniError::manifest(
140                        "bm25 field must be a property access".to_string(),
141                    ));
142                }
143            };
144            let prop = extract_property(field).ok_or_else(|| {
145                OmniError::manifest("bm25 field must be a property access".to_string())
146            })?;
147            let text = resolve_to_string(query, params).ok_or_else(|| {
148                OmniError::manifest("bm25 query must resolve to a string".to_string())
149            })?;
150            Ok(SearchMode {
151                bm25: Some((var, prop, text)),
152                ..Default::default()
153            })
154        }
155        IRExpr::Rrf {
156            primary,
157            secondary,
158            k,
159        } => {
160            let limit = ir.limit.ok_or_else(|| {
161                OmniError::manifest("rrf() ordering requires a limit clause".to_string())
162            })? as usize;
163            let k_val = k
164                .as_ref()
165                .and_then(|e| resolve_to_int(e, params))
166                .unwrap_or(60) as u32;
167
168            let primary_mode =
169                extract_sub_search_mode(ir, primary, params, catalog, ir.limit).await?;
170            let secondary_mode =
171                extract_sub_search_mode(ir, secondary, params, catalog, ir.limit).await?;
172
173            Ok(SearchMode {
174                rrf: Some(RrfMode {
175                    primary: Box::new(primary_mode),
176                    secondary: Box::new(secondary_mode),
177                    k: k_val,
178                    limit,
179                }),
180                ..Default::default()
181            })
182        }
183        _ => Ok(SearchMode::default()),
184    }
185}
186
187/// Extract a sub-search mode from a nested RRF expression (nearest or bm25).
188async fn extract_sub_search_mode(
189    ir: &QueryIR,
190    expr: &IRExpr,
191    params: &ParamMap,
192    catalog: &Catalog,
193    limit: Option<u64>,
194) -> Result<SearchMode> {
195    match expr {
196        IRExpr::Nearest {
197            variable,
198            property,
199            query,
200        } => {
201            let vec =
202                resolve_nearest_query_vec(ir, catalog, variable, property, query, params).await?;
203            let k = limit.unwrap_or(100) as usize;
204            Ok(SearchMode {
205                nearest: Some((variable.clone(), property.clone(), vec, k)),
206                ..Default::default()
207            })
208        }
209        IRExpr::Bm25 { field, query } => {
210            let var = match field.as_ref() {
211                IRExpr::PropAccess { variable, .. } => variable.clone(),
212                _ => {
213                    return Err(OmniError::manifest(
214                        "bm25 field must be a property access".to_string(),
215                    ));
216                }
217            };
218            let prop = extract_property(field).ok_or_else(|| {
219                OmniError::manifest("bm25 field must be a property access".to_string())
220            })?;
221            let text = resolve_to_string(query, params).ok_or_else(|| {
222                OmniError::manifest("bm25 query must resolve to a string".to_string())
223            })?;
224            Ok(SearchMode {
225                bm25: Some((var, prop, text)),
226                ..Default::default()
227            })
228        }
229        _ => Ok(SearchMode::default()),
230    }
231}
232
233/// Resolve an expression to a nearest() query vector.
234async fn resolve_nearest_query_vec(
235    ir: &QueryIR,
236    catalog: &Catalog,
237    variable: &str,
238    property: &str,
239    expr: &IRExpr,
240    params: &ParamMap,
241) -> Result<Vec<f32>> {
242    let lit = resolve_literal_or_param(expr, params)?;
243    match lit {
244        Literal::List(_) => literal_to_f32_vec(&lit),
245        Literal::String(text) => {
246            let expected_dim = nearest_property_dimension(ir, catalog, variable, property)?;
247            EmbeddingClient::from_env()?
248                .embed_query_text(&text, expected_dim)
249                .await
250        }
251        _ => Err(OmniError::manifest(
252            "nearest query must be a string or list of floats".to_string(),
253        )),
254    }
255}
256
257fn resolve_literal_or_param(expr: &IRExpr, params: &ParamMap) -> Result<Literal> {
258    Ok(match expr {
259        IRExpr::Literal(lit) => lit.clone(),
260        IRExpr::Param(name) => params
261            .get(name)
262            .cloned()
263            .ok_or_else(|| OmniError::manifest(format!("parameter '{}' not provided", name)))?,
264        _ => {
265            return Err(OmniError::manifest(
266                "nearest query must be a literal or parameter".to_string(),
267            ));
268        }
269    })
270}
271
272/// Resolve a literal vector expression to a Vec<f32>.
273fn literal_to_f32_vec(lit: &Literal) -> Result<Vec<f32>> {
274    match lit {
275        Literal::List(items) => items
276            .iter()
277            .map(|item| match item {
278                Literal::Float(f) => Ok(*f as f32),
279                Literal::Integer(n) => Ok(*n as f32),
280                _ => Err(OmniError::manifest(
281                    "vector elements must be numeric".to_string(),
282                )),
283            })
284            .collect(),
285        _ => Err(OmniError::manifest(
286            "nearest query must be a list of floats".to_string(),
287        )),
288    }
289}
290
291fn nearest_property_dimension(
292    ir: &QueryIR,
293    catalog: &Catalog,
294    variable: &str,
295    property: &str,
296) -> Result<usize> {
297    let type_name = resolve_binding_type_name(&ir.pipeline, variable).ok_or_else(|| {
298        OmniError::manifest_internal(format!(
299            "nearest() variable '${}' is not bound to a node type in the lowered pipeline",
300            variable
301        ))
302    })?;
303    let node_type = catalog.node_types.get(type_name).ok_or_else(|| {
304        OmniError::manifest_internal(format!(
305            "nearest() binding '${}' resolved unknown node type '{}'",
306            variable, type_name
307        ))
308    })?;
309    let prop = node_type.properties.get(property).ok_or_else(|| {
310        OmniError::manifest_internal(format!(
311            "nearest() property '{}.{}' is missing from the catalog",
312            type_name, property
313        ))
314    })?;
315    match prop.scalar {
316        ScalarType::Vector(dim) if !prop.list => Ok(dim as usize),
317        _ => Err(OmniError::manifest_internal(format!(
318            "nearest() property '{}.{}' is not a scalar vector",
319            type_name, property
320        ))),
321    }
322}
323
324fn resolve_binding_type_name<'a>(pipeline: &'a [IROp], variable: &str) -> Option<&'a str> {
325    for op in pipeline {
326        match op {
327            IROp::NodeScan {
328                variable: bound_var,
329                type_name,
330                ..
331            } if bound_var == variable => return Some(type_name.as_str()),
332            IROp::Expand {
333                dst_var, dst_type, ..
334            } if dst_var == variable => return Some(dst_type.as_str()),
335            IROp::AntiJoin { inner, .. } => {
336                if let Some(type_name) = resolve_binding_type_name(inner, variable) {
337                    return Some(type_name);
338                }
339            }
340            _ => {}
341        }
342    }
343    None
344}
345
346/// Execute a lowered QueryIR. Pure function — no state, no caches.
347pub async fn execute_query(
348    ir: &QueryIR,
349    params: &ParamMap,
350    snapshot: &Snapshot,
351    graph_index: Option<&GraphIndex>,
352    catalog: &Catalog,
353) -> Result<QueryResult> {
354    let search_mode = extract_search_mode(ir, params, catalog).await?;
355
356    // RRF requires forked execution
357    if let Some(ref rrf) = search_mode.rrf {
358        return execute_rrf_query(ir, params, snapshot, graph_index, catalog, rrf).await;
359    }
360
361    let mut wide: Option<RecordBatch> = None;
362    execute_pipeline(&ir.pipeline, params, snapshot, graph_index, catalog, &mut wide, &search_mode).await?;
363    let wide_batch = wide.unwrap_or_else(|| RecordBatch::new_empty(Arc::new(Schema::empty())));
364
365    // Project return expressions
366    let has_aggregates = ir.return_exprs.iter().any(|p| matches!(&p.expr, IRExpr::Aggregate { .. }));
367    let mut result_batch = project_return(&wide_batch, &ir.return_exprs, params)?;
368
369    // Apply ordering (skip if search mode already ordered the results)
370    if !ir.order_by.is_empty() && !is_search_ordered(&search_mode) {
371        result_batch = if has_aggregates {
372            apply_ordering(result_batch.clone(), &ir.order_by, &result_batch, params)?
373        } else {
374            apply_ordering(result_batch, &ir.order_by, &wide_batch, params)?
375        };
376    }
377
378    // Apply limit
379    if let Some(limit) = ir.limit {
380        let len = result_batch.num_rows().min(limit as usize);
381        result_batch = result_batch.slice(0, len);
382    }
383
384    Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
385}
386
387/// Check if the search mode already returns results in the correct order.
388fn is_search_ordered(search_mode: &SearchMode) -> bool {
389    search_mode.nearest.is_some() || search_mode.bm25.is_some()
390}
391
392/// Execute a query with RRF (Reciprocal Rank Fusion) ordering.
393async fn execute_rrf_query(
394    ir: &QueryIR,
395    params: &ParamMap,
396    snapshot: &Snapshot,
397    graph_index: Option<&GraphIndex>,
398    catalog: &Catalog,
399    rrf: &RrfMode,
400) -> Result<QueryResult> {
401    // Execute primary search
402    let mut primary_wide: Option<RecordBatch> = None;
403    execute_pipeline(
404        &ir.pipeline,
405        params,
406        snapshot,
407        graph_index,
408        catalog,
409        &mut primary_wide,
410        &rrf.primary,
411    )
412    .await?;
413
414    // Execute secondary search
415    let mut secondary_wide: Option<RecordBatch> = None;
416    execute_pipeline(
417        &ir.pipeline,
418        params,
419        snapshot,
420        graph_index,
421        catalog,
422        &mut secondary_wide,
423        &rrf.secondary,
424    )
425    .await?;
426
427    // For RRF, we need to find the main binding variable
428    // (the one that both searches operate on)
429    let primary_var = rrf
430        .primary
431        .nearest
432        .as_ref()
433        .map(|(v, ..)| v.as_str())
434        .or_else(|| rrf.primary.bm25.as_ref().map(|(v, ..)| v.as_str()))
435        .ok_or_else(|| OmniError::manifest("rrf primary must be nearest or bm25".to_string()))?;
436
437    let primary_batch = primary_wide.as_ref().ok_or_else(|| {
438        OmniError::manifest(format!(
439            "rrf primary variable '{}' not in bindings",
440            primary_var
441        ))
442    })?;
443    let secondary_batch = secondary_wide.as_ref().ok_or_else(|| {
444        OmniError::manifest(format!(
445            "rrf secondary variable '{}' not in bindings",
446            primary_var
447        ))
448    })?;
449
450    // Build ID → rank maps
451    let id_col_name = format!("{}.id", primary_var);
452    let primary_ids = extract_id_column_by_name(primary_batch, &id_col_name)?;
453    let secondary_ids = extract_id_column_by_name(secondary_batch, &id_col_name)?;
454
455    let mut primary_rank: HashMap<String, usize> = HashMap::new();
456    for (i, id) in primary_ids.iter().enumerate() {
457        primary_rank.entry(id.clone()).or_insert(i);
458    }
459    let mut secondary_rank: HashMap<String, usize> = HashMap::new();
460    for (i, id) in secondary_ids.iter().enumerate() {
461        secondary_rank.entry(id.clone()).or_insert(i);
462    }
463
464    // Collect all unique IDs
465    let mut all_ids: Vec<String> = primary_ids.clone();
466    for id in &secondary_ids {
467        if !primary_rank.contains_key(id) {
468            all_ids.push(id.clone());
469        }
470    }
471
472    // Compute RRF scores
473    let k = rrf.k as f64;
474    let mut scored: Vec<(String, f64)> = all_ids
475        .iter()
476        .map(|id| {
477            let p = primary_rank
478                .get(id)
479                .map(|&r| 1.0 / (k + r as f64 + 1.0))
480                .unwrap_or(0.0);
481            let s = secondary_rank
482                .get(id)
483                .map(|&r| 1.0 / (k + r as f64 + 1.0))
484                .unwrap_or(0.0);
485            (id.clone(), p + s)
486        })
487        .collect();
488    scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
489    scored.truncate(rrf.limit);
490
491    // Collect winning IDs in order — look up rows from primary or secondary batch
492    let winning_ids: Vec<String> = scored.iter().map(|(id, _)| id.clone()).collect();
493
494    // Build a combined row source: merge primary and secondary by id
495    let mut id_to_batch_row: HashMap<String, (&RecordBatch, usize)> = HashMap::new();
496    for (i, id) in primary_ids.iter().enumerate() {
497        id_to_batch_row
498            .entry(id.clone())
499            .or_insert((primary_batch, i));
500    }
501    for (i, id) in secondary_ids.iter().enumerate() {
502        id_to_batch_row
503            .entry(id.clone())
504            .or_insert((secondary_batch, i));
505    }
506
507    // Reconstruct a combined batch for the binding in winning order
508    let fused_batch = build_fused_batch(&winning_ids, &id_to_batch_row, primary_batch.schema())?;
509
510    // Project directly from fused batch
511    let result_batch = project_return(&fused_batch, &ir.return_exprs, params)?;
512
513    // Already ordered by RRF score + already limited
514    Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
515}
516
517fn extract_id_column_by_name(batch: &RecordBatch, col_name: &str) -> Result<Vec<String>> {
518    let col = batch
519        .column_by_name(col_name)
520        .ok_or_else(|| OmniError::manifest(format!("batch missing '{}' column for RRF", col_name)))?;
521    let ids = col
522        .as_any()
523        .downcast_ref::<StringArray>()
524        .ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", col_name)))?;
525    Ok((0..ids.len()).map(|i| ids.value(i).to_string()).collect())
526}
527
528fn build_fused_batch(
529    ordered_ids: &[String],
530    id_to_batch_row: &HashMap<String, (&RecordBatch, usize)>,
531    schema: SchemaRef,
532) -> Result<RecordBatch> {
533    if ordered_ids.is_empty() {
534        return Ok(RecordBatch::new_empty(schema));
535    }
536
537    // Gather indices from source batches, collecting rows in the right order
538    let mut row_slices: Vec<RecordBatch> = Vec::with_capacity(ordered_ids.len());
539    for id in ordered_ids {
540        if let Some(&(batch, row_idx)) = id_to_batch_row.get(id) {
541            row_slices.push(batch.slice(row_idx, 1));
542        }
543    }
544
545    if row_slices.is_empty() {
546        return Ok(RecordBatch::new_empty(schema));
547    }
548
549    let schema = row_slices[0].schema();
550    arrow_select::concat::concat_batches(&schema, &row_slices)
551        .map_err(|e| OmniError::Lance(e.to_string()))
552}
553
554/// Check if a filter is a text search filter that needs Lance SQL pushdown.
555fn is_search_filter(filter: &IRFilter) -> bool {
556    matches!(
557        &filter.left,
558        IRExpr::Search { .. } | IRExpr::Fuzzy { .. } | IRExpr::MatchText { .. }
559    )
560}
561
562/// Extract the variable name from a search filter's field expression.
563fn search_filter_variable(filter: &IRFilter) -> Option<&str> {
564    let field = match &filter.left {
565        IRExpr::Search { field, .. } => field,
566        IRExpr::Fuzzy { field, .. } => field,
567        IRExpr::MatchText { field, .. } => field,
568        _ => return None,
569    };
570    match field.as_ref() {
571        IRExpr::PropAccess { variable, .. } => Some(variable.as_str()),
572        _ => None,
573    }
574}
575
576fn execute_pipeline<'a>(
577    pipeline: &'a [IROp],
578    params: &'a ParamMap,
579    snapshot: &'a Snapshot,
580    graph_index: Option<&'a GraphIndex>,
581    catalog: &'a Catalog,
582    wide: &'a mut Option<RecordBatch>,
583    search_mode: &'a SearchMode,
584) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
585    Box::pin(async move {
586        // Pre-pass: collect search filters that need to be hoisted to NodeScan
587        let mut hoisted_search_filters: HashMap<String, Vec<IRFilter>> = HashMap::new();
588        let mut hoisted_indices: HashSet<usize> = HashSet::new();
589        for (i, op) in pipeline.iter().enumerate() {
590            if let IROp::Filter(filter) = op {
591                if is_search_filter(filter) {
592                    if let Some(var) = search_filter_variable(filter) {
593                        hoisted_search_filters
594                            .entry(var.to_string())
595                            .or_default()
596                            .push(filter.clone());
597                        hoisted_indices.insert(i);
598                    }
599                }
600            }
601        }
602
603        for (i, op) in pipeline.iter().enumerate() {
604            // Skip hoisted search filters
605            if hoisted_indices.contains(&i) {
606                continue;
607            }
608            match op {
609                IROp::NodeScan {
610                    variable,
611                    type_name,
612                    filters,
613                } => {
614                    // Merge inline filters with hoisted search filters
615                    let mut all_filters: Vec<IRFilter> = filters.clone();
616                    if let Some(extra) = hoisted_search_filters.get(variable) {
617                        all_filters.extend(extra.iter().cloned());
618                    }
619                    let batch = execute_node_scan(
620                        type_name,
621                        variable,
622                        &all_filters,
623                        params,
624                        snapshot,
625                        catalog,
626                        search_mode,
627                    )
628                    .await?;
629                    let prefixed = prefix_batch(&batch, variable)?;
630                    *wide = Some(match wide.take() {
631                        None => prefixed,
632                        Some(existing) => cross_join_batches(&existing, &prefixed)?,
633                    });
634                }
635                IROp::Filter(filter) => {
636                    if let Some(batch) = wide.as_mut() {
637                        apply_filter(batch, filter, params)?;
638                    }
639                }
640                IROp::Expand {
641                    src_var,
642                    dst_var,
643                    edge_type,
644                    direction,
645                    dst_type,
646                    min_hops,
647                    max_hops,
648                    dst_filters,
649                } => {
650                    let gi = graph_index.ok_or_else(|| {
651                        OmniError::manifest("graph index required for traversal".to_string())
652                    })?;
653                    if let Some(batch) = wide.as_mut() {
654                        execute_expand(
655                            batch, gi, snapshot, catalog, src_var, dst_var, edge_type, *direction,
656                            dst_type, *min_hops, *max_hops, dst_filters, params,
657                        )
658                        .await?;
659                    }
660                }
661                IROp::AntiJoin { outer_var, inner } => {
662                    let gi = graph_index;
663                    if let Some(batch) = wide.as_mut() {
664                        execute_anti_join(batch, inner, params, snapshot, gi, catalog, outer_var)
665                            .await?;
666                    }
667                }
668            }
669        }
670        Ok(())
671    })
672}
673
674/// Execute a graph traversal (Expand).
675async fn execute_expand(
676    wide: &mut RecordBatch,
677    graph_index: &GraphIndex,
678    snapshot: &Snapshot,
679    catalog: &Catalog,
680    src_var: &str,
681    dst_var: &str,
682    edge_type: &str,
683    direction: Direction,
684    dst_type: &str,
685    min_hops: u32,
686    max_hops: Option<u32>,
687    dst_filters: &[IRFilter],
688    params: &ParamMap,
689) -> Result<()> {
690    let src_id_col_name = format!("{}.id", src_var);
691    let src_ids = wide
692        .column_by_name(&src_id_col_name)
693        .ok_or_else(|| OmniError::manifest(format!("wide batch missing '{}' column", src_id_col_name)))?
694        .as_any()
695        .downcast_ref::<StringArray>()
696        .ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", src_id_col_name)))?
697        .clone();
698
699    // Determine which type index to use for source and destination
700    let edge_def = catalog
701        .edge_types
702        .get(edge_type)
703        .ok_or_else(|| OmniError::manifest(format!("unknown edge type '{}'", edge_type)))?;
704
705    let (src_type_name, dst_type_name) = match direction {
706        Direction::Out => (&edge_def.from_type, &edge_def.to_type),
707        Direction::In => (&edge_def.to_type, &edge_def.from_type),
708    };
709
710    let src_type_idx = graph_index
711        .type_index(src_type_name)
712        .ok_or_else(|| OmniError::manifest(format!("no type index for '{}'", src_type_name)))?;
713    let dst_type_idx = graph_index
714        .type_index(dst_type_name)
715        .ok_or_else(|| OmniError::manifest(format!("no type index for '{}'", dst_type_name)))?;
716
717    let adj = match direction {
718        Direction::Out => graph_index.csr(edge_type),
719        Direction::In => graph_index.csc(edge_type),
720    }
721    .ok_or_else(|| OmniError::manifest(format!("no adjacency index for edge '{}'", edge_type)))?;
722
723    let max = max_hops.unwrap_or(min_hops.max(1));
724
725    let same_type = src_type_name == dst_type_name;
726
727    // BFS to collect (src_row_idx, dst_id) pairs with per-source dedup
728    let mut src_indices: Vec<u32> = Vec::new();
729    let mut dst_id_list: Vec<String> = Vec::new();
730    for i in 0..src_ids.len() {
731        let src_id = src_ids.value(i);
732        let Some(src_dense) = src_type_idx.to_dense(src_id) else {
733            continue;
734        };
735
736        // BFS with hop tracking
737        let mut frontier: Vec<u32> = vec![src_dense];
738        let mut visited: HashSet<u32> = HashSet::new();
739        let mut seen_dst_ids: HashSet<String> = HashSet::new();
740        // Only track visited in the destination namespace for same-type edges
741        // (to avoid revisiting the source). For cross-type edges, dense indices
742        // are in different namespaces so collision is impossible.
743        if same_type {
744            visited.insert(src_dense);
745        }
746
747        for hop in 1..=max {
748            let mut next_frontier = Vec::new();
749            for &node in &frontier {
750                for &neighbor in adj.neighbors(node) {
751                    if !same_type || visited.insert(neighbor) {
752                        next_frontier.push(neighbor);
753                        if hop >= min_hops {
754                            if let Some(dst_id) = dst_type_idx.to_id(neighbor) {
755                                let dst_id = dst_id.to_string();
756                                if seen_dst_ids.insert(dst_id.clone()) {
757                                    src_indices.push(i as u32);
758                                    dst_id_list.push(dst_id);
759                                }
760                            }
761                        }
762                    }
763                }
764            }
765            frontier = next_frontier;
766            if frontier.is_empty() {
767                break;
768            }
769        }
770    }
771
772    // Split dst_filters: SQL-pushable go to Lance, the rest applied post-hconcat
773    let pushdown_sql = build_lance_filter(dst_filters, params);
774    let non_pushable: Vec<&IRFilter> = dst_filters
775        .iter()
776        .filter(|f| ir_filter_to_sql(f, params).is_none())
777        .collect();
778
779    // Hydrate destination nodes from the snapshot (with pushed-down filters)
780    let dst_batch = hydrate_nodes(snapshot, catalog, dst_type, &dst_id_list, pushdown_sql.as_deref()).await?;
781
782    // Build a mapping from dst_id to row index in dst_batch
783    let dst_batch_id_col = dst_batch
784        .column_by_name("id")
785        .ok_or_else(|| OmniError::manifest("hydrated batch missing 'id' column".to_string()))?
786        .as_any()
787        .downcast_ref::<StringArray>()
788        .ok_or_else(|| OmniError::manifest("hydrated 'id' column is not Utf8".to_string()))?;
789    let dst_id_to_row: HashMap<&str, usize> = (0..dst_batch_id_col.len())
790        .map(|i| (dst_batch_id_col.value(i), i))
791        .collect();
792
793    // Build aligned src/dst index arrays (only for IDs that exist in hydrated batch)
794    let mut final_src_indices: Vec<u32> = Vec::new();
795    let mut dst_indices: Vec<u32> = Vec::new();
796    for (src_idx, dst_id) in src_indices.iter().zip(dst_id_list.iter()) {
797        if let Some(&dst_row) = dst_id_to_row.get(dst_id.as_str()) {
798            final_src_indices.push(*src_idx);
799            dst_indices.push(dst_row as u32);
800        }
801    }
802
803    let src_take = UInt32Array::from(final_src_indices);
804    let dst_take = UInt32Array::from(dst_indices);
805    let expanded_wide = take_batch(wide, &src_take)?;
806    let dst_prefixed = prefix_batch(&dst_batch, dst_var)?;
807    let aligned_dst = take_batch(&dst_prefixed, &dst_take)?;
808    *wide = hconcat_batches(&expanded_wide, &aligned_dst)?;
809
810    // Apply any non-pushable destination filters (e.g. list-contains) in memory
811    for f in &non_pushable {
812        apply_filter(wide, f, params)?;
813    }
814
815    Ok(())
816}
817
818/// Load full node rows for a set of IDs from a snapshot.
819///
820/// When `extra_filter_sql` is provided (from deferred destination-binding
821/// filters), it is ANDed with the `id IN (...)` clause so that Lance can
822/// skip non-matching rows at the storage level.
823async fn hydrate_nodes(
824    snapshot: &Snapshot,
825    catalog: &Catalog,
826    type_name: &str,
827    ids: &[String],
828    extra_filter_sql: Option<&str>,
829) -> Result<RecordBatch> {
830    let node_type = catalog
831        .node_types
832        .get(type_name)
833        .ok_or_else(|| OmniError::manifest(format!("unknown node type '{}'", type_name)))?;
834
835    if ids.is_empty() {
836        return Ok(RecordBatch::new_empty(node_type.arrow_schema.clone()));
837    }
838
839    let table_key = format!("node:{}", type_name);
840    let ds = snapshot.open(&table_key).await?;
841
842    // Build filter: id IN ('a', 'b', 'c')
843    let escaped: Vec<String> = ids
844        .iter()
845        .map(|id| format!("'{}'", id.replace('\'', "''")))
846        .collect();
847    let mut filter_sql = format!("id IN ({})", escaped.join(", "));
848    if let Some(extra) = extra_filter_sql {
849        filter_sql = format!("({}) AND ({})", filter_sql, extra);
850    }
851    let has_blobs = !node_type.blob_properties.is_empty();
852    let non_blob_cols: Vec<&str> = node_type
853        .arrow_schema
854        .fields()
855        .iter()
856        .filter(|f| !node_type.blob_properties.contains(f.name()))
857        .map(|f| f.name().as_str())
858        .collect();
859    let projection = has_blobs.then_some(non_blob_cols.as_slice());
860    let batches = crate::table_store::TableStore::scan_stream(
861        &ds,
862        projection,
863        Some(&filter_sql),
864        None,
865        false,
866    )
867    .await?
868    .try_collect::<Vec<RecordBatch>>()
869    .await
870    .map_err(|e| OmniError::Lance(e.to_string()))?;
871
872    let scan_result = if batches.is_empty() {
873        return Ok(RecordBatch::new_empty(node_type.arrow_schema.clone()));
874    } else if batches.len() == 1 {
875        batches.into_iter().next().unwrap()
876    } else {
877        let schema = batches[0].schema();
878        arrow_select::concat::concat_batches(&schema, &batches)
879            .map_err(|e| OmniError::Lance(e.to_string()))?
880    };
881
882    if has_blobs {
883        return add_null_blob_columns(&scan_result, node_type);
884    }
885    Ok(scan_result)
886}
887
888/// Try bulk anti-join via CSR existence check. Returns Some(mask) if the inner
889/// pipeline is a single Expand from outer_var (the common negation pattern).
890fn try_bulk_anti_join_mask(
891    wide: &RecordBatch,
892    inner_pipeline: &[IROp],
893    graph_index: Option<&GraphIndex>,
894    catalog: &Catalog,
895    outer_var: &str,
896) -> Option<BooleanArray> {
897    if inner_pipeline.len() != 1 {
898        return None;
899    }
900    let IROp::Expand {
901        src_var,
902        edge_type,
903        direction,
904        dst_filters,
905        ..
906    } = &inner_pipeline[0]
907    else {
908        return None;
909    };
910    if src_var != outer_var {
911        return None;
912    }
913    // Bulk CSR check only tests neighbor existence, not destination
914    // properties.  Fall back to the slow path when dst_filters are present.
915    if !dst_filters.is_empty() {
916        return None;
917    }
918    let gi = graph_index?;
919    let edge_def = catalog.edge_types.get(edge_type.as_str())?;
920
921    let src_type_name = match direction {
922        Direction::Out => &edge_def.from_type,
923        Direction::In => &edge_def.to_type,
924    };
925    let adj = match direction {
926        Direction::Out => gi.csr(edge_type),
927        Direction::In => gi.csc(edge_type),
928    }?;
929    let type_idx = gi.type_index(src_type_name)?;
930
931    let id_col_name = format!("{}.id", outer_var);
932    let outer_ids = wide
933        .column_by_name(&id_col_name)?
934        .as_any()
935        .downcast_ref::<StringArray>()?;
936
937    let keep_mask: Vec<bool> = (0..outer_ids.len())
938        .map(|i| {
939            let id = outer_ids.value(i);
940            match type_idx.to_dense(id) {
941                Some(dense) => !adj.has_neighbors(dense),
942                None => true, // not in graph index = no edges = keep
943            }
944        })
945        .collect();
946
947    Some(BooleanArray::from(keep_mask))
948}
949
950/// Execute an AntiJoin: remove rows from wide batch where the inner pipeline finds matches.
951async fn execute_anti_join(
952    wide: &mut RecordBatch,
953    inner_pipeline: &[IROp],
954    params: &ParamMap,
955    snapshot: &Snapshot,
956    graph_index: Option<&GraphIndex>,
957    catalog: &Catalog,
958    outer_var: &str,
959) -> Result<()> {
960    // Fast path: bulk CSR existence check (O(N), zero Lance I/O)
961    if let Some(mask) =
962        try_bulk_anti_join_mask(wide, inner_pipeline, graph_index, catalog, outer_var)
963    {
964        *wide = arrow_select::filter::filter_record_batch(wide, &mask)
965            .map_err(|e| OmniError::Lance(e.to_string()))?;
966        return Ok(());
967    }
968
969    // Slow path: per-row inner pipeline execution
970    let num_rows = wide.num_rows();
971    let mut keep_mask = vec![true; num_rows];
972
973    for i in 0..num_rows {
974        let single_row = wide.slice(i, 1);
975        let mut inner_wide: Option<RecordBatch> = Some(single_row);
976
977        let no_search = SearchMode::default();
978        execute_pipeline(
979            inner_pipeline,
980            params,
981            snapshot,
982            graph_index,
983            catalog,
984            &mut inner_wide,
985            &no_search,
986        )
987        .await?;
988
989        let has_match = inner_wide
990            .as_ref()
991            .map(|batch| batch.num_rows() > 0)
992            .unwrap_or(false);
993
994        if has_match {
995            keep_mask[i] = false;
996        }
997    }
998
999    let mask = BooleanArray::from(keep_mask);
1000    *wide = arrow_select::filter::filter_record_batch(wide, &mask)
1001        .map_err(|e| OmniError::Lance(e.to_string()))?;
1002    Ok(())
1003}
1004
1005/// Scan a node type's Lance dataset with optional filter pushdown and search modes.
1006async fn execute_node_scan(
1007    type_name: &str,
1008    variable: &str,
1009    filters: &[IRFilter],
1010    params: &ParamMap,
1011    snapshot: &Snapshot,
1012    catalog: &Catalog,
1013    search_mode: &SearchMode,
1014) -> Result<RecordBatch> {
1015    let table_key = format!("node:{}", type_name);
1016    let ds = snapshot.open(&table_key).await?;
1017
1018    // Build Lance SQL filter string from non-search IR filters
1019    let filter_sql = build_lance_filter(filters, params);
1020
1021    // Blob columns must be excluded from scan when a filter is present
1022    // (Lance bug: BlobsDescriptions + filter triggers a projection assertion).
1023    // We exclude blob columns and add metadata post-scan via take_blobs_by_indices.
1024    let node_type = &catalog.node_types[type_name];
1025    let has_blobs = !node_type.blob_properties.is_empty();
1026    let non_blob_cols: Vec<&str> = node_type
1027        .arrow_schema
1028        .fields()
1029        .iter()
1030        .filter(|f| !node_type.blob_properties.contains(f.name()))
1031        .map(|f| f.name().as_str())
1032        .collect();
1033    let projection = has_blobs.then_some(non_blob_cols.as_slice());
1034    let batches = crate::table_store::TableStore::scan_stream_with(
1035        &ds,
1036        projection,
1037        filter_sql.as_deref(),
1038        None,
1039        false,
1040        |scanner| {
1041            // Apply FTS queries from hoisted search filters (search/fuzzy/match_text in match clause)
1042            for filter in filters {
1043                if is_search_filter(filter) {
1044                    if let Some(fts_query) = build_fts_query(&filter.left, params) {
1045                        scanner.full_text_search(fts_query).map_err(|e| {
1046                            OmniError::Lance(format!("full_text_search filter: {}", e))
1047                        })?;
1048                    }
1049                }
1050            }
1051
1052            // Apply nearest vector search if this variable is the target
1053            if let Some((ref var, ref prop, ref vec, k)) = search_mode.nearest {
1054                if var == variable {
1055                    let query_arr = Float32Array::from(vec.clone());
1056                    scanner
1057                        .nearest(prop, &query_arr, k)
1058                        .map_err(|e| OmniError::Lance(format!("nearest: {}", e)))?;
1059                }
1060            }
1061
1062            // Apply BM25 full-text search if this variable is the target
1063            if let Some((ref var, ref prop, ref text)) = search_mode.bm25 {
1064                if var == variable {
1065                    let fts_query = lance_index::scalar::FullTextSearchQuery::new(text.clone())
1066                        .with_column(prop.clone())
1067                        .map_err(|e| OmniError::Lance(format!("fts with_column: {}", e)))?;
1068                    scanner
1069                        .full_text_search(fts_query)
1070                        .map_err(|e| OmniError::Lance(format!("full_text_search: {}", e)))?;
1071                }
1072            }
1073            Ok(())
1074        },
1075    )
1076    .await?
1077    .try_collect::<Vec<RecordBatch>>()
1078    .await
1079    .map_err(|e| OmniError::Lance(e.to_string()))?;
1080
1081    let scan_result = if batches.is_empty() {
1082        RecordBatch::new_empty(batches.first().map(|b| b.schema()).unwrap_or_else(|| {
1083            // Build a non-blob schema for empty result
1084            let fields: Vec<_> = node_type
1085                .arrow_schema
1086                .fields()
1087                .iter()
1088                .filter(|f| !node_type.blob_properties.contains(f.name()))
1089                .map(|f| f.as_ref().clone())
1090                .collect();
1091            Arc::new(Schema::new(fields))
1092        }))
1093    } else if batches.len() == 1 {
1094        batches.into_iter().next().unwrap()
1095    } else {
1096        let schema = batches[0].schema();
1097        arrow_select::concat::concat_batches(&schema, &batches)
1098            .map_err(|e| OmniError::Lance(e.to_string()))?
1099    };
1100
1101    // Add null placeholder columns for excluded blob properties
1102    if has_blobs {
1103        return add_null_blob_columns(&scan_result, node_type);
1104    }
1105    Ok(scan_result)
1106}
1107
1108/// Add null Utf8 columns for blob properties excluded from a scan.
1109/// Uses column_by_name (not positional) so it's order-independent.
1110fn add_null_blob_columns(
1111    batch: &RecordBatch,
1112    node_type: &omnigraph_compiler::catalog::NodeType,
1113) -> Result<RecordBatch> {
1114    let num_rows = batch.num_rows();
1115    let mut fields = Vec::with_capacity(node_type.arrow_schema.fields().len());
1116    let mut columns: Vec<ArrayRef> = Vec::with_capacity(node_type.arrow_schema.fields().len());
1117
1118    for field in node_type.arrow_schema.fields() {
1119        if node_type.blob_properties.contains(field.name()) {
1120            fields.push(Field::new(field.name(), DataType::Utf8, true));
1121            columns.push(Arc::new(StringArray::from(vec![None::<&str>; num_rows])));
1122        } else if let Some(col) = batch.column_by_name(field.name()) {
1123            let batch_schema = batch.schema();
1124            let batch_field = batch_schema
1125                .field_with_name(field.name())
1126                .map_err(|e| OmniError::Lance(e.to_string()))?;
1127            fields.push(batch_field.clone());
1128            columns.push(col.clone());
1129        }
1130    }
1131
1132    RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
1133        .map_err(|e| OmniError::Lance(e.to_string()))
1134}
1135
1136/// Convert IR filters to a Lance SQL filter string.
1137fn build_lance_filter(filters: &[IRFilter], params: &ParamMap) -> Option<String> {
1138    if filters.is_empty() {
1139        return None;
1140    }
1141
1142    let parts: Vec<String> = filters
1143        .iter()
1144        .filter_map(|f| ir_filter_to_sql(f, params))
1145        .collect();
1146
1147    if parts.is_empty() {
1148        return None;
1149    }
1150
1151    Some(parts.join(" AND "))
1152}
1153
1154fn ir_filter_to_sql(filter: &IRFilter, params: &ParamMap) -> Option<String> {
1155    // Search predicates (search/fuzzy/match_text = true) are NOT converted to SQL.
1156    // They are handled via scanner.full_text_search() in execute_node_scan.
1157    if is_search_filter(filter) {
1158        return None;
1159    }
1160
1161    let left = ir_expr_to_sql(&filter.left, params)?;
1162    let right = ir_expr_to_sql(&filter.right, params)?;
1163    let op = match filter.op {
1164        CompOp::Eq => "=",
1165        CompOp::Ne => "!=",
1166        CompOp::Gt => ">",
1167        CompOp::Lt => "<",
1168        CompOp::Ge => ">=",
1169        CompOp::Le => "<=",
1170        CompOp::Contains => return None, // Can't pushdown list contains
1171    };
1172    Some(format!("{} {} {}", left, op, right))
1173}
1174
1175/// Build a FullTextSearchQuery from a search IR expression.
1176fn build_fts_query(
1177    expr: &IRExpr,
1178    params: &ParamMap,
1179) -> Option<lance_index::scalar::FullTextSearchQuery> {
1180    match expr {
1181        IRExpr::Search { field, query } => {
1182            let prop = extract_property(field)?;
1183            let q = resolve_to_string(query, params)?;
1184            lance_index::scalar::FullTextSearchQuery::new(q)
1185                .with_column(prop)
1186                .ok()
1187        }
1188        IRExpr::Fuzzy {
1189            field,
1190            query,
1191            max_edits,
1192        } => {
1193            let prop = extract_property(field)?;
1194            let q = resolve_to_string(query, params)?;
1195            let edits = max_edits
1196                .as_ref()
1197                .and_then(|e| resolve_to_int(e, params))
1198                .unwrap_or(2) as u32;
1199            lance_index::scalar::FullTextSearchQuery::new_fuzzy(q, Some(edits))
1200                .with_column(prop)
1201                .ok()
1202        }
1203        IRExpr::MatchText { field, query } => {
1204            // Use regular text search (phrase search not available in Lance 3.0 Rust API)
1205            let prop = extract_property(field)?;
1206            let q = resolve_to_string(query, params)?;
1207            lance_index::scalar::FullTextSearchQuery::new(q)
1208                .with_column(prop)
1209                .ok()
1210        }
1211        _ => None,
1212    }
1213}
1214
1215/// Extract the property name from a PropAccess expression.
1216fn extract_property(expr: &IRExpr) -> Option<String> {
1217    match expr {
1218        IRExpr::PropAccess { property, .. } => Some(property.clone()),
1219        _ => None,
1220    }
1221}
1222
1223/// Resolve an expression to a string value (literal or param).
1224fn resolve_to_string(expr: &IRExpr, params: &ParamMap) -> Option<String> {
1225    match expr {
1226        IRExpr::Literal(Literal::String(s)) => Some(s.clone()),
1227        IRExpr::Param(name) => match params.get(name)? {
1228            Literal::String(s) => Some(s.clone()),
1229            _ => None,
1230        },
1231        _ => None,
1232    }
1233}
1234
1235/// Resolve an expression to an integer value (literal or param).
1236fn resolve_to_int(expr: &IRExpr, params: &ParamMap) -> Option<i64> {
1237    match expr {
1238        IRExpr::Literal(Literal::Integer(n)) => Some(*n),
1239        IRExpr::Param(name) => match params.get(name)? {
1240            Literal::Integer(n) => Some(*n),
1241            _ => None,
1242        },
1243        _ => None,
1244    }
1245}
1246
1247fn ir_expr_to_sql(expr: &IRExpr, params: &ParamMap) -> Option<String> {
1248    match expr {
1249        IRExpr::PropAccess { property, .. } => Some(property.clone()),
1250        IRExpr::Literal(lit) => Some(literal_to_sql(lit)),
1251        IRExpr::Param(name) => params.get(name).map(literal_to_sql),
1252        _ => None,
1253    }
1254}
1255
1256pub(super) fn literal_to_sql(lit: &Literal) -> String {
1257    match lit {
1258        Literal::Null => "NULL".to_string(),
1259        Literal::String(s) => format!("'{}'", s.replace('\'', "''")),
1260        Literal::Integer(n) => n.to_string(),
1261        Literal::Float(f) => f.to_string(),
1262        Literal::Bool(b) => b.to_string(),
1263        Literal::Date(s) => format!("'{}'", s.replace('\'', "''")),
1264        Literal::DateTime(s) => format!("'{}'", s.replace('\'', "''")),
1265        Literal::List(_) => "NULL".to_string(), // Not supported in SQL pushdown
1266    }
1267}
1268
1269fn prefix_batch(batch: &RecordBatch, variable: &str) -> Result<RecordBatch> {
1270    let fields: Vec<Field> = batch.schema().fields().iter().map(|f| {
1271        Field::new(format!("{}.{}", variable, f.name()), f.data_type().clone(), f.is_nullable())
1272    }).collect();
1273    let schema = Arc::new(Schema::new(fields));
1274    RecordBatch::try_new(schema, batch.columns().to_vec()).map_err(|e| OmniError::Lance(e.to_string()))
1275}
1276
1277fn cross_join_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
1278    let n = left.num_rows();
1279    let m = right.num_rows();
1280    if n == 0 || m == 0 {
1281        let mut fields: Vec<Field> = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect();
1282        fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
1283        return Ok(RecordBatch::new_empty(Arc::new(Schema::new(fields))));
1284    }
1285    let left_indices: Vec<u32> = (0..n as u32).flat_map(|i| std::iter::repeat(i).take(m)).collect();
1286    let right_indices: Vec<u32> = (0..n).flat_map(|_| 0..m as u32).collect();
1287    let left_expanded = take_batch(left, &UInt32Array::from(left_indices))?;
1288    let right_expanded = take_batch(right, &UInt32Array::from(right_indices))?;
1289    hconcat_batches(&left_expanded, &right_expanded)
1290}
1291
1292fn hconcat_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
1293    let mut fields: Vec<Field> = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect();
1294    if cfg!(debug_assertions) {
1295        let left_schema = left.schema();
1296        let left_names: HashSet<&str> = left_schema.fields().iter().map(|f| f.name().as_str()).collect();
1297        let right_schema = right.schema();
1298        for f in right_schema.fields() {
1299            debug_assert!(!left_names.contains(f.name().as_str()), "hconcat_batches: duplicate column '{}'", f.name());
1300        }
1301    }
1302    fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
1303    let mut columns: Vec<ArrayRef> = left.columns().to_vec();
1304    columns.extend(right.columns().to_vec());
1305    RecordBatch::try_new(Arc::new(Schema::new(fields)), columns).map_err(|e| OmniError::Lance(e.to_string()))
1306}
1307
1308fn take_batch(batch: &RecordBatch, indices: &UInt32Array) -> Result<RecordBatch> {
1309    let columns: Vec<ArrayRef> = batch.columns().iter()
1310        .map(|col| arrow_select::take::take(col.as_ref(), indices, None))
1311        .collect::<std::result::Result<Vec<_>, _>>()
1312        .map_err(|e| OmniError::Lance(e.to_string()))?;
1313    RecordBatch::try_new(batch.schema(), columns).map_err(|e| OmniError::Lance(e.to_string()))
1314}