1use super::*;
2
3use super::projection::{apply_filter, apply_ordering, project_return};
4
5impl Omnigraph {
6 pub async fn query(
8 &self,
9 target: impl Into<ReadTarget>,
10 query_source: &str,
11 query_name: &str,
12 params: &ParamMap,
13 ) -> Result<QueryResult> {
14 self.ensure_schema_state_valid().await?;
15 let resolved = self.resolved_target(target).await?;
16 let catalog = self.catalog();
17
18 let query_decl = omnigraph_compiler::find_named_query(query_source, query_name)
19 .map_err(|e| OmniError::manifest(e.to_string()))?;
20 let type_ctx = typecheck_query(&catalog, &query_decl)?;
21 let ir = lower_query(&catalog, &query_decl, &type_ctx)?;
22
23 let needs_graph = ir
24 .pipeline
25 .iter()
26 .any(|op| matches!(op, IROp::Expand { .. } | IROp::AntiJoin { .. }));
27 let graph_index = if needs_graph {
28 Some(self.graph_index_for_resolved(&resolved).await?)
29 } else {
30 None
31 };
32
33 execute_query(
34 &ir,
35 params,
36 &resolved.snapshot,
37 graph_index.as_deref(),
38 &catalog,
39 )
40 .await
41 }
42
43 pub async fn run_query_at(
48 &self,
49 version: u64,
50 query_source: &str,
51 query_name: &str,
52 params: &ParamMap,
53 ) -> Result<QueryResult> {
54 self.ensure_schema_state_valid().await?;
55 let snapshot = self.snapshot_at_version(version).await?;
56 let catalog = self.catalog();
57
58 let query_decl = omnigraph_compiler::find_named_query(query_source, query_name)
59 .map_err(|e| OmniError::manifest(e.to_string()))?;
60 let type_ctx = typecheck_query(&catalog, &query_decl)?;
61 let ir = lower_query(&catalog, &query_decl, &type_ctx)?;
62
63 let needs_graph = ir
64 .pipeline
65 .iter()
66 .any(|op| matches!(op, IROp::Expand { .. } | IROp::AntiJoin { .. }));
67 let graph_index = if needs_graph {
68 let edge_types = catalog
69 .edge_types
70 .iter()
71 .map(|(name, et)| (name.clone(), (et.from_type.clone(), et.to_type.clone())))
72 .collect();
73 Some(Arc::new(GraphIndex::build(&snapshot, &edge_types).await?))
74 } else {
75 None
76 };
77
78 execute_query(&ir, params, &snapshot, graph_index.as_deref(), &catalog).await
79 }
80}
81
82#[derive(Debug, Default)]
86struct SearchMode {
87 nearest: Option<(String, String, Vec<f32>, usize)>,
89 bm25: Option<(String, String, String)>,
91 rrf: Option<RrfMode>,
93}
94
95#[derive(Debug)]
96struct RrfMode {
97 primary: Box<SearchMode>,
98 secondary: Box<SearchMode>,
99 k: u32,
100 limit: usize,
101}
102
103async fn extract_search_mode(
105 ir: &QueryIR,
106 params: &ParamMap,
107 catalog: &Catalog,
108) -> Result<SearchMode> {
109 if ir.order_by.is_empty() {
110 return Ok(SearchMode::default());
111 }
112 let ordering = &ir.order_by[0];
113 match &ordering.expr {
114 IRExpr::Nearest {
115 variable,
116 property,
117 query,
118 } => {
119 let vec =
120 resolve_nearest_query_vec(ir, catalog, variable, property, query, params).await?;
121 let k = ir.limit.ok_or_else(|| {
122 OmniError::manifest("nearest() ordering requires a limit clause".to_string())
123 })? as usize;
124 Ok(SearchMode {
125 nearest: Some((variable.clone(), property.clone(), vec, k)),
126 ..Default::default()
127 })
128 }
129 IRExpr::Bm25 { field, query } => {
130 let var = match field.as_ref() {
131 IRExpr::PropAccess { variable, .. } => variable.clone(),
132 _ => {
133 return Err(OmniError::manifest(
134 "bm25 field must be a property access".to_string(),
135 ));
136 }
137 };
138 let prop = extract_property(field).ok_or_else(|| {
139 OmniError::manifest("bm25 field must be a property access".to_string())
140 })?;
141 let text = resolve_to_string(query, params).ok_or_else(|| {
142 OmniError::manifest("bm25 query must resolve to a string".to_string())
143 })?;
144 Ok(SearchMode {
145 bm25: Some((var, prop, text)),
146 ..Default::default()
147 })
148 }
149 IRExpr::Rrf {
150 primary,
151 secondary,
152 k,
153 } => {
154 let limit = ir.limit.ok_or_else(|| {
155 OmniError::manifest("rrf() ordering requires a limit clause".to_string())
156 })? as usize;
157 let k_val = k
158 .as_ref()
159 .and_then(|e| resolve_to_int(e, params))
160 .unwrap_or(60) as u32;
161
162 let primary_mode =
163 extract_sub_search_mode(ir, primary, params, catalog, ir.limit).await?;
164 let secondary_mode =
165 extract_sub_search_mode(ir, secondary, params, catalog, ir.limit).await?;
166
167 Ok(SearchMode {
168 rrf: Some(RrfMode {
169 primary: Box::new(primary_mode),
170 secondary: Box::new(secondary_mode),
171 k: k_val,
172 limit,
173 }),
174 ..Default::default()
175 })
176 }
177 _ => Ok(SearchMode::default()),
178 }
179}
180
181async fn extract_sub_search_mode(
183 ir: &QueryIR,
184 expr: &IRExpr,
185 params: &ParamMap,
186 catalog: &Catalog,
187 limit: Option<u64>,
188) -> Result<SearchMode> {
189 match expr {
190 IRExpr::Nearest {
191 variable,
192 property,
193 query,
194 } => {
195 let vec =
196 resolve_nearest_query_vec(ir, catalog, variable, property, query, params).await?;
197 let k = limit.unwrap_or(100) as usize;
198 Ok(SearchMode {
199 nearest: Some((variable.clone(), property.clone(), vec, k)),
200 ..Default::default()
201 })
202 }
203 IRExpr::Bm25 { field, query } => {
204 let var = match field.as_ref() {
205 IRExpr::PropAccess { variable, .. } => variable.clone(),
206 _ => {
207 return Err(OmniError::manifest(
208 "bm25 field must be a property access".to_string(),
209 ));
210 }
211 };
212 let prop = extract_property(field).ok_or_else(|| {
213 OmniError::manifest("bm25 field must be a property access".to_string())
214 })?;
215 let text = resolve_to_string(query, params).ok_or_else(|| {
216 OmniError::manifest("bm25 query must resolve to a string".to_string())
217 })?;
218 Ok(SearchMode {
219 bm25: Some((var, prop, text)),
220 ..Default::default()
221 })
222 }
223 _ => Ok(SearchMode::default()),
224 }
225}
226
227async fn resolve_nearest_query_vec(
229 ir: &QueryIR,
230 catalog: &Catalog,
231 variable: &str,
232 property: &str,
233 expr: &IRExpr,
234 params: &ParamMap,
235) -> Result<Vec<f32>> {
236 let lit = resolve_literal_or_param(expr, params)?;
237 match lit {
238 Literal::List(_) => literal_to_f32_vec(&lit),
239 Literal::String(text) => {
240 let expected_dim = nearest_property_dimension(ir, catalog, variable, property)?;
241 EmbeddingClient::from_env()?
242 .embed_query_text(&text, expected_dim)
243 .await
244 }
245 _ => Err(OmniError::manifest(
246 "nearest query must be a string or list of floats".to_string(),
247 )),
248 }
249}
250
251fn resolve_literal_or_param(expr: &IRExpr, params: &ParamMap) -> Result<Literal> {
252 Ok(match expr {
253 IRExpr::Literal(lit) => lit.clone(),
254 IRExpr::Param(name) => params
255 .get(name)
256 .cloned()
257 .ok_or_else(|| OmniError::manifest(format!("parameter '{}' not provided", name)))?,
258 _ => {
259 return Err(OmniError::manifest(
260 "nearest query must be a literal or parameter".to_string(),
261 ));
262 }
263 })
264}
265
266fn literal_to_f32_vec(lit: &Literal) -> Result<Vec<f32>> {
268 match lit {
269 Literal::List(items) => items
270 .iter()
271 .map(|item| match item {
272 Literal::Float(f) => Ok(*f as f32),
273 Literal::Integer(n) => Ok(*n as f32),
274 _ => Err(OmniError::manifest(
275 "vector elements must be numeric".to_string(),
276 )),
277 })
278 .collect(),
279 _ => Err(OmniError::manifest(
280 "nearest query must be a list of floats".to_string(),
281 )),
282 }
283}
284
285fn nearest_property_dimension(
286 ir: &QueryIR,
287 catalog: &Catalog,
288 variable: &str,
289 property: &str,
290) -> Result<usize> {
291 let type_name = resolve_binding_type_name(&ir.pipeline, variable).ok_or_else(|| {
292 OmniError::manifest_internal(format!(
293 "nearest() variable '${}' is not bound to a node type in the lowered pipeline",
294 variable
295 ))
296 })?;
297 let node_type = catalog.node_types.get(type_name).ok_or_else(|| {
298 OmniError::manifest_internal(format!(
299 "nearest() binding '${}' resolved unknown node type '{}'",
300 variable, type_name
301 ))
302 })?;
303 let prop = node_type.properties.get(property).ok_or_else(|| {
304 OmniError::manifest_internal(format!(
305 "nearest() property '{}.{}' is missing from the catalog",
306 type_name, property
307 ))
308 })?;
309 match prop.scalar {
310 ScalarType::Vector(dim) if !prop.list => Ok(dim as usize),
311 _ => Err(OmniError::manifest_internal(format!(
312 "nearest() property '{}.{}' is not a scalar vector",
313 type_name, property
314 ))),
315 }
316}
317
318fn resolve_binding_type_name<'a>(pipeline: &'a [IROp], variable: &str) -> Option<&'a str> {
319 for op in pipeline {
320 match op {
321 IROp::NodeScan {
322 variable: bound_var,
323 type_name,
324 ..
325 } if bound_var == variable => return Some(type_name.as_str()),
326 IROp::Expand {
327 dst_var, dst_type, ..
328 } if dst_var == variable => return Some(dst_type.as_str()),
329 IROp::AntiJoin { inner, .. } => {
330 if let Some(type_name) = resolve_binding_type_name(inner, variable) {
331 return Some(type_name);
332 }
333 }
334 _ => {}
335 }
336 }
337 None
338}
339
340pub async fn execute_query(
342 ir: &QueryIR,
343 params: &ParamMap,
344 snapshot: &Snapshot,
345 graph_index: Option<&GraphIndex>,
346 catalog: &Catalog,
347) -> Result<QueryResult> {
348 let search_mode = extract_search_mode(ir, params, catalog).await?;
349
350 if let Some(ref rrf) = search_mode.rrf {
352 return execute_rrf_query(ir, params, snapshot, graph_index, catalog, rrf).await;
353 }
354
355 let mut wide: Option<RecordBatch> = None;
356 execute_pipeline(
357 &ir.pipeline,
358 params,
359 snapshot,
360 graph_index,
361 catalog,
362 &mut wide,
363 &search_mode,
364 )
365 .await?;
366 let wide_batch = wide.unwrap_or_else(|| RecordBatch::new_empty(Arc::new(Schema::empty())));
367
368 let has_aggregates = ir
370 .return_exprs
371 .iter()
372 .any(|p| matches!(&p.expr, IRExpr::Aggregate { .. }));
373 let mut result_batch = project_return(&wide_batch, &ir.return_exprs, params)?;
374
375 if !ir.order_by.is_empty() && !is_search_ordered(&search_mode) {
377 result_batch = if has_aggregates {
378 apply_ordering(result_batch.clone(), &ir.order_by, &result_batch, params)?
379 } else {
380 apply_ordering(result_batch, &ir.order_by, &wide_batch, params)?
381 };
382 }
383
384 if let Some(limit) = ir.limit {
386 let len = result_batch.num_rows().min(limit as usize);
387 result_batch = result_batch.slice(0, len);
388 }
389
390 Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
391}
392
393fn is_search_ordered(search_mode: &SearchMode) -> bool {
395 search_mode.nearest.is_some() || search_mode.bm25.is_some()
396}
397
398async fn execute_rrf_query(
400 ir: &QueryIR,
401 params: &ParamMap,
402 snapshot: &Snapshot,
403 graph_index: Option<&GraphIndex>,
404 catalog: &Catalog,
405 rrf: &RrfMode,
406) -> Result<QueryResult> {
407 let mut primary_wide: Option<RecordBatch> = None;
409 execute_pipeline(
410 &ir.pipeline,
411 params,
412 snapshot,
413 graph_index,
414 catalog,
415 &mut primary_wide,
416 &rrf.primary,
417 )
418 .await?;
419
420 let mut secondary_wide: Option<RecordBatch> = None;
422 execute_pipeline(
423 &ir.pipeline,
424 params,
425 snapshot,
426 graph_index,
427 catalog,
428 &mut secondary_wide,
429 &rrf.secondary,
430 )
431 .await?;
432
433 let primary_var = rrf
436 .primary
437 .nearest
438 .as_ref()
439 .map(|(v, ..)| v.as_str())
440 .or_else(|| rrf.primary.bm25.as_ref().map(|(v, ..)| v.as_str()))
441 .ok_or_else(|| OmniError::manifest("rrf primary must be nearest or bm25".to_string()))?;
442
443 let primary_batch = primary_wide.as_ref().ok_or_else(|| {
444 OmniError::manifest(format!(
445 "rrf primary variable '{}' not in bindings",
446 primary_var
447 ))
448 })?;
449 let secondary_batch = secondary_wide.as_ref().ok_or_else(|| {
450 OmniError::manifest(format!(
451 "rrf secondary variable '{}' not in bindings",
452 primary_var
453 ))
454 })?;
455
456 let id_col_name = format!("{}.id", primary_var);
458 let primary_ids = extract_id_column_by_name(primary_batch, &id_col_name)?;
459 let secondary_ids = extract_id_column_by_name(secondary_batch, &id_col_name)?;
460
461 let mut primary_rank: HashMap<String, usize> = HashMap::new();
462 for (i, id) in primary_ids.iter().enumerate() {
463 primary_rank.entry(id.clone()).or_insert(i);
464 }
465 let mut secondary_rank: HashMap<String, usize> = HashMap::new();
466 for (i, id) in secondary_ids.iter().enumerate() {
467 secondary_rank.entry(id.clone()).or_insert(i);
468 }
469
470 let mut all_ids: Vec<String> = primary_ids.clone();
472 for id in &secondary_ids {
473 if !primary_rank.contains_key(id) {
474 all_ids.push(id.clone());
475 }
476 }
477
478 let k = rrf.k as f64;
480 let mut scored: Vec<(String, f64)> = all_ids
481 .iter()
482 .map(|id| {
483 let p = primary_rank
484 .get(id)
485 .map(|&r| 1.0 / (k + r as f64 + 1.0))
486 .unwrap_or(0.0);
487 let s = secondary_rank
488 .get(id)
489 .map(|&r| 1.0 / (k + r as f64 + 1.0))
490 .unwrap_or(0.0);
491 (id.clone(), p + s)
492 })
493 .collect();
494 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
495 scored.truncate(rrf.limit);
496
497 let winning_ids: Vec<String> = scored.iter().map(|(id, _)| id.clone()).collect();
499
500 let mut id_to_batch_row: HashMap<String, (&RecordBatch, usize)> = HashMap::new();
502 for (i, id) in primary_ids.iter().enumerate() {
503 id_to_batch_row
504 .entry(id.clone())
505 .or_insert((primary_batch, i));
506 }
507 for (i, id) in secondary_ids.iter().enumerate() {
508 id_to_batch_row
509 .entry(id.clone())
510 .or_insert((secondary_batch, i));
511 }
512
513 let fused_batch = build_fused_batch(&winning_ids, &id_to_batch_row, primary_batch.schema())?;
515
516 let result_batch = project_return(&fused_batch, &ir.return_exprs, params)?;
518
519 Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
521}
522
523fn extract_id_column_by_name(batch: &RecordBatch, col_name: &str) -> Result<Vec<String>> {
524 let col = batch.column_by_name(col_name).ok_or_else(|| {
525 OmniError::manifest(format!("batch missing '{}' column for RRF", col_name))
526 })?;
527 let ids = col
528 .as_any()
529 .downcast_ref::<StringArray>()
530 .ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", col_name)))?;
531 Ok((0..ids.len()).map(|i| ids.value(i).to_string()).collect())
532}
533
534fn build_fused_batch(
535 ordered_ids: &[String],
536 id_to_batch_row: &HashMap<String, (&RecordBatch, usize)>,
537 schema: SchemaRef,
538) -> Result<RecordBatch> {
539 if ordered_ids.is_empty() {
540 return Ok(RecordBatch::new_empty(schema));
541 }
542
543 let mut row_slices: Vec<RecordBatch> = Vec::with_capacity(ordered_ids.len());
545 for id in ordered_ids {
546 if let Some(&(batch, row_idx)) = id_to_batch_row.get(id) {
547 row_slices.push(batch.slice(row_idx, 1));
548 }
549 }
550
551 if row_slices.is_empty() {
552 return Ok(RecordBatch::new_empty(schema));
553 }
554
555 let schema = row_slices[0].schema();
556 arrow_select::concat::concat_batches(&schema, &row_slices)
557 .map_err(|e| OmniError::Lance(e.to_string()))
558}
559
560fn is_search_filter(filter: &IRFilter) -> bool {
562 matches!(
563 &filter.left,
564 IRExpr::Search { .. } | IRExpr::Fuzzy { .. } | IRExpr::MatchText { .. }
565 )
566}
567
568fn search_filter_variable(filter: &IRFilter) -> Option<&str> {
570 let field = match &filter.left {
571 IRExpr::Search { field, .. } => field,
572 IRExpr::Fuzzy { field, .. } => field,
573 IRExpr::MatchText { field, .. } => field,
574 _ => return None,
575 };
576 match field.as_ref() {
577 IRExpr::PropAccess { variable, .. } => Some(variable.as_str()),
578 _ => None,
579 }
580}
581
582fn execute_pipeline<'a>(
583 pipeline: &'a [IROp],
584 params: &'a ParamMap,
585 snapshot: &'a Snapshot,
586 graph_index: Option<&'a GraphIndex>,
587 catalog: &'a Catalog,
588 wide: &'a mut Option<RecordBatch>,
589 search_mode: &'a SearchMode,
590) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
591 Box::pin(async move {
592 let mut hoisted_search_filters: HashMap<String, Vec<IRFilter>> = HashMap::new();
594 let mut hoisted_indices: HashSet<usize> = HashSet::new();
595 for (i, op) in pipeline.iter().enumerate() {
596 if let IROp::Filter(filter) = op {
597 if is_search_filter(filter) {
598 if let Some(var) = search_filter_variable(filter) {
599 hoisted_search_filters
600 .entry(var.to_string())
601 .or_default()
602 .push(filter.clone());
603 hoisted_indices.insert(i);
604 }
605 }
606 }
607 }
608
609 for (i, op) in pipeline.iter().enumerate() {
610 if hoisted_indices.contains(&i) {
612 continue;
613 }
614 match op {
615 IROp::NodeScan {
616 variable,
617 type_name,
618 filters,
619 } => {
620 let mut all_filters: Vec<IRFilter> = filters.clone();
622 if let Some(extra) = hoisted_search_filters.get(variable) {
623 all_filters.extend(extra.iter().cloned());
624 }
625 let batch = execute_node_scan(
626 type_name,
627 variable,
628 &all_filters,
629 params,
630 snapshot,
631 catalog,
632 search_mode,
633 )
634 .await?;
635 let prefixed = prefix_batch(&batch, variable)?;
636 *wide = Some(match wide.take() {
637 None => prefixed,
638 Some(existing) => cross_join_batches(&existing, &prefixed)?,
639 });
640 }
641 IROp::Filter(filter) => {
642 if let Some(batch) = wide.as_mut() {
643 apply_filter(batch, filter, params)?;
644 }
645 }
646 IROp::Expand {
647 src_var,
648 dst_var,
649 edge_type,
650 direction,
651 dst_type,
652 min_hops,
653 max_hops,
654 dst_filters,
655 } => {
656 let gi = graph_index.ok_or_else(|| {
657 OmniError::manifest("graph index required for traversal".to_string())
658 })?;
659 if let Some(batch) = wide.as_mut() {
660 execute_expand(
661 batch,
662 gi,
663 snapshot,
664 catalog,
665 src_var,
666 dst_var,
667 edge_type,
668 *direction,
669 dst_type,
670 *min_hops,
671 *max_hops,
672 dst_filters,
673 params,
674 )
675 .await?;
676 }
677 }
678 IROp::AntiJoin { outer_var, inner } => {
679 let gi = graph_index;
680 if let Some(batch) = wide.as_mut() {
681 execute_anti_join(batch, inner, params, snapshot, gi, catalog, outer_var)
682 .await?;
683 }
684 }
685 }
686 }
687 Ok(())
688 })
689}
690
691async fn execute_expand(
693 wide: &mut RecordBatch,
694 graph_index: &GraphIndex,
695 snapshot: &Snapshot,
696 catalog: &Catalog,
697 src_var: &str,
698 dst_var: &str,
699 edge_type: &str,
700 direction: Direction,
701 dst_type: &str,
702 min_hops: u32,
703 max_hops: Option<u32>,
704 dst_filters: &[IRFilter],
705 params: &ParamMap,
706) -> Result<()> {
707 let src_id_col_name = format!("{}.id", src_var);
708 let src_ids = wide
709 .column_by_name(&src_id_col_name)
710 .ok_or_else(|| {
711 OmniError::manifest(format!("wide batch missing '{}' column", src_id_col_name))
712 })?
713 .as_any()
714 .downcast_ref::<StringArray>()
715 .ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", src_id_col_name)))?
716 .clone();
717
718 let edge_def = catalog
720 .edge_types
721 .get(edge_type)
722 .ok_or_else(|| OmniError::manifest(format!("unknown edge type '{}'", edge_type)))?;
723
724 let (src_type_name, dst_type_name) = match direction {
725 Direction::Out => (&edge_def.from_type, &edge_def.to_type),
726 Direction::In => (&edge_def.to_type, &edge_def.from_type),
727 };
728
729 let src_type_idx = graph_index
730 .type_index(src_type_name)
731 .ok_or_else(|| OmniError::manifest(format!("no type index for '{}'", src_type_name)))?;
732 let dst_type_idx = graph_index
733 .type_index(dst_type_name)
734 .ok_or_else(|| OmniError::manifest(format!("no type index for '{}'", dst_type_name)))?;
735
736 let adj = match direction {
737 Direction::Out => graph_index.csr(edge_type),
738 Direction::In => graph_index.csc(edge_type),
739 }
740 .ok_or_else(|| OmniError::manifest(format!("no adjacency index for edge '{}'", edge_type)))?;
741
742 let max = max_hops.unwrap_or(min_hops.max(1));
743
744 let same_type = src_type_name == dst_type_name;
745
746 let mut src_indices: Vec<u32> = Vec::new();
750 let mut dst_dense_list: Vec<u32> = Vec::new();
751 for i in 0..src_ids.len() {
752 let src_id = src_ids.value(i);
753 let Some(src_dense) = src_type_idx.to_dense(src_id) else {
754 continue;
755 };
756
757 let mut frontier: Vec<u32> = vec![src_dense];
759 let mut visited: HashSet<u32> = HashSet::new();
760 let mut seen_dst_dense: HashSet<u32> = HashSet::new();
761 if same_type {
765 visited.insert(src_dense);
766 }
767
768 for hop in 1..=max {
769 let mut next_frontier = Vec::new();
770 for &node in &frontier {
771 for &neighbor in adj.neighbors(node) {
772 if !same_type || visited.insert(neighbor) {
773 next_frontier.push(neighbor);
774 if hop >= min_hops && seen_dst_dense.insert(neighbor) {
775 src_indices.push(i as u32);
776 dst_dense_list.push(neighbor);
777 }
778 }
779 }
780 }
781 frontier = next_frontier;
782 if frontier.is_empty() {
783 break;
784 }
785 }
786 }
787
788 let pushdown_sql = build_lance_filter(dst_filters, params);
790 let non_pushable: Vec<&IRFilter> = dst_filters
791 .iter()
792 .filter(|f| ir_filter_to_sql(f, params).is_none())
793 .collect();
794
795 let mut unique_dst_list: Vec<String> = Vec::new();
799 {
800 let mut seen: HashSet<u32> = HashSet::with_capacity(dst_dense_list.len());
801 for &d in &dst_dense_list {
802 if seen.insert(d) {
803 if let Some(id) = dst_type_idx.to_id(d) {
804 unique_dst_list.push(id.to_string());
805 }
806 }
807 }
808 }
809 let dst_batch = hydrate_nodes(
810 snapshot,
811 catalog,
812 dst_type,
813 &unique_dst_list,
814 pushdown_sql.as_deref(),
815 )
816 .await?;
817
818 let dst_batch_id_col = dst_batch
820 .column_by_name("id")
821 .ok_or_else(|| OmniError::manifest("hydrated batch missing 'id' column".to_string()))?
822 .as_any()
823 .downcast_ref::<StringArray>()
824 .ok_or_else(|| OmniError::manifest("hydrated 'id' column is not Utf8".to_string()))?;
825 let mut dense_to_row: Vec<Option<u32>> = vec![None; dst_type_idx.len()];
826 for row in 0..dst_batch_id_col.len() {
827 let id_str = dst_batch_id_col.value(row);
828 if let Some(dense) = dst_type_idx.to_dense(id_str) {
829 dense_to_row[dense as usize] = Some(row as u32);
830 }
831 }
832
833 let mut final_src_indices: Vec<u32> = Vec::new();
835 let mut dst_indices: Vec<u32> = Vec::new();
836 for (src_idx, dst_dense) in src_indices.iter().zip(dst_dense_list.iter()) {
837 if let Some(dst_row) = dense_to_row[*dst_dense as usize] {
838 final_src_indices.push(*src_idx);
839 dst_indices.push(dst_row);
840 }
841 }
842
843 let src_take = UInt32Array::from(final_src_indices);
844 let dst_take = UInt32Array::from(dst_indices);
845 let expanded_wide = take_batch(wide, &src_take)?;
846 let dst_prefixed = prefix_batch(&dst_batch, dst_var)?;
847 let aligned_dst = take_batch(&dst_prefixed, &dst_take)?;
848 *wide = hconcat_batches(&expanded_wide, &aligned_dst)?;
849
850 for f in &non_pushable {
852 apply_filter(wide, f, params)?;
853 }
854
855 Ok(())
856}
857
858async fn hydrate_nodes(
864 snapshot: &Snapshot,
865 catalog: &Catalog,
866 type_name: &str,
867 ids: &[String],
868 extra_filter_sql: Option<&str>,
869) -> Result<RecordBatch> {
870 let node_type = catalog
871 .node_types
872 .get(type_name)
873 .ok_or_else(|| OmniError::manifest(format!("unknown node type '{}'", type_name)))?;
874
875 if ids.is_empty() {
876 return Ok(RecordBatch::new_empty(node_type.arrow_schema.clone()));
877 }
878
879 let table_key = format!("node:{}", type_name);
880 let ds = snapshot.open(&table_key).await?;
881
882 let escaped: Vec<String> = ids
884 .iter()
885 .map(|id| format!("'{}'", id.replace('\'', "''")))
886 .collect();
887 let mut filter_sql = format!("id IN ({})", escaped.join(", "));
888 if let Some(extra) = extra_filter_sql {
889 filter_sql = format!("({}) AND ({})", filter_sql, extra);
890 }
891 let has_blobs = !node_type.blob_properties.is_empty();
892 let non_blob_cols: Vec<&str> = node_type
893 .arrow_schema
894 .fields()
895 .iter()
896 .filter(|f| !node_type.blob_properties.contains(f.name()))
897 .map(|f| f.name().as_str())
898 .collect();
899 let projection = has_blobs.then_some(non_blob_cols.as_slice());
900 let batches = crate::table_store::TableStore::scan_stream(
901 &ds,
902 projection,
903 Some(&filter_sql),
904 None,
905 false,
906 )
907 .await?
908 .try_collect::<Vec<RecordBatch>>()
909 .await
910 .map_err(|e| OmniError::Lance(e.to_string()))?;
911
912 let scan_result = if batches.is_empty() {
913 return Ok(RecordBatch::new_empty(node_type.arrow_schema.clone()));
914 } else if batches.len() == 1 {
915 batches.into_iter().next().unwrap()
916 } else {
917 let schema = batches[0].schema();
918 arrow_select::concat::concat_batches(&schema, &batches)
919 .map_err(|e| OmniError::Lance(e.to_string()))?
920 };
921
922 if has_blobs {
923 return add_null_blob_columns(&scan_result, node_type);
924 }
925 Ok(scan_result)
926}
927
928fn try_bulk_anti_join_mask(
931 wide: &RecordBatch,
932 inner_pipeline: &[IROp],
933 graph_index: Option<&GraphIndex>,
934 catalog: &Catalog,
935 outer_var: &str,
936) -> Option<BooleanArray> {
937 if inner_pipeline.len() != 1 {
938 return None;
939 }
940 let IROp::Expand {
941 src_var,
942 edge_type,
943 direction,
944 dst_filters,
945 ..
946 } = &inner_pipeline[0]
947 else {
948 return None;
949 };
950 if src_var != outer_var {
951 return None;
952 }
953 if !dst_filters.is_empty() {
956 return None;
957 }
958 let gi = graph_index?;
959 let edge_def = catalog.edge_types.get(edge_type.as_str())?;
960
961 let src_type_name = match direction {
962 Direction::Out => &edge_def.from_type,
963 Direction::In => &edge_def.to_type,
964 };
965 let adj = match direction {
966 Direction::Out => gi.csr(edge_type),
967 Direction::In => gi.csc(edge_type),
968 }?;
969 let type_idx = gi.type_index(src_type_name)?;
970
971 let id_col_name = format!("{}.id", outer_var);
972 let outer_ids = wide
973 .column_by_name(&id_col_name)?
974 .as_any()
975 .downcast_ref::<StringArray>()?;
976
977 let keep_mask: Vec<bool> = (0..outer_ids.len())
978 .map(|i| {
979 let id = outer_ids.value(i);
980 match type_idx.to_dense(id) {
981 Some(dense) => !adj.has_neighbors(dense),
982 None => true, }
984 })
985 .collect();
986
987 Some(BooleanArray::from(keep_mask))
988}
989
990async fn execute_anti_join(
992 wide: &mut RecordBatch,
993 inner_pipeline: &[IROp],
994 params: &ParamMap,
995 snapshot: &Snapshot,
996 graph_index: Option<&GraphIndex>,
997 catalog: &Catalog,
998 outer_var: &str,
999) -> Result<()> {
1000 if let Some(mask) =
1002 try_bulk_anti_join_mask(wide, inner_pipeline, graph_index, catalog, outer_var)
1003 {
1004 *wide = arrow_select::filter::filter_record_batch(wide, &mask)
1005 .map_err(|e| OmniError::Lance(e.to_string()))?;
1006 return Ok(());
1007 }
1008
1009 let num_rows = wide.num_rows();
1011 let mut keep_mask = vec![true; num_rows];
1012
1013 for i in 0..num_rows {
1014 let single_row = wide.slice(i, 1);
1015 let mut inner_wide: Option<RecordBatch> = Some(single_row);
1016
1017 let no_search = SearchMode::default();
1018 execute_pipeline(
1019 inner_pipeline,
1020 params,
1021 snapshot,
1022 graph_index,
1023 catalog,
1024 &mut inner_wide,
1025 &no_search,
1026 )
1027 .await?;
1028
1029 let has_match = inner_wide
1030 .as_ref()
1031 .map(|batch| batch.num_rows() > 0)
1032 .unwrap_or(false);
1033
1034 if has_match {
1035 keep_mask[i] = false;
1036 }
1037 }
1038
1039 let mask = BooleanArray::from(keep_mask);
1040 *wide = arrow_select::filter::filter_record_batch(wide, &mask)
1041 .map_err(|e| OmniError::Lance(e.to_string()))?;
1042 Ok(())
1043}
1044
1045async fn execute_node_scan(
1047 type_name: &str,
1048 variable: &str,
1049 filters: &[IRFilter],
1050 params: &ParamMap,
1051 snapshot: &Snapshot,
1052 catalog: &Catalog,
1053 search_mode: &SearchMode,
1054) -> Result<RecordBatch> {
1055 let table_key = format!("node:{}", type_name);
1056 let ds = snapshot.open(&table_key).await?;
1057
1058 let filter_expr = build_lance_filter_expr(filters, params);
1068
1069 let node_type = &catalog.node_types[type_name];
1073 let has_blobs = !node_type.blob_properties.is_empty();
1074 let non_blob_cols: Vec<&str> = node_type
1075 .arrow_schema
1076 .fields()
1077 .iter()
1078 .filter(|f| !node_type.blob_properties.contains(f.name()))
1079 .map(|f| f.name().as_str())
1080 .collect();
1081 let projection = has_blobs.then_some(non_blob_cols.as_slice());
1082 let batches = crate::table_store::TableStore::scan_stream_with(
1083 &ds,
1084 projection,
1085 None,
1086 None,
1087 false,
1088 |scanner| {
1089 if let Some(ref expr) = filter_expr {
1091 scanner.filter_expr(expr.clone());
1092 }
1093
1094 for filter in filters {
1096 if is_search_filter(filter) {
1097 if let Some(fts_query) = build_fts_query(&filter.left, params) {
1098 scanner.full_text_search(fts_query).map_err(|e| {
1099 OmniError::Lance(format!("full_text_search filter: {}", e))
1100 })?;
1101 }
1102 }
1103 }
1104
1105 if let Some((ref var, ref prop, ref vec, k)) = search_mode.nearest {
1107 if var == variable {
1108 let query_arr = Float32Array::from(vec.clone());
1109 scanner
1110 .nearest(prop, &query_arr, k)
1111 .map_err(|e| OmniError::Lance(format!("nearest: {}", e)))?;
1112 }
1113 }
1114
1115 if let Some((ref var, ref prop, ref text)) = search_mode.bm25 {
1117 if var == variable {
1118 let fts_query = lance_index::scalar::FullTextSearchQuery::new(text.clone())
1119 .with_column(prop.clone())
1120 .map_err(|e| OmniError::Lance(format!("fts with_column: {}", e)))?;
1121 scanner
1122 .full_text_search(fts_query)
1123 .map_err(|e| OmniError::Lance(format!("full_text_search: {}", e)))?;
1124 }
1125 }
1126 Ok(())
1127 },
1128 )
1129 .await?
1130 .try_collect::<Vec<RecordBatch>>()
1131 .await
1132 .map_err(|e| OmniError::Lance(e.to_string()))?;
1133
1134 let scan_result = if batches.is_empty() {
1135 RecordBatch::new_empty(batches.first().map(|b| b.schema()).unwrap_or_else(|| {
1136 let fields: Vec<_> = node_type
1138 .arrow_schema
1139 .fields()
1140 .iter()
1141 .filter(|f| !node_type.blob_properties.contains(f.name()))
1142 .map(|f| f.as_ref().clone())
1143 .collect();
1144 Arc::new(Schema::new(fields))
1145 }))
1146 } else if batches.len() == 1 {
1147 batches.into_iter().next().unwrap()
1148 } else {
1149 let schema = batches[0].schema();
1150 arrow_select::concat::concat_batches(&schema, &batches)
1151 .map_err(|e| OmniError::Lance(e.to_string()))?
1152 };
1153
1154 if has_blobs {
1156 return add_null_blob_columns(&scan_result, node_type);
1157 }
1158 Ok(scan_result)
1159}
1160
1161fn add_null_blob_columns(
1164 batch: &RecordBatch,
1165 node_type: &omnigraph_compiler::catalog::NodeType,
1166) -> Result<RecordBatch> {
1167 let num_rows = batch.num_rows();
1168 let mut fields = Vec::with_capacity(node_type.arrow_schema.fields().len());
1169 let mut columns: Vec<ArrayRef> = Vec::with_capacity(node_type.arrow_schema.fields().len());
1170
1171 for field in node_type.arrow_schema.fields() {
1172 if node_type.blob_properties.contains(field.name()) {
1173 fields.push(Field::new(field.name(), DataType::Utf8, true));
1174 columns.push(Arc::new(StringArray::from(vec![None::<&str>; num_rows])));
1175 } else if let Some(col) = batch.column_by_name(field.name()) {
1176 let batch_schema = batch.schema();
1177 let batch_field = batch_schema
1178 .field_with_name(field.name())
1179 .map_err(|e| OmniError::Lance(e.to_string()))?;
1180 fields.push(batch_field.clone());
1181 columns.push(col.clone());
1182 }
1183 }
1184
1185 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
1186 .map_err(|e| OmniError::Lance(e.to_string()))
1187}
1188
1189fn build_lance_filter(filters: &[IRFilter], params: &ParamMap) -> Option<String> {
1191 if filters.is_empty() {
1192 return None;
1193 }
1194
1195 let parts: Vec<String> = filters
1196 .iter()
1197 .filter_map(|f| ir_filter_to_sql(f, params))
1198 .collect();
1199
1200 if parts.is_empty() {
1201 return None;
1202 }
1203
1204 Some(parts.join(" AND "))
1205}
1206
1207fn ir_filter_to_sql(filter: &IRFilter, params: &ParamMap) -> Option<String> {
1208 if is_search_filter(filter) {
1211 return None;
1212 }
1213
1214 let left = ir_expr_to_sql(&filter.left, params)?;
1215 let right = ir_expr_to_sql(&filter.right, params)?;
1216 let op = match filter.op {
1217 CompOp::Eq => "=",
1218 CompOp::Ne => "!=",
1219 CompOp::Gt => ">",
1220 CompOp::Lt => "<",
1221 CompOp::Ge => ">=",
1222 CompOp::Le => "<=",
1223 CompOp::Contains => return None, };
1225 Some(format!("{} {} {}", left, op, right))
1226}
1227
1228fn build_fts_query(
1230 expr: &IRExpr,
1231 params: &ParamMap,
1232) -> Option<lance_index::scalar::FullTextSearchQuery> {
1233 match expr {
1234 IRExpr::Search { field, query } => {
1235 let prop = extract_property(field)?;
1236 let q = resolve_to_string(query, params)?;
1237 lance_index::scalar::FullTextSearchQuery::new(q)
1238 .with_column(prop)
1239 .ok()
1240 }
1241 IRExpr::Fuzzy {
1242 field,
1243 query,
1244 max_edits,
1245 } => {
1246 let prop = extract_property(field)?;
1247 let q = resolve_to_string(query, params)?;
1248 let edits = max_edits
1249 .as_ref()
1250 .and_then(|e| resolve_to_int(e, params))
1251 .unwrap_or(2) as u32;
1252 lance_index::scalar::FullTextSearchQuery::new_fuzzy(q, Some(edits))
1253 .with_column(prop)
1254 .ok()
1255 }
1256 IRExpr::MatchText { field, query } => {
1257 let prop = extract_property(field)?;
1259 let q = resolve_to_string(query, params)?;
1260 lance_index::scalar::FullTextSearchQuery::new(q)
1261 .with_column(prop)
1262 .ok()
1263 }
1264 _ => None,
1265 }
1266}
1267
1268fn extract_property(expr: &IRExpr) -> Option<String> {
1270 match expr {
1271 IRExpr::PropAccess { property, .. } => Some(property.clone()),
1272 _ => None,
1273 }
1274}
1275
1276fn resolve_to_string(expr: &IRExpr, params: &ParamMap) -> Option<String> {
1278 match expr {
1279 IRExpr::Literal(Literal::String(s)) => Some(s.clone()),
1280 IRExpr::Param(name) => match params.get(name)? {
1281 Literal::String(s) => Some(s.clone()),
1282 _ => None,
1283 },
1284 _ => None,
1285 }
1286}
1287
1288fn resolve_to_int(expr: &IRExpr, params: &ParamMap) -> Option<i64> {
1290 match expr {
1291 IRExpr::Literal(Literal::Integer(n)) => Some(*n),
1292 IRExpr::Param(name) => match params.get(name)? {
1293 Literal::Integer(n) => Some(*n),
1294 _ => None,
1295 },
1296 _ => None,
1297 }
1298}
1299
1300fn ir_expr_to_sql(expr: &IRExpr, params: &ParamMap) -> Option<String> {
1301 match expr {
1302 IRExpr::PropAccess { property, .. } => Some(property.clone()),
1303 IRExpr::Literal(lit) => Some(literal_to_sql(lit)),
1304 IRExpr::Param(name) => params.get(name).map(literal_to_sql),
1305 _ => None,
1306 }
1307}
1308
1309pub(super) fn literal_to_sql(lit: &Literal) -> String {
1310 match lit {
1311 Literal::Null => "NULL".to_string(),
1312 Literal::String(s) => format!("'{}'", s.replace('\'', "''")),
1313 Literal::Integer(n) => n.to_string(),
1314 Literal::Float(f) => f.to_string(),
1315 Literal::Bool(b) => b.to_string(),
1316 Literal::Date(s) => format!("'{}'", s.replace('\'', "''")),
1317 Literal::DateTime(s) => format!("'{}'", s.replace('\'', "''")),
1318 Literal::List(_) => "NULL".to_string(), }
1320}
1321
1322pub(super) fn build_lance_filter_expr(
1347 filters: &[IRFilter],
1348 params: &ParamMap,
1349) -> Option<datafusion::prelude::Expr> {
1350 use datafusion::logical_expr::Operator;
1351 use datafusion::prelude::Expr;
1352
1353 let mut acc: Option<Expr> = None;
1354 for f in filters {
1355 let Some(e) = ir_filter_to_expr(f, params) else {
1356 continue;
1357 };
1358 acc = Some(match acc {
1359 None => e,
1360 Some(prev) => Expr::BinaryExpr(datafusion::logical_expr::BinaryExpr::new(
1361 Box::new(prev),
1362 Operator::And,
1363 Box::new(e),
1364 )),
1365 });
1366 }
1367 acc
1368}
1369
1370pub(super) fn ir_filter_to_expr(
1374 filter: &IRFilter,
1375 params: &ParamMap,
1376) -> Option<datafusion::prelude::Expr> {
1377 use datafusion::functions_nested::expr_fn::array_has;
1378
1379 if is_search_filter(filter) {
1380 return None;
1381 }
1382
1383 if matches!(filter.op, CompOp::Contains) {
1387 let left = ir_expr_to_expr(&filter.left, params)?;
1388 let right = ir_expr_to_expr(&filter.right, params)?;
1389 return Some(array_has(left, right));
1390 }
1391
1392 let left = ir_expr_to_expr(&filter.left, params)?;
1393 let right = ir_expr_to_expr(&filter.right, params)?;
1394 Some(match filter.op {
1395 CompOp::Eq => left.eq(right),
1396 CompOp::Ne => left.not_eq(right),
1397 CompOp::Gt => left.gt(right),
1398 CompOp::Lt => left.lt(right),
1399 CompOp::Ge => left.gt_eq(right),
1400 CompOp::Le => left.lt_eq(right),
1401 CompOp::Contains => unreachable!("handled above"),
1402 })
1403}
1404
1405pub(super) fn ir_expr_to_expr(
1409 expr: &IRExpr,
1410 params: &ParamMap,
1411) -> Option<datafusion::prelude::Expr> {
1412 use datafusion::prelude::{col, lit};
1413 match expr {
1414 IRExpr::PropAccess { property, .. } => Some(col(property)),
1415 IRExpr::Literal(l) => literal_to_expr(l),
1416 IRExpr::Param(name) => params.get(name).and_then(literal_to_expr),
1417 _ => None,
1418 }
1419}
1420
1421fn literal_to_expr(lit: &Literal) -> Option<datafusion::prelude::Expr> {
1425 use datafusion::prelude::lit as df_lit;
1426 Some(match lit {
1427 Literal::Null => df_lit(datafusion::scalar::ScalarValue::Null),
1428 Literal::String(s) => df_lit(s.clone()),
1429 Literal::Integer(n) => df_lit(*n),
1430 Literal::Float(f) => df_lit(*f),
1431 Literal::Bool(b) => df_lit(*b),
1432 Literal::Date(s) => df_lit(s.clone()),
1436 Literal::DateTime(s) => df_lit(s.clone()),
1437 Literal::List(_) => return None,
1438 })
1439}
1440
1441fn prefix_batch(batch: &RecordBatch, variable: &str) -> Result<RecordBatch> {
1442 let fields: Vec<Field> = batch
1443 .schema()
1444 .fields()
1445 .iter()
1446 .map(|f| {
1447 Field::new(
1448 format!("{}.{}", variable, f.name()),
1449 f.data_type().clone(),
1450 f.is_nullable(),
1451 )
1452 })
1453 .collect();
1454 let schema = Arc::new(Schema::new(fields));
1455 RecordBatch::try_new(schema, batch.columns().to_vec())
1456 .map_err(|e| OmniError::Lance(e.to_string()))
1457}
1458
1459fn cross_join_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
1460 let n = left.num_rows();
1461 let m = right.num_rows();
1462 if n == 0 || m == 0 {
1463 let mut fields: Vec<Field> = left
1464 .schema()
1465 .fields()
1466 .iter()
1467 .map(|f| f.as_ref().clone())
1468 .collect();
1469 fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
1470 return Ok(RecordBatch::new_empty(Arc::new(Schema::new(fields))));
1471 }
1472 let left_indices: Vec<u32> = (0..n as u32)
1473 .flat_map(|i| std::iter::repeat(i).take(m))
1474 .collect();
1475 let right_indices: Vec<u32> = (0..n).flat_map(|_| 0..m as u32).collect();
1476 let left_expanded = take_batch(left, &UInt32Array::from(left_indices))?;
1477 let right_expanded = take_batch(right, &UInt32Array::from(right_indices))?;
1478 hconcat_batches(&left_expanded, &right_expanded)
1479}
1480
1481fn hconcat_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
1482 let mut fields: Vec<Field> = left
1483 .schema()
1484 .fields()
1485 .iter()
1486 .map(|f| f.as_ref().clone())
1487 .collect();
1488 if cfg!(debug_assertions) {
1489 let left_schema = left.schema();
1490 let left_names: HashSet<&str> = left_schema
1491 .fields()
1492 .iter()
1493 .map(|f| f.name().as_str())
1494 .collect();
1495 let right_schema = right.schema();
1496 for f in right_schema.fields() {
1497 debug_assert!(
1498 !left_names.contains(f.name().as_str()),
1499 "hconcat_batches: duplicate column '{}'",
1500 f.name()
1501 );
1502 }
1503 }
1504 fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
1505 let mut columns: Vec<ArrayRef> = left.columns().to_vec();
1506 columns.extend(right.columns().to_vec());
1507 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
1508 .map_err(|e| OmniError::Lance(e.to_string()))
1509}
1510
1511fn take_batch(batch: &RecordBatch, indices: &UInt32Array) -> Result<RecordBatch> {
1512 let columns: Vec<ArrayRef> = batch
1513 .columns()
1514 .iter()
1515 .map(|col| arrow_select::take::take(col.as_ref(), indices, None))
1516 .collect::<std::result::Result<Vec<_>, _>>()
1517 .map_err(|e| OmniError::Lance(e.to_string()))?;
1518 RecordBatch::try_new(batch.schema(), columns).map_err(|e| OmniError::Lance(e.to_string()))
1519}