1use super::*;
2
3use super::projection::{apply_filter, apply_ordering, project_return};
4
5impl Omnigraph {
6 pub async fn query(
8 &self,
9 target: impl Into<ReadTarget>,
10 query_source: &str,
11 query_name: &str,
12 params: &ParamMap,
13 ) -> Result<QueryResult> {
14 self.ensure_schema_state_valid().await?;
15 let resolved = self.resolved_target(target).await?;
16
17 let query_decl = omnigraph_compiler::find_named_query(query_source, query_name)
18 .map_err(|e| OmniError::manifest(e.to_string()))?;
19 let type_ctx = typecheck_query(self.catalog(), &query_decl)?;
20 let ir = lower_query(self.catalog(), &query_decl, &type_ctx)?;
21
22 let needs_graph = ir
23 .pipeline
24 .iter()
25 .any(|op| matches!(op, IROp::Expand { .. } | IROp::AntiJoin { .. }));
26 let graph_index = if needs_graph {
27 Some(self.graph_index_for_resolved(&resolved).await?)
28 } else {
29 None
30 };
31
32 execute_query(
33 &ir,
34 params,
35 &resolved.snapshot,
36 graph_index.as_deref(),
37 self.catalog(),
38 )
39 .await
40 }
41
42 pub async fn run_query_at(
47 &self,
48 version: u64,
49 query_source: &str,
50 query_name: &str,
51 params: &ParamMap,
52 ) -> Result<QueryResult> {
53 self.ensure_schema_state_valid().await?;
54 let snapshot = self.snapshot_at_version(version).await?;
55
56 let query_decl = omnigraph_compiler::find_named_query(query_source, query_name)
57 .map_err(|e| OmniError::manifest(e.to_string()))?;
58 let type_ctx = typecheck_query(self.catalog(), &query_decl)?;
59 let ir = lower_query(self.catalog(), &query_decl, &type_ctx)?;
60
61 let needs_graph = ir
62 .pipeline
63 .iter()
64 .any(|op| matches!(op, IROp::Expand { .. } | IROp::AntiJoin { .. }));
65 let graph_index = if needs_graph {
66 let edge_types = self
67 .catalog()
68 .edge_types
69 .iter()
70 .map(|(name, et)| (name.clone(), (et.from_type.clone(), et.to_type.clone())))
71 .collect();
72 Some(Arc::new(GraphIndex::build(&snapshot, &edge_types).await?))
73 } else {
74 None
75 };
76
77 execute_query(
78 &ir,
79 params,
80 &snapshot,
81 graph_index.as_deref(),
82 self.catalog(),
83 )
84 .await
85 }
86}
87
88#[derive(Debug, Default)]
92struct SearchMode {
93 nearest: Option<(String, String, Vec<f32>, usize)>,
95 bm25: Option<(String, String, String)>,
97 rrf: Option<RrfMode>,
99}
100
101#[derive(Debug)]
102struct RrfMode {
103 primary: Box<SearchMode>,
104 secondary: Box<SearchMode>,
105 k: u32,
106 limit: usize,
107}
108
109async fn extract_search_mode(
111 ir: &QueryIR,
112 params: &ParamMap,
113 catalog: &Catalog,
114) -> Result<SearchMode> {
115 if ir.order_by.is_empty() {
116 return Ok(SearchMode::default());
117 }
118 let ordering = &ir.order_by[0];
119 match &ordering.expr {
120 IRExpr::Nearest {
121 variable,
122 property,
123 query,
124 } => {
125 let vec =
126 resolve_nearest_query_vec(ir, catalog, variable, property, query, params).await?;
127 let k = ir.limit.ok_or_else(|| {
128 OmniError::manifest("nearest() ordering requires a limit clause".to_string())
129 })? as usize;
130 Ok(SearchMode {
131 nearest: Some((variable.clone(), property.clone(), vec, k)),
132 ..Default::default()
133 })
134 }
135 IRExpr::Bm25 { field, query } => {
136 let var = match field.as_ref() {
137 IRExpr::PropAccess { variable, .. } => variable.clone(),
138 _ => {
139 return Err(OmniError::manifest(
140 "bm25 field must be a property access".to_string(),
141 ));
142 }
143 };
144 let prop = extract_property(field).ok_or_else(|| {
145 OmniError::manifest("bm25 field must be a property access".to_string())
146 })?;
147 let text = resolve_to_string(query, params).ok_or_else(|| {
148 OmniError::manifest("bm25 query must resolve to a string".to_string())
149 })?;
150 Ok(SearchMode {
151 bm25: Some((var, prop, text)),
152 ..Default::default()
153 })
154 }
155 IRExpr::Rrf {
156 primary,
157 secondary,
158 k,
159 } => {
160 let limit = ir.limit.ok_or_else(|| {
161 OmniError::manifest("rrf() ordering requires a limit clause".to_string())
162 })? as usize;
163 let k_val = k
164 .as_ref()
165 .and_then(|e| resolve_to_int(e, params))
166 .unwrap_or(60) as u32;
167
168 let primary_mode =
169 extract_sub_search_mode(ir, primary, params, catalog, ir.limit).await?;
170 let secondary_mode =
171 extract_sub_search_mode(ir, secondary, params, catalog, ir.limit).await?;
172
173 Ok(SearchMode {
174 rrf: Some(RrfMode {
175 primary: Box::new(primary_mode),
176 secondary: Box::new(secondary_mode),
177 k: k_val,
178 limit,
179 }),
180 ..Default::default()
181 })
182 }
183 _ => Ok(SearchMode::default()),
184 }
185}
186
187async fn extract_sub_search_mode(
189 ir: &QueryIR,
190 expr: &IRExpr,
191 params: &ParamMap,
192 catalog: &Catalog,
193 limit: Option<u64>,
194) -> Result<SearchMode> {
195 match expr {
196 IRExpr::Nearest {
197 variable,
198 property,
199 query,
200 } => {
201 let vec =
202 resolve_nearest_query_vec(ir, catalog, variable, property, query, params).await?;
203 let k = limit.unwrap_or(100) as usize;
204 Ok(SearchMode {
205 nearest: Some((variable.clone(), property.clone(), vec, k)),
206 ..Default::default()
207 })
208 }
209 IRExpr::Bm25 { field, query } => {
210 let var = match field.as_ref() {
211 IRExpr::PropAccess { variable, .. } => variable.clone(),
212 _ => {
213 return Err(OmniError::manifest(
214 "bm25 field must be a property access".to_string(),
215 ));
216 }
217 };
218 let prop = extract_property(field).ok_or_else(|| {
219 OmniError::manifest("bm25 field must be a property access".to_string())
220 })?;
221 let text = resolve_to_string(query, params).ok_or_else(|| {
222 OmniError::manifest("bm25 query must resolve to a string".to_string())
223 })?;
224 Ok(SearchMode {
225 bm25: Some((var, prop, text)),
226 ..Default::default()
227 })
228 }
229 _ => Ok(SearchMode::default()),
230 }
231}
232
233async fn resolve_nearest_query_vec(
235 ir: &QueryIR,
236 catalog: &Catalog,
237 variable: &str,
238 property: &str,
239 expr: &IRExpr,
240 params: &ParamMap,
241) -> Result<Vec<f32>> {
242 let lit = resolve_literal_or_param(expr, params)?;
243 match lit {
244 Literal::List(_) => literal_to_f32_vec(&lit),
245 Literal::String(text) => {
246 let expected_dim = nearest_property_dimension(ir, catalog, variable, property)?;
247 EmbeddingClient::from_env()?
248 .embed_query_text(&text, expected_dim)
249 .await
250 }
251 _ => Err(OmniError::manifest(
252 "nearest query must be a string or list of floats".to_string(),
253 )),
254 }
255}
256
257fn resolve_literal_or_param(expr: &IRExpr, params: &ParamMap) -> Result<Literal> {
258 Ok(match expr {
259 IRExpr::Literal(lit) => lit.clone(),
260 IRExpr::Param(name) => params
261 .get(name)
262 .cloned()
263 .ok_or_else(|| OmniError::manifest(format!("parameter '{}' not provided", name)))?,
264 _ => {
265 return Err(OmniError::manifest(
266 "nearest query must be a literal or parameter".to_string(),
267 ));
268 }
269 })
270}
271
272fn literal_to_f32_vec(lit: &Literal) -> Result<Vec<f32>> {
274 match lit {
275 Literal::List(items) => items
276 .iter()
277 .map(|item| match item {
278 Literal::Float(f) => Ok(*f as f32),
279 Literal::Integer(n) => Ok(*n as f32),
280 _ => Err(OmniError::manifest(
281 "vector elements must be numeric".to_string(),
282 )),
283 })
284 .collect(),
285 _ => Err(OmniError::manifest(
286 "nearest query must be a list of floats".to_string(),
287 )),
288 }
289}
290
291fn nearest_property_dimension(
292 ir: &QueryIR,
293 catalog: &Catalog,
294 variable: &str,
295 property: &str,
296) -> Result<usize> {
297 let type_name = resolve_binding_type_name(&ir.pipeline, variable).ok_or_else(|| {
298 OmniError::manifest_internal(format!(
299 "nearest() variable '${}' is not bound to a node type in the lowered pipeline",
300 variable
301 ))
302 })?;
303 let node_type = catalog.node_types.get(type_name).ok_or_else(|| {
304 OmniError::manifest_internal(format!(
305 "nearest() binding '${}' resolved unknown node type '{}'",
306 variable, type_name
307 ))
308 })?;
309 let prop = node_type.properties.get(property).ok_or_else(|| {
310 OmniError::manifest_internal(format!(
311 "nearest() property '{}.{}' is missing from the catalog",
312 type_name, property
313 ))
314 })?;
315 match prop.scalar {
316 ScalarType::Vector(dim) if !prop.list => Ok(dim as usize),
317 _ => Err(OmniError::manifest_internal(format!(
318 "nearest() property '{}.{}' is not a scalar vector",
319 type_name, property
320 ))),
321 }
322}
323
324fn resolve_binding_type_name<'a>(pipeline: &'a [IROp], variable: &str) -> Option<&'a str> {
325 for op in pipeline {
326 match op {
327 IROp::NodeScan {
328 variable: bound_var,
329 type_name,
330 ..
331 } if bound_var == variable => return Some(type_name.as_str()),
332 IROp::Expand {
333 dst_var, dst_type, ..
334 } if dst_var == variable => return Some(dst_type.as_str()),
335 IROp::AntiJoin { inner, .. } => {
336 if let Some(type_name) = resolve_binding_type_name(inner, variable) {
337 return Some(type_name);
338 }
339 }
340 _ => {}
341 }
342 }
343 None
344}
345
346pub async fn execute_query(
348 ir: &QueryIR,
349 params: &ParamMap,
350 snapshot: &Snapshot,
351 graph_index: Option<&GraphIndex>,
352 catalog: &Catalog,
353) -> Result<QueryResult> {
354 let search_mode = extract_search_mode(ir, params, catalog).await?;
355
356 if let Some(ref rrf) = search_mode.rrf {
358 return execute_rrf_query(ir, params, snapshot, graph_index, catalog, rrf).await;
359 }
360
361 let mut wide: Option<RecordBatch> = None;
362 execute_pipeline(&ir.pipeline, params, snapshot, graph_index, catalog, &mut wide, &search_mode).await?;
363 let wide_batch = wide.unwrap_or_else(|| RecordBatch::new_empty(Arc::new(Schema::empty())));
364
365 let has_aggregates = ir.return_exprs.iter().any(|p| matches!(&p.expr, IRExpr::Aggregate { .. }));
367 let mut result_batch = project_return(&wide_batch, &ir.return_exprs, params)?;
368
369 if !ir.order_by.is_empty() && !is_search_ordered(&search_mode) {
371 result_batch = if has_aggregates {
372 apply_ordering(result_batch.clone(), &ir.order_by, &result_batch, params)?
373 } else {
374 apply_ordering(result_batch, &ir.order_by, &wide_batch, params)?
375 };
376 }
377
378 if let Some(limit) = ir.limit {
380 let len = result_batch.num_rows().min(limit as usize);
381 result_batch = result_batch.slice(0, len);
382 }
383
384 Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
385}
386
387fn is_search_ordered(search_mode: &SearchMode) -> bool {
389 search_mode.nearest.is_some() || search_mode.bm25.is_some()
390}
391
392async fn execute_rrf_query(
394 ir: &QueryIR,
395 params: &ParamMap,
396 snapshot: &Snapshot,
397 graph_index: Option<&GraphIndex>,
398 catalog: &Catalog,
399 rrf: &RrfMode,
400) -> Result<QueryResult> {
401 let mut primary_wide: Option<RecordBatch> = None;
403 execute_pipeline(
404 &ir.pipeline,
405 params,
406 snapshot,
407 graph_index,
408 catalog,
409 &mut primary_wide,
410 &rrf.primary,
411 )
412 .await?;
413
414 let mut secondary_wide: Option<RecordBatch> = None;
416 execute_pipeline(
417 &ir.pipeline,
418 params,
419 snapshot,
420 graph_index,
421 catalog,
422 &mut secondary_wide,
423 &rrf.secondary,
424 )
425 .await?;
426
427 let primary_var = rrf
430 .primary
431 .nearest
432 .as_ref()
433 .map(|(v, ..)| v.as_str())
434 .or_else(|| rrf.primary.bm25.as_ref().map(|(v, ..)| v.as_str()))
435 .ok_or_else(|| OmniError::manifest("rrf primary must be nearest or bm25".to_string()))?;
436
437 let primary_batch = primary_wide.as_ref().ok_or_else(|| {
438 OmniError::manifest(format!(
439 "rrf primary variable '{}' not in bindings",
440 primary_var
441 ))
442 })?;
443 let secondary_batch = secondary_wide.as_ref().ok_or_else(|| {
444 OmniError::manifest(format!(
445 "rrf secondary variable '{}' not in bindings",
446 primary_var
447 ))
448 })?;
449
450 let id_col_name = format!("{}.id", primary_var);
452 let primary_ids = extract_id_column_by_name(primary_batch, &id_col_name)?;
453 let secondary_ids = extract_id_column_by_name(secondary_batch, &id_col_name)?;
454
455 let mut primary_rank: HashMap<String, usize> = HashMap::new();
456 for (i, id) in primary_ids.iter().enumerate() {
457 primary_rank.entry(id.clone()).or_insert(i);
458 }
459 let mut secondary_rank: HashMap<String, usize> = HashMap::new();
460 for (i, id) in secondary_ids.iter().enumerate() {
461 secondary_rank.entry(id.clone()).or_insert(i);
462 }
463
464 let mut all_ids: Vec<String> = primary_ids.clone();
466 for id in &secondary_ids {
467 if !primary_rank.contains_key(id) {
468 all_ids.push(id.clone());
469 }
470 }
471
472 let k = rrf.k as f64;
474 let mut scored: Vec<(String, f64)> = all_ids
475 .iter()
476 .map(|id| {
477 let p = primary_rank
478 .get(id)
479 .map(|&r| 1.0 / (k + r as f64 + 1.0))
480 .unwrap_or(0.0);
481 let s = secondary_rank
482 .get(id)
483 .map(|&r| 1.0 / (k + r as f64 + 1.0))
484 .unwrap_or(0.0);
485 (id.clone(), p + s)
486 })
487 .collect();
488 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
489 scored.truncate(rrf.limit);
490
491 let winning_ids: Vec<String> = scored.iter().map(|(id, _)| id.clone()).collect();
493
494 let mut id_to_batch_row: HashMap<String, (&RecordBatch, usize)> = HashMap::new();
496 for (i, id) in primary_ids.iter().enumerate() {
497 id_to_batch_row
498 .entry(id.clone())
499 .or_insert((primary_batch, i));
500 }
501 for (i, id) in secondary_ids.iter().enumerate() {
502 id_to_batch_row
503 .entry(id.clone())
504 .or_insert((secondary_batch, i));
505 }
506
507 let fused_batch = build_fused_batch(&winning_ids, &id_to_batch_row, primary_batch.schema())?;
509
510 let result_batch = project_return(&fused_batch, &ir.return_exprs, params)?;
512
513 Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
515}
516
517fn extract_id_column_by_name(batch: &RecordBatch, col_name: &str) -> Result<Vec<String>> {
518 let col = batch
519 .column_by_name(col_name)
520 .ok_or_else(|| OmniError::manifest(format!("batch missing '{}' column for RRF", col_name)))?;
521 let ids = col
522 .as_any()
523 .downcast_ref::<StringArray>()
524 .ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", col_name)))?;
525 Ok((0..ids.len()).map(|i| ids.value(i).to_string()).collect())
526}
527
528fn build_fused_batch(
529 ordered_ids: &[String],
530 id_to_batch_row: &HashMap<String, (&RecordBatch, usize)>,
531 schema: SchemaRef,
532) -> Result<RecordBatch> {
533 if ordered_ids.is_empty() {
534 return Ok(RecordBatch::new_empty(schema));
535 }
536
537 let mut row_slices: Vec<RecordBatch> = Vec::with_capacity(ordered_ids.len());
539 for id in ordered_ids {
540 if let Some(&(batch, row_idx)) = id_to_batch_row.get(id) {
541 row_slices.push(batch.slice(row_idx, 1));
542 }
543 }
544
545 if row_slices.is_empty() {
546 return Ok(RecordBatch::new_empty(schema));
547 }
548
549 let schema = row_slices[0].schema();
550 arrow_select::concat::concat_batches(&schema, &row_slices)
551 .map_err(|e| OmniError::Lance(e.to_string()))
552}
553
554fn is_search_filter(filter: &IRFilter) -> bool {
556 matches!(
557 &filter.left,
558 IRExpr::Search { .. } | IRExpr::Fuzzy { .. } | IRExpr::MatchText { .. }
559 )
560}
561
562fn search_filter_variable(filter: &IRFilter) -> Option<&str> {
564 let field = match &filter.left {
565 IRExpr::Search { field, .. } => field,
566 IRExpr::Fuzzy { field, .. } => field,
567 IRExpr::MatchText { field, .. } => field,
568 _ => return None,
569 };
570 match field.as_ref() {
571 IRExpr::PropAccess { variable, .. } => Some(variable.as_str()),
572 _ => None,
573 }
574}
575
576fn execute_pipeline<'a>(
577 pipeline: &'a [IROp],
578 params: &'a ParamMap,
579 snapshot: &'a Snapshot,
580 graph_index: Option<&'a GraphIndex>,
581 catalog: &'a Catalog,
582 wide: &'a mut Option<RecordBatch>,
583 search_mode: &'a SearchMode,
584) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
585 Box::pin(async move {
586 let mut hoisted_search_filters: HashMap<String, Vec<IRFilter>> = HashMap::new();
588 let mut hoisted_indices: HashSet<usize> = HashSet::new();
589 for (i, op) in pipeline.iter().enumerate() {
590 if let IROp::Filter(filter) = op {
591 if is_search_filter(filter) {
592 if let Some(var) = search_filter_variable(filter) {
593 hoisted_search_filters
594 .entry(var.to_string())
595 .or_default()
596 .push(filter.clone());
597 hoisted_indices.insert(i);
598 }
599 }
600 }
601 }
602
603 for (i, op) in pipeline.iter().enumerate() {
604 if hoisted_indices.contains(&i) {
606 continue;
607 }
608 match op {
609 IROp::NodeScan {
610 variable,
611 type_name,
612 filters,
613 } => {
614 let mut all_filters: Vec<IRFilter> = filters.clone();
616 if let Some(extra) = hoisted_search_filters.get(variable) {
617 all_filters.extend(extra.iter().cloned());
618 }
619 let batch = execute_node_scan(
620 type_name,
621 variable,
622 &all_filters,
623 params,
624 snapshot,
625 catalog,
626 search_mode,
627 )
628 .await?;
629 let prefixed = prefix_batch(&batch, variable)?;
630 *wide = Some(match wide.take() {
631 None => prefixed,
632 Some(existing) => cross_join_batches(&existing, &prefixed)?,
633 });
634 }
635 IROp::Filter(filter) => {
636 if let Some(batch) = wide.as_mut() {
637 apply_filter(batch, filter, params)?;
638 }
639 }
640 IROp::Expand {
641 src_var,
642 dst_var,
643 edge_type,
644 direction,
645 dst_type,
646 min_hops,
647 max_hops,
648 dst_filters,
649 } => {
650 let gi = graph_index.ok_or_else(|| {
651 OmniError::manifest("graph index required for traversal".to_string())
652 })?;
653 if let Some(batch) = wide.as_mut() {
654 execute_expand(
655 batch, gi, snapshot, catalog, src_var, dst_var, edge_type, *direction,
656 dst_type, *min_hops, *max_hops, dst_filters, params,
657 )
658 .await?;
659 }
660 }
661 IROp::AntiJoin { outer_var, inner } => {
662 let gi = graph_index;
663 if let Some(batch) = wide.as_mut() {
664 execute_anti_join(batch, inner, params, snapshot, gi, catalog, outer_var)
665 .await?;
666 }
667 }
668 }
669 }
670 Ok(())
671 })
672}
673
674async fn execute_expand(
676 wide: &mut RecordBatch,
677 graph_index: &GraphIndex,
678 snapshot: &Snapshot,
679 catalog: &Catalog,
680 src_var: &str,
681 dst_var: &str,
682 edge_type: &str,
683 direction: Direction,
684 dst_type: &str,
685 min_hops: u32,
686 max_hops: Option<u32>,
687 dst_filters: &[IRFilter],
688 params: &ParamMap,
689) -> Result<()> {
690 let src_id_col_name = format!("{}.id", src_var);
691 let src_ids = wide
692 .column_by_name(&src_id_col_name)
693 .ok_or_else(|| OmniError::manifest(format!("wide batch missing '{}' column", src_id_col_name)))?
694 .as_any()
695 .downcast_ref::<StringArray>()
696 .ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", src_id_col_name)))?
697 .clone();
698
699 let edge_def = catalog
701 .edge_types
702 .get(edge_type)
703 .ok_or_else(|| OmniError::manifest(format!("unknown edge type '{}'", edge_type)))?;
704
705 let (src_type_name, dst_type_name) = match direction {
706 Direction::Out => (&edge_def.from_type, &edge_def.to_type),
707 Direction::In => (&edge_def.to_type, &edge_def.from_type),
708 };
709
710 let src_type_idx = graph_index
711 .type_index(src_type_name)
712 .ok_or_else(|| OmniError::manifest(format!("no type index for '{}'", src_type_name)))?;
713 let dst_type_idx = graph_index
714 .type_index(dst_type_name)
715 .ok_or_else(|| OmniError::manifest(format!("no type index for '{}'", dst_type_name)))?;
716
717 let adj = match direction {
718 Direction::Out => graph_index.csr(edge_type),
719 Direction::In => graph_index.csc(edge_type),
720 }
721 .ok_or_else(|| OmniError::manifest(format!("no adjacency index for edge '{}'", edge_type)))?;
722
723 let max = max_hops.unwrap_or(min_hops.max(1));
724
725 let same_type = src_type_name == dst_type_name;
726
727 let mut src_indices: Vec<u32> = Vec::new();
729 let mut dst_id_list: Vec<String> = Vec::new();
730 for i in 0..src_ids.len() {
731 let src_id = src_ids.value(i);
732 let Some(src_dense) = src_type_idx.to_dense(src_id) else {
733 continue;
734 };
735
736 let mut frontier: Vec<u32> = vec![src_dense];
738 let mut visited: HashSet<u32> = HashSet::new();
739 let mut seen_dst_ids: HashSet<String> = HashSet::new();
740 if same_type {
744 visited.insert(src_dense);
745 }
746
747 for hop in 1..=max {
748 let mut next_frontier = Vec::new();
749 for &node in &frontier {
750 for &neighbor in adj.neighbors(node) {
751 if !same_type || visited.insert(neighbor) {
752 next_frontier.push(neighbor);
753 if hop >= min_hops {
754 if let Some(dst_id) = dst_type_idx.to_id(neighbor) {
755 let dst_id = dst_id.to_string();
756 if seen_dst_ids.insert(dst_id.clone()) {
757 src_indices.push(i as u32);
758 dst_id_list.push(dst_id);
759 }
760 }
761 }
762 }
763 }
764 }
765 frontier = next_frontier;
766 if frontier.is_empty() {
767 break;
768 }
769 }
770 }
771
772 let pushdown_sql = build_lance_filter(dst_filters, params);
774 let non_pushable: Vec<&IRFilter> = dst_filters
775 .iter()
776 .filter(|f| ir_filter_to_sql(f, params).is_none())
777 .collect();
778
779 let dst_batch = hydrate_nodes(snapshot, catalog, dst_type, &dst_id_list, pushdown_sql.as_deref()).await?;
781
782 let dst_batch_id_col = dst_batch
784 .column_by_name("id")
785 .ok_or_else(|| OmniError::manifest("hydrated batch missing 'id' column".to_string()))?
786 .as_any()
787 .downcast_ref::<StringArray>()
788 .ok_or_else(|| OmniError::manifest("hydrated 'id' column is not Utf8".to_string()))?;
789 let dst_id_to_row: HashMap<&str, usize> = (0..dst_batch_id_col.len())
790 .map(|i| (dst_batch_id_col.value(i), i))
791 .collect();
792
793 let mut final_src_indices: Vec<u32> = Vec::new();
795 let mut dst_indices: Vec<u32> = Vec::new();
796 for (src_idx, dst_id) in src_indices.iter().zip(dst_id_list.iter()) {
797 if let Some(&dst_row) = dst_id_to_row.get(dst_id.as_str()) {
798 final_src_indices.push(*src_idx);
799 dst_indices.push(dst_row as u32);
800 }
801 }
802
803 let src_take = UInt32Array::from(final_src_indices);
804 let dst_take = UInt32Array::from(dst_indices);
805 let expanded_wide = take_batch(wide, &src_take)?;
806 let dst_prefixed = prefix_batch(&dst_batch, dst_var)?;
807 let aligned_dst = take_batch(&dst_prefixed, &dst_take)?;
808 *wide = hconcat_batches(&expanded_wide, &aligned_dst)?;
809
810 for f in &non_pushable {
812 apply_filter(wide, f, params)?;
813 }
814
815 Ok(())
816}
817
818async fn hydrate_nodes(
824 snapshot: &Snapshot,
825 catalog: &Catalog,
826 type_name: &str,
827 ids: &[String],
828 extra_filter_sql: Option<&str>,
829) -> Result<RecordBatch> {
830 let node_type = catalog
831 .node_types
832 .get(type_name)
833 .ok_or_else(|| OmniError::manifest(format!("unknown node type '{}'", type_name)))?;
834
835 if ids.is_empty() {
836 return Ok(RecordBatch::new_empty(node_type.arrow_schema.clone()));
837 }
838
839 let table_key = format!("node:{}", type_name);
840 let ds = snapshot.open(&table_key).await?;
841
842 let escaped: Vec<String> = ids
844 .iter()
845 .map(|id| format!("'{}'", id.replace('\'', "''")))
846 .collect();
847 let mut filter_sql = format!("id IN ({})", escaped.join(", "));
848 if let Some(extra) = extra_filter_sql {
849 filter_sql = format!("({}) AND ({})", filter_sql, extra);
850 }
851 let has_blobs = !node_type.blob_properties.is_empty();
852 let non_blob_cols: Vec<&str> = node_type
853 .arrow_schema
854 .fields()
855 .iter()
856 .filter(|f| !node_type.blob_properties.contains(f.name()))
857 .map(|f| f.name().as_str())
858 .collect();
859 let projection = has_blobs.then_some(non_blob_cols.as_slice());
860 let batches = crate::table_store::TableStore::scan_stream(
861 &ds,
862 projection,
863 Some(&filter_sql),
864 None,
865 false,
866 )
867 .await?
868 .try_collect::<Vec<RecordBatch>>()
869 .await
870 .map_err(|e| OmniError::Lance(e.to_string()))?;
871
872 let scan_result = if batches.is_empty() {
873 return Ok(RecordBatch::new_empty(node_type.arrow_schema.clone()));
874 } else if batches.len() == 1 {
875 batches.into_iter().next().unwrap()
876 } else {
877 let schema = batches[0].schema();
878 arrow_select::concat::concat_batches(&schema, &batches)
879 .map_err(|e| OmniError::Lance(e.to_string()))?
880 };
881
882 if has_blobs {
883 return add_null_blob_columns(&scan_result, node_type);
884 }
885 Ok(scan_result)
886}
887
888fn try_bulk_anti_join_mask(
891 wide: &RecordBatch,
892 inner_pipeline: &[IROp],
893 graph_index: Option<&GraphIndex>,
894 catalog: &Catalog,
895 outer_var: &str,
896) -> Option<BooleanArray> {
897 if inner_pipeline.len() != 1 {
898 return None;
899 }
900 let IROp::Expand {
901 src_var,
902 edge_type,
903 direction,
904 dst_filters,
905 ..
906 } = &inner_pipeline[0]
907 else {
908 return None;
909 };
910 if src_var != outer_var {
911 return None;
912 }
913 if !dst_filters.is_empty() {
916 return None;
917 }
918 let gi = graph_index?;
919 let edge_def = catalog.edge_types.get(edge_type.as_str())?;
920
921 let src_type_name = match direction {
922 Direction::Out => &edge_def.from_type,
923 Direction::In => &edge_def.to_type,
924 };
925 let adj = match direction {
926 Direction::Out => gi.csr(edge_type),
927 Direction::In => gi.csc(edge_type),
928 }?;
929 let type_idx = gi.type_index(src_type_name)?;
930
931 let id_col_name = format!("{}.id", outer_var);
932 let outer_ids = wide
933 .column_by_name(&id_col_name)?
934 .as_any()
935 .downcast_ref::<StringArray>()?;
936
937 let keep_mask: Vec<bool> = (0..outer_ids.len())
938 .map(|i| {
939 let id = outer_ids.value(i);
940 match type_idx.to_dense(id) {
941 Some(dense) => !adj.has_neighbors(dense),
942 None => true, }
944 })
945 .collect();
946
947 Some(BooleanArray::from(keep_mask))
948}
949
950async fn execute_anti_join(
952 wide: &mut RecordBatch,
953 inner_pipeline: &[IROp],
954 params: &ParamMap,
955 snapshot: &Snapshot,
956 graph_index: Option<&GraphIndex>,
957 catalog: &Catalog,
958 outer_var: &str,
959) -> Result<()> {
960 if let Some(mask) =
962 try_bulk_anti_join_mask(wide, inner_pipeline, graph_index, catalog, outer_var)
963 {
964 *wide = arrow_select::filter::filter_record_batch(wide, &mask)
965 .map_err(|e| OmniError::Lance(e.to_string()))?;
966 return Ok(());
967 }
968
969 let num_rows = wide.num_rows();
971 let mut keep_mask = vec![true; num_rows];
972
973 for i in 0..num_rows {
974 let single_row = wide.slice(i, 1);
975 let mut inner_wide: Option<RecordBatch> = Some(single_row);
976
977 let no_search = SearchMode::default();
978 execute_pipeline(
979 inner_pipeline,
980 params,
981 snapshot,
982 graph_index,
983 catalog,
984 &mut inner_wide,
985 &no_search,
986 )
987 .await?;
988
989 let has_match = inner_wide
990 .as_ref()
991 .map(|batch| batch.num_rows() > 0)
992 .unwrap_or(false);
993
994 if has_match {
995 keep_mask[i] = false;
996 }
997 }
998
999 let mask = BooleanArray::from(keep_mask);
1000 *wide = arrow_select::filter::filter_record_batch(wide, &mask)
1001 .map_err(|e| OmniError::Lance(e.to_string()))?;
1002 Ok(())
1003}
1004
1005async fn execute_node_scan(
1007 type_name: &str,
1008 variable: &str,
1009 filters: &[IRFilter],
1010 params: &ParamMap,
1011 snapshot: &Snapshot,
1012 catalog: &Catalog,
1013 search_mode: &SearchMode,
1014) -> Result<RecordBatch> {
1015 let table_key = format!("node:{}", type_name);
1016 let ds = snapshot.open(&table_key).await?;
1017
1018 let filter_sql = build_lance_filter(filters, params);
1020
1021 let node_type = &catalog.node_types[type_name];
1025 let has_blobs = !node_type.blob_properties.is_empty();
1026 let non_blob_cols: Vec<&str> = node_type
1027 .arrow_schema
1028 .fields()
1029 .iter()
1030 .filter(|f| !node_type.blob_properties.contains(f.name()))
1031 .map(|f| f.name().as_str())
1032 .collect();
1033 let projection = has_blobs.then_some(non_blob_cols.as_slice());
1034 let batches = crate::table_store::TableStore::scan_stream_with(
1035 &ds,
1036 projection,
1037 filter_sql.as_deref(),
1038 None,
1039 false,
1040 |scanner| {
1041 for filter in filters {
1043 if is_search_filter(filter) {
1044 if let Some(fts_query) = build_fts_query(&filter.left, params) {
1045 scanner.full_text_search(fts_query).map_err(|e| {
1046 OmniError::Lance(format!("full_text_search filter: {}", e))
1047 })?;
1048 }
1049 }
1050 }
1051
1052 if let Some((ref var, ref prop, ref vec, k)) = search_mode.nearest {
1054 if var == variable {
1055 let query_arr = Float32Array::from(vec.clone());
1056 scanner
1057 .nearest(prop, &query_arr, k)
1058 .map_err(|e| OmniError::Lance(format!("nearest: {}", e)))?;
1059 }
1060 }
1061
1062 if let Some((ref var, ref prop, ref text)) = search_mode.bm25 {
1064 if var == variable {
1065 let fts_query = lance_index::scalar::FullTextSearchQuery::new(text.clone())
1066 .with_column(prop.clone())
1067 .map_err(|e| OmniError::Lance(format!("fts with_column: {}", e)))?;
1068 scanner
1069 .full_text_search(fts_query)
1070 .map_err(|e| OmniError::Lance(format!("full_text_search: {}", e)))?;
1071 }
1072 }
1073 Ok(())
1074 },
1075 )
1076 .await?
1077 .try_collect::<Vec<RecordBatch>>()
1078 .await
1079 .map_err(|e| OmniError::Lance(e.to_string()))?;
1080
1081 let scan_result = if batches.is_empty() {
1082 RecordBatch::new_empty(batches.first().map(|b| b.schema()).unwrap_or_else(|| {
1083 let fields: Vec<_> = node_type
1085 .arrow_schema
1086 .fields()
1087 .iter()
1088 .filter(|f| !node_type.blob_properties.contains(f.name()))
1089 .map(|f| f.as_ref().clone())
1090 .collect();
1091 Arc::new(Schema::new(fields))
1092 }))
1093 } else if batches.len() == 1 {
1094 batches.into_iter().next().unwrap()
1095 } else {
1096 let schema = batches[0].schema();
1097 arrow_select::concat::concat_batches(&schema, &batches)
1098 .map_err(|e| OmniError::Lance(e.to_string()))?
1099 };
1100
1101 if has_blobs {
1103 return add_null_blob_columns(&scan_result, node_type);
1104 }
1105 Ok(scan_result)
1106}
1107
1108fn add_null_blob_columns(
1111 batch: &RecordBatch,
1112 node_type: &omnigraph_compiler::catalog::NodeType,
1113) -> Result<RecordBatch> {
1114 let num_rows = batch.num_rows();
1115 let mut fields = Vec::with_capacity(node_type.arrow_schema.fields().len());
1116 let mut columns: Vec<ArrayRef> = Vec::with_capacity(node_type.arrow_schema.fields().len());
1117
1118 for field in node_type.arrow_schema.fields() {
1119 if node_type.blob_properties.contains(field.name()) {
1120 fields.push(Field::new(field.name(), DataType::Utf8, true));
1121 columns.push(Arc::new(StringArray::from(vec![None::<&str>; num_rows])));
1122 } else if let Some(col) = batch.column_by_name(field.name()) {
1123 let batch_schema = batch.schema();
1124 let batch_field = batch_schema
1125 .field_with_name(field.name())
1126 .map_err(|e| OmniError::Lance(e.to_string()))?;
1127 fields.push(batch_field.clone());
1128 columns.push(col.clone());
1129 }
1130 }
1131
1132 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
1133 .map_err(|e| OmniError::Lance(e.to_string()))
1134}
1135
1136fn build_lance_filter(filters: &[IRFilter], params: &ParamMap) -> Option<String> {
1138 if filters.is_empty() {
1139 return None;
1140 }
1141
1142 let parts: Vec<String> = filters
1143 .iter()
1144 .filter_map(|f| ir_filter_to_sql(f, params))
1145 .collect();
1146
1147 if parts.is_empty() {
1148 return None;
1149 }
1150
1151 Some(parts.join(" AND "))
1152}
1153
1154fn ir_filter_to_sql(filter: &IRFilter, params: &ParamMap) -> Option<String> {
1155 if is_search_filter(filter) {
1158 return None;
1159 }
1160
1161 let left = ir_expr_to_sql(&filter.left, params)?;
1162 let right = ir_expr_to_sql(&filter.right, params)?;
1163 let op = match filter.op {
1164 CompOp::Eq => "=",
1165 CompOp::Ne => "!=",
1166 CompOp::Gt => ">",
1167 CompOp::Lt => "<",
1168 CompOp::Ge => ">=",
1169 CompOp::Le => "<=",
1170 CompOp::Contains => return None, };
1172 Some(format!("{} {} {}", left, op, right))
1173}
1174
1175fn build_fts_query(
1177 expr: &IRExpr,
1178 params: &ParamMap,
1179) -> Option<lance_index::scalar::FullTextSearchQuery> {
1180 match expr {
1181 IRExpr::Search { field, query } => {
1182 let prop = extract_property(field)?;
1183 let q = resolve_to_string(query, params)?;
1184 lance_index::scalar::FullTextSearchQuery::new(q)
1185 .with_column(prop)
1186 .ok()
1187 }
1188 IRExpr::Fuzzy {
1189 field,
1190 query,
1191 max_edits,
1192 } => {
1193 let prop = extract_property(field)?;
1194 let q = resolve_to_string(query, params)?;
1195 let edits = max_edits
1196 .as_ref()
1197 .and_then(|e| resolve_to_int(e, params))
1198 .unwrap_or(2) as u32;
1199 lance_index::scalar::FullTextSearchQuery::new_fuzzy(q, Some(edits))
1200 .with_column(prop)
1201 .ok()
1202 }
1203 IRExpr::MatchText { field, query } => {
1204 let prop = extract_property(field)?;
1206 let q = resolve_to_string(query, params)?;
1207 lance_index::scalar::FullTextSearchQuery::new(q)
1208 .with_column(prop)
1209 .ok()
1210 }
1211 _ => None,
1212 }
1213}
1214
1215fn extract_property(expr: &IRExpr) -> Option<String> {
1217 match expr {
1218 IRExpr::PropAccess { property, .. } => Some(property.clone()),
1219 _ => None,
1220 }
1221}
1222
1223fn resolve_to_string(expr: &IRExpr, params: &ParamMap) -> Option<String> {
1225 match expr {
1226 IRExpr::Literal(Literal::String(s)) => Some(s.clone()),
1227 IRExpr::Param(name) => match params.get(name)? {
1228 Literal::String(s) => Some(s.clone()),
1229 _ => None,
1230 },
1231 _ => None,
1232 }
1233}
1234
1235fn resolve_to_int(expr: &IRExpr, params: &ParamMap) -> Option<i64> {
1237 match expr {
1238 IRExpr::Literal(Literal::Integer(n)) => Some(*n),
1239 IRExpr::Param(name) => match params.get(name)? {
1240 Literal::Integer(n) => Some(*n),
1241 _ => None,
1242 },
1243 _ => None,
1244 }
1245}
1246
1247fn ir_expr_to_sql(expr: &IRExpr, params: &ParamMap) -> Option<String> {
1248 match expr {
1249 IRExpr::PropAccess { property, .. } => Some(property.clone()),
1250 IRExpr::Literal(lit) => Some(literal_to_sql(lit)),
1251 IRExpr::Param(name) => params.get(name).map(literal_to_sql),
1252 _ => None,
1253 }
1254}
1255
1256pub(super) fn literal_to_sql(lit: &Literal) -> String {
1257 match lit {
1258 Literal::Null => "NULL".to_string(),
1259 Literal::String(s) => format!("'{}'", s.replace('\'', "''")),
1260 Literal::Integer(n) => n.to_string(),
1261 Literal::Float(f) => f.to_string(),
1262 Literal::Bool(b) => b.to_string(),
1263 Literal::Date(s) => format!("'{}'", s.replace('\'', "''")),
1264 Literal::DateTime(s) => format!("'{}'", s.replace('\'', "''")),
1265 Literal::List(_) => "NULL".to_string(), }
1267}
1268
1269fn prefix_batch(batch: &RecordBatch, variable: &str) -> Result<RecordBatch> {
1270 let fields: Vec<Field> = batch.schema().fields().iter().map(|f| {
1271 Field::new(format!("{}.{}", variable, f.name()), f.data_type().clone(), f.is_nullable())
1272 }).collect();
1273 let schema = Arc::new(Schema::new(fields));
1274 RecordBatch::try_new(schema, batch.columns().to_vec()).map_err(|e| OmniError::Lance(e.to_string()))
1275}
1276
1277fn cross_join_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
1278 let n = left.num_rows();
1279 let m = right.num_rows();
1280 if n == 0 || m == 0 {
1281 let mut fields: Vec<Field> = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect();
1282 fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
1283 return Ok(RecordBatch::new_empty(Arc::new(Schema::new(fields))));
1284 }
1285 let left_indices: Vec<u32> = (0..n as u32).flat_map(|i| std::iter::repeat(i).take(m)).collect();
1286 let right_indices: Vec<u32> = (0..n).flat_map(|_| 0..m as u32).collect();
1287 let left_expanded = take_batch(left, &UInt32Array::from(left_indices))?;
1288 let right_expanded = take_batch(right, &UInt32Array::from(right_indices))?;
1289 hconcat_batches(&left_expanded, &right_expanded)
1290}
1291
1292fn hconcat_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
1293 let mut fields: Vec<Field> = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect();
1294 if cfg!(debug_assertions) {
1295 let left_schema = left.schema();
1296 let left_names: HashSet<&str> = left_schema.fields().iter().map(|f| f.name().as_str()).collect();
1297 let right_schema = right.schema();
1298 for f in right_schema.fields() {
1299 debug_assert!(!left_names.contains(f.name().as_str()), "hconcat_batches: duplicate column '{}'", f.name());
1300 }
1301 }
1302 fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
1303 let mut columns: Vec<ArrayRef> = left.columns().to_vec();
1304 columns.extend(right.columns().to_vec());
1305 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns).map_err(|e| OmniError::Lance(e.to_string()))
1306}
1307
1308fn take_batch(batch: &RecordBatch, indices: &UInt32Array) -> Result<RecordBatch> {
1309 let columns: Vec<ArrayRef> = batch.columns().iter()
1310 .map(|col| arrow_select::take::take(col.as_ref(), indices, None))
1311 .collect::<std::result::Result<Vec<_>, _>>()
1312 .map_err(|e| OmniError::Lance(e.to_string()))?;
1313 RecordBatch::try_new(batch.schema(), columns).map_err(|e| OmniError::Lance(e.to_string()))
1314}