1use super::*;
2
3use super::projection::{apply_filter, apply_ordering, project_return};
4
5impl Omnigraph {
6 pub async fn query(
8 &self,
9 target: impl Into<ReadTarget>,
10 query_source: &str,
11 query_name: &str,
12 params: &ParamMap,
13 ) -> Result<QueryResult> {
14 self.ensure_schema_state_valid().await?;
15 let resolved = self.resolved_target(target).await?;
16 let catalog = self.catalog();
17
18 let query_decl = omnigraph_compiler::find_named_query(query_source, query_name)
19 .map_err(|e| OmniError::manifest(e.to_string()))?;
20 let type_ctx = typecheck_query(&catalog, &query_decl)?;
21 let ir = lower_query(&catalog, &query_decl, &type_ctx)?;
22
23 let needs_graph = ir
24 .pipeline
25 .iter()
26 .any(|op| matches!(op, IROp::Expand { .. } | IROp::AntiJoin { .. }));
27 let graph_index = if needs_graph {
28 Some(self.graph_index_for_resolved(&resolved).await?)
29 } else {
30 None
31 };
32
33 execute_query(
34 &ir,
35 params,
36 &resolved.snapshot,
37 graph_index.as_deref(),
38 &catalog,
39 )
40 .await
41 }
42
43 pub async fn run_query_at(
48 &self,
49 version: u64,
50 query_source: &str,
51 query_name: &str,
52 params: &ParamMap,
53 ) -> Result<QueryResult> {
54 self.ensure_schema_state_valid().await?;
55 let snapshot = self.snapshot_at_version(version).await?;
56 let catalog = self.catalog();
57
58 let query_decl = omnigraph_compiler::find_named_query(query_source, query_name)
59 .map_err(|e| OmniError::manifest(e.to_string()))?;
60 let type_ctx = typecheck_query(&catalog, &query_decl)?;
61 let ir = lower_query(&catalog, &query_decl, &type_ctx)?;
62
63 let needs_graph = ir
64 .pipeline
65 .iter()
66 .any(|op| matches!(op, IROp::Expand { .. } | IROp::AntiJoin { .. }));
67 let graph_index = if needs_graph {
68 let edge_types = catalog
69 .edge_types
70 .iter()
71 .map(|(name, et)| (name.clone(), (et.from_type.clone(), et.to_type.clone())))
72 .collect();
73 Some(Arc::new(GraphIndex::build(&snapshot, &edge_types).await?))
74 } else {
75 None
76 };
77
78 execute_query(
79 &ir,
80 params,
81 &snapshot,
82 graph_index.as_deref(),
83 &catalog,
84 )
85 .await
86 }
87}
88
89#[derive(Debug, Default)]
93struct SearchMode {
94 nearest: Option<(String, String, Vec<f32>, usize)>,
96 bm25: Option<(String, String, String)>,
98 rrf: Option<RrfMode>,
100}
101
102#[derive(Debug)]
103struct RrfMode {
104 primary: Box<SearchMode>,
105 secondary: Box<SearchMode>,
106 k: u32,
107 limit: usize,
108}
109
110async fn extract_search_mode(
112 ir: &QueryIR,
113 params: &ParamMap,
114 catalog: &Catalog,
115) -> Result<SearchMode> {
116 if ir.order_by.is_empty() {
117 return Ok(SearchMode::default());
118 }
119 let ordering = &ir.order_by[0];
120 match &ordering.expr {
121 IRExpr::Nearest {
122 variable,
123 property,
124 query,
125 } => {
126 let vec =
127 resolve_nearest_query_vec(ir, catalog, variable, property, query, params).await?;
128 let k = ir.limit.ok_or_else(|| {
129 OmniError::manifest("nearest() ordering requires a limit clause".to_string())
130 })? as usize;
131 Ok(SearchMode {
132 nearest: Some((variable.clone(), property.clone(), vec, k)),
133 ..Default::default()
134 })
135 }
136 IRExpr::Bm25 { field, query } => {
137 let var = match field.as_ref() {
138 IRExpr::PropAccess { variable, .. } => variable.clone(),
139 _ => {
140 return Err(OmniError::manifest(
141 "bm25 field must be a property access".to_string(),
142 ));
143 }
144 };
145 let prop = extract_property(field).ok_or_else(|| {
146 OmniError::manifest("bm25 field must be a property access".to_string())
147 })?;
148 let text = resolve_to_string(query, params).ok_or_else(|| {
149 OmniError::manifest("bm25 query must resolve to a string".to_string())
150 })?;
151 Ok(SearchMode {
152 bm25: Some((var, prop, text)),
153 ..Default::default()
154 })
155 }
156 IRExpr::Rrf {
157 primary,
158 secondary,
159 k,
160 } => {
161 let limit = ir.limit.ok_or_else(|| {
162 OmniError::manifest("rrf() ordering requires a limit clause".to_string())
163 })? as usize;
164 let k_val = k
165 .as_ref()
166 .and_then(|e| resolve_to_int(e, params))
167 .unwrap_or(60) as u32;
168
169 let primary_mode =
170 extract_sub_search_mode(ir, primary, params, catalog, ir.limit).await?;
171 let secondary_mode =
172 extract_sub_search_mode(ir, secondary, params, catalog, ir.limit).await?;
173
174 Ok(SearchMode {
175 rrf: Some(RrfMode {
176 primary: Box::new(primary_mode),
177 secondary: Box::new(secondary_mode),
178 k: k_val,
179 limit,
180 }),
181 ..Default::default()
182 })
183 }
184 _ => Ok(SearchMode::default()),
185 }
186}
187
188async fn extract_sub_search_mode(
190 ir: &QueryIR,
191 expr: &IRExpr,
192 params: &ParamMap,
193 catalog: &Catalog,
194 limit: Option<u64>,
195) -> Result<SearchMode> {
196 match expr {
197 IRExpr::Nearest {
198 variable,
199 property,
200 query,
201 } => {
202 let vec =
203 resolve_nearest_query_vec(ir, catalog, variable, property, query, params).await?;
204 let k = limit.unwrap_or(100) as usize;
205 Ok(SearchMode {
206 nearest: Some((variable.clone(), property.clone(), vec, k)),
207 ..Default::default()
208 })
209 }
210 IRExpr::Bm25 { field, query } => {
211 let var = match field.as_ref() {
212 IRExpr::PropAccess { variable, .. } => variable.clone(),
213 _ => {
214 return Err(OmniError::manifest(
215 "bm25 field must be a property access".to_string(),
216 ));
217 }
218 };
219 let prop = extract_property(field).ok_or_else(|| {
220 OmniError::manifest("bm25 field must be a property access".to_string())
221 })?;
222 let text = resolve_to_string(query, params).ok_or_else(|| {
223 OmniError::manifest("bm25 query must resolve to a string".to_string())
224 })?;
225 Ok(SearchMode {
226 bm25: Some((var, prop, text)),
227 ..Default::default()
228 })
229 }
230 _ => Ok(SearchMode::default()),
231 }
232}
233
234async fn resolve_nearest_query_vec(
236 ir: &QueryIR,
237 catalog: &Catalog,
238 variable: &str,
239 property: &str,
240 expr: &IRExpr,
241 params: &ParamMap,
242) -> Result<Vec<f32>> {
243 let lit = resolve_literal_or_param(expr, params)?;
244 match lit {
245 Literal::List(_) => literal_to_f32_vec(&lit),
246 Literal::String(text) => {
247 let expected_dim = nearest_property_dimension(ir, catalog, variable, property)?;
248 EmbeddingClient::from_env()?
249 .embed_query_text(&text, expected_dim)
250 .await
251 }
252 _ => Err(OmniError::manifest(
253 "nearest query must be a string or list of floats".to_string(),
254 )),
255 }
256}
257
258fn resolve_literal_or_param(expr: &IRExpr, params: &ParamMap) -> Result<Literal> {
259 Ok(match expr {
260 IRExpr::Literal(lit) => lit.clone(),
261 IRExpr::Param(name) => params
262 .get(name)
263 .cloned()
264 .ok_or_else(|| OmniError::manifest(format!("parameter '{}' not provided", name)))?,
265 _ => {
266 return Err(OmniError::manifest(
267 "nearest query must be a literal or parameter".to_string(),
268 ));
269 }
270 })
271}
272
273fn literal_to_f32_vec(lit: &Literal) -> Result<Vec<f32>> {
275 match lit {
276 Literal::List(items) => items
277 .iter()
278 .map(|item| match item {
279 Literal::Float(f) => Ok(*f as f32),
280 Literal::Integer(n) => Ok(*n as f32),
281 _ => Err(OmniError::manifest(
282 "vector elements must be numeric".to_string(),
283 )),
284 })
285 .collect(),
286 _ => Err(OmniError::manifest(
287 "nearest query must be a list of floats".to_string(),
288 )),
289 }
290}
291
292fn nearest_property_dimension(
293 ir: &QueryIR,
294 catalog: &Catalog,
295 variable: &str,
296 property: &str,
297) -> Result<usize> {
298 let type_name = resolve_binding_type_name(&ir.pipeline, variable).ok_or_else(|| {
299 OmniError::manifest_internal(format!(
300 "nearest() variable '${}' is not bound to a node type in the lowered pipeline",
301 variable
302 ))
303 })?;
304 let node_type = catalog.node_types.get(type_name).ok_or_else(|| {
305 OmniError::manifest_internal(format!(
306 "nearest() binding '${}' resolved unknown node type '{}'",
307 variable, type_name
308 ))
309 })?;
310 let prop = node_type.properties.get(property).ok_or_else(|| {
311 OmniError::manifest_internal(format!(
312 "nearest() property '{}.{}' is missing from the catalog",
313 type_name, property
314 ))
315 })?;
316 match prop.scalar {
317 ScalarType::Vector(dim) if !prop.list => Ok(dim as usize),
318 _ => Err(OmniError::manifest_internal(format!(
319 "nearest() property '{}.{}' is not a scalar vector",
320 type_name, property
321 ))),
322 }
323}
324
325fn resolve_binding_type_name<'a>(pipeline: &'a [IROp], variable: &str) -> Option<&'a str> {
326 for op in pipeline {
327 match op {
328 IROp::NodeScan {
329 variable: bound_var,
330 type_name,
331 ..
332 } if bound_var == variable => return Some(type_name.as_str()),
333 IROp::Expand {
334 dst_var, dst_type, ..
335 } if dst_var == variable => return Some(dst_type.as_str()),
336 IROp::AntiJoin { inner, .. } => {
337 if let Some(type_name) = resolve_binding_type_name(inner, variable) {
338 return Some(type_name);
339 }
340 }
341 _ => {}
342 }
343 }
344 None
345}
346
347pub async fn execute_query(
349 ir: &QueryIR,
350 params: &ParamMap,
351 snapshot: &Snapshot,
352 graph_index: Option<&GraphIndex>,
353 catalog: &Catalog,
354) -> Result<QueryResult> {
355 let search_mode = extract_search_mode(ir, params, catalog).await?;
356
357 if let Some(ref rrf) = search_mode.rrf {
359 return execute_rrf_query(ir, params, snapshot, graph_index, catalog, rrf).await;
360 }
361
362 let mut wide: Option<RecordBatch> = None;
363 execute_pipeline(&ir.pipeline, params, snapshot, graph_index, catalog, &mut wide, &search_mode).await?;
364 let wide_batch = wide.unwrap_or_else(|| RecordBatch::new_empty(Arc::new(Schema::empty())));
365
366 let has_aggregates = ir.return_exprs.iter().any(|p| matches!(&p.expr, IRExpr::Aggregate { .. }));
368 let mut result_batch = project_return(&wide_batch, &ir.return_exprs, params)?;
369
370 if !ir.order_by.is_empty() && !is_search_ordered(&search_mode) {
372 result_batch = if has_aggregates {
373 apply_ordering(result_batch.clone(), &ir.order_by, &result_batch, params)?
374 } else {
375 apply_ordering(result_batch, &ir.order_by, &wide_batch, params)?
376 };
377 }
378
379 if let Some(limit) = ir.limit {
381 let len = result_batch.num_rows().min(limit as usize);
382 result_batch = result_batch.slice(0, len);
383 }
384
385 Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
386}
387
388fn is_search_ordered(search_mode: &SearchMode) -> bool {
390 search_mode.nearest.is_some() || search_mode.bm25.is_some()
391}
392
393async fn execute_rrf_query(
395 ir: &QueryIR,
396 params: &ParamMap,
397 snapshot: &Snapshot,
398 graph_index: Option<&GraphIndex>,
399 catalog: &Catalog,
400 rrf: &RrfMode,
401) -> Result<QueryResult> {
402 let mut primary_wide: Option<RecordBatch> = None;
404 execute_pipeline(
405 &ir.pipeline,
406 params,
407 snapshot,
408 graph_index,
409 catalog,
410 &mut primary_wide,
411 &rrf.primary,
412 )
413 .await?;
414
415 let mut secondary_wide: Option<RecordBatch> = None;
417 execute_pipeline(
418 &ir.pipeline,
419 params,
420 snapshot,
421 graph_index,
422 catalog,
423 &mut secondary_wide,
424 &rrf.secondary,
425 )
426 .await?;
427
428 let primary_var = rrf
431 .primary
432 .nearest
433 .as_ref()
434 .map(|(v, ..)| v.as_str())
435 .or_else(|| rrf.primary.bm25.as_ref().map(|(v, ..)| v.as_str()))
436 .ok_or_else(|| OmniError::manifest("rrf primary must be nearest or bm25".to_string()))?;
437
438 let primary_batch = primary_wide.as_ref().ok_or_else(|| {
439 OmniError::manifest(format!(
440 "rrf primary variable '{}' not in bindings",
441 primary_var
442 ))
443 })?;
444 let secondary_batch = secondary_wide.as_ref().ok_or_else(|| {
445 OmniError::manifest(format!(
446 "rrf secondary variable '{}' not in bindings",
447 primary_var
448 ))
449 })?;
450
451 let id_col_name = format!("{}.id", primary_var);
453 let primary_ids = extract_id_column_by_name(primary_batch, &id_col_name)?;
454 let secondary_ids = extract_id_column_by_name(secondary_batch, &id_col_name)?;
455
456 let mut primary_rank: HashMap<String, usize> = HashMap::new();
457 for (i, id) in primary_ids.iter().enumerate() {
458 primary_rank.entry(id.clone()).or_insert(i);
459 }
460 let mut secondary_rank: HashMap<String, usize> = HashMap::new();
461 for (i, id) in secondary_ids.iter().enumerate() {
462 secondary_rank.entry(id.clone()).or_insert(i);
463 }
464
465 let mut all_ids: Vec<String> = primary_ids.clone();
467 for id in &secondary_ids {
468 if !primary_rank.contains_key(id) {
469 all_ids.push(id.clone());
470 }
471 }
472
473 let k = rrf.k as f64;
475 let mut scored: Vec<(String, f64)> = all_ids
476 .iter()
477 .map(|id| {
478 let p = primary_rank
479 .get(id)
480 .map(|&r| 1.0 / (k + r as f64 + 1.0))
481 .unwrap_or(0.0);
482 let s = secondary_rank
483 .get(id)
484 .map(|&r| 1.0 / (k + r as f64 + 1.0))
485 .unwrap_or(0.0);
486 (id.clone(), p + s)
487 })
488 .collect();
489 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
490 scored.truncate(rrf.limit);
491
492 let winning_ids: Vec<String> = scored.iter().map(|(id, _)| id.clone()).collect();
494
495 let mut id_to_batch_row: HashMap<String, (&RecordBatch, usize)> = HashMap::new();
497 for (i, id) in primary_ids.iter().enumerate() {
498 id_to_batch_row
499 .entry(id.clone())
500 .or_insert((primary_batch, i));
501 }
502 for (i, id) in secondary_ids.iter().enumerate() {
503 id_to_batch_row
504 .entry(id.clone())
505 .or_insert((secondary_batch, i));
506 }
507
508 let fused_batch = build_fused_batch(&winning_ids, &id_to_batch_row, primary_batch.schema())?;
510
511 let result_batch = project_return(&fused_batch, &ir.return_exprs, params)?;
513
514 Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
516}
517
518fn extract_id_column_by_name(batch: &RecordBatch, col_name: &str) -> Result<Vec<String>> {
519 let col = batch
520 .column_by_name(col_name)
521 .ok_or_else(|| OmniError::manifest(format!("batch missing '{}' column for RRF", col_name)))?;
522 let ids = col
523 .as_any()
524 .downcast_ref::<StringArray>()
525 .ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", col_name)))?;
526 Ok((0..ids.len()).map(|i| ids.value(i).to_string()).collect())
527}
528
529fn build_fused_batch(
530 ordered_ids: &[String],
531 id_to_batch_row: &HashMap<String, (&RecordBatch, usize)>,
532 schema: SchemaRef,
533) -> Result<RecordBatch> {
534 if ordered_ids.is_empty() {
535 return Ok(RecordBatch::new_empty(schema));
536 }
537
538 let mut row_slices: Vec<RecordBatch> = Vec::with_capacity(ordered_ids.len());
540 for id in ordered_ids {
541 if let Some(&(batch, row_idx)) = id_to_batch_row.get(id) {
542 row_slices.push(batch.slice(row_idx, 1));
543 }
544 }
545
546 if row_slices.is_empty() {
547 return Ok(RecordBatch::new_empty(schema));
548 }
549
550 let schema = row_slices[0].schema();
551 arrow_select::concat::concat_batches(&schema, &row_slices)
552 .map_err(|e| OmniError::Lance(e.to_string()))
553}
554
555fn is_search_filter(filter: &IRFilter) -> bool {
557 matches!(
558 &filter.left,
559 IRExpr::Search { .. } | IRExpr::Fuzzy { .. } | IRExpr::MatchText { .. }
560 )
561}
562
563fn search_filter_variable(filter: &IRFilter) -> Option<&str> {
565 let field = match &filter.left {
566 IRExpr::Search { field, .. } => field,
567 IRExpr::Fuzzy { field, .. } => field,
568 IRExpr::MatchText { field, .. } => field,
569 _ => return None,
570 };
571 match field.as_ref() {
572 IRExpr::PropAccess { variable, .. } => Some(variable.as_str()),
573 _ => None,
574 }
575}
576
577fn execute_pipeline<'a>(
578 pipeline: &'a [IROp],
579 params: &'a ParamMap,
580 snapshot: &'a Snapshot,
581 graph_index: Option<&'a GraphIndex>,
582 catalog: &'a Catalog,
583 wide: &'a mut Option<RecordBatch>,
584 search_mode: &'a SearchMode,
585) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
586 Box::pin(async move {
587 let mut hoisted_search_filters: HashMap<String, Vec<IRFilter>> = HashMap::new();
589 let mut hoisted_indices: HashSet<usize> = HashSet::new();
590 for (i, op) in pipeline.iter().enumerate() {
591 if let IROp::Filter(filter) = op {
592 if is_search_filter(filter) {
593 if let Some(var) = search_filter_variable(filter) {
594 hoisted_search_filters
595 .entry(var.to_string())
596 .or_default()
597 .push(filter.clone());
598 hoisted_indices.insert(i);
599 }
600 }
601 }
602 }
603
604 for (i, op) in pipeline.iter().enumerate() {
605 if hoisted_indices.contains(&i) {
607 continue;
608 }
609 match op {
610 IROp::NodeScan {
611 variable,
612 type_name,
613 filters,
614 } => {
615 let mut all_filters: Vec<IRFilter> = filters.clone();
617 if let Some(extra) = hoisted_search_filters.get(variable) {
618 all_filters.extend(extra.iter().cloned());
619 }
620 let batch = execute_node_scan(
621 type_name,
622 variable,
623 &all_filters,
624 params,
625 snapshot,
626 catalog,
627 search_mode,
628 )
629 .await?;
630 let prefixed = prefix_batch(&batch, variable)?;
631 *wide = Some(match wide.take() {
632 None => prefixed,
633 Some(existing) => cross_join_batches(&existing, &prefixed)?,
634 });
635 }
636 IROp::Filter(filter) => {
637 if let Some(batch) = wide.as_mut() {
638 apply_filter(batch, filter, params)?;
639 }
640 }
641 IROp::Expand {
642 src_var,
643 dst_var,
644 edge_type,
645 direction,
646 dst_type,
647 min_hops,
648 max_hops,
649 dst_filters,
650 } => {
651 let gi = graph_index.ok_or_else(|| {
652 OmniError::manifest("graph index required for traversal".to_string())
653 })?;
654 if let Some(batch) = wide.as_mut() {
655 execute_expand(
656 batch, gi, snapshot, catalog, src_var, dst_var, edge_type, *direction,
657 dst_type, *min_hops, *max_hops, dst_filters, params,
658 )
659 .await?;
660 }
661 }
662 IROp::AntiJoin { outer_var, inner } => {
663 let gi = graph_index;
664 if let Some(batch) = wide.as_mut() {
665 execute_anti_join(batch, inner, params, snapshot, gi, catalog, outer_var)
666 .await?;
667 }
668 }
669 }
670 }
671 Ok(())
672 })
673}
674
675async fn execute_expand(
677 wide: &mut RecordBatch,
678 graph_index: &GraphIndex,
679 snapshot: &Snapshot,
680 catalog: &Catalog,
681 src_var: &str,
682 dst_var: &str,
683 edge_type: &str,
684 direction: Direction,
685 dst_type: &str,
686 min_hops: u32,
687 max_hops: Option<u32>,
688 dst_filters: &[IRFilter],
689 params: &ParamMap,
690) -> Result<()> {
691 let src_id_col_name = format!("{}.id", src_var);
692 let src_ids = wide
693 .column_by_name(&src_id_col_name)
694 .ok_or_else(|| OmniError::manifest(format!("wide batch missing '{}' column", src_id_col_name)))?
695 .as_any()
696 .downcast_ref::<StringArray>()
697 .ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", src_id_col_name)))?
698 .clone();
699
700 let edge_def = catalog
702 .edge_types
703 .get(edge_type)
704 .ok_or_else(|| OmniError::manifest(format!("unknown edge type '{}'", edge_type)))?;
705
706 let (src_type_name, dst_type_name) = match direction {
707 Direction::Out => (&edge_def.from_type, &edge_def.to_type),
708 Direction::In => (&edge_def.to_type, &edge_def.from_type),
709 };
710
711 let src_type_idx = graph_index
712 .type_index(src_type_name)
713 .ok_or_else(|| OmniError::manifest(format!("no type index for '{}'", src_type_name)))?;
714 let dst_type_idx = graph_index
715 .type_index(dst_type_name)
716 .ok_or_else(|| OmniError::manifest(format!("no type index for '{}'", dst_type_name)))?;
717
718 let adj = match direction {
719 Direction::Out => graph_index.csr(edge_type),
720 Direction::In => graph_index.csc(edge_type),
721 }
722 .ok_or_else(|| OmniError::manifest(format!("no adjacency index for edge '{}'", edge_type)))?;
723
724 let max = max_hops.unwrap_or(min_hops.max(1));
725
726 let same_type = src_type_name == dst_type_name;
727
728 let mut src_indices: Vec<u32> = Vec::new();
732 let mut dst_dense_list: Vec<u32> = Vec::new();
733 for i in 0..src_ids.len() {
734 let src_id = src_ids.value(i);
735 let Some(src_dense) = src_type_idx.to_dense(src_id) else {
736 continue;
737 };
738
739 let mut frontier: Vec<u32> = vec![src_dense];
741 let mut visited: HashSet<u32> = HashSet::new();
742 let mut seen_dst_dense: HashSet<u32> = HashSet::new();
743 if same_type {
747 visited.insert(src_dense);
748 }
749
750 for hop in 1..=max {
751 let mut next_frontier = Vec::new();
752 for &node in &frontier {
753 for &neighbor in adj.neighbors(node) {
754 if !same_type || visited.insert(neighbor) {
755 next_frontier.push(neighbor);
756 if hop >= min_hops && seen_dst_dense.insert(neighbor) {
757 src_indices.push(i as u32);
758 dst_dense_list.push(neighbor);
759 }
760 }
761 }
762 }
763 frontier = next_frontier;
764 if frontier.is_empty() {
765 break;
766 }
767 }
768 }
769
770 let pushdown_sql = build_lance_filter(dst_filters, params);
772 let non_pushable: Vec<&IRFilter> = dst_filters
773 .iter()
774 .filter(|f| ir_filter_to_sql(f, params).is_none())
775 .collect();
776
777 let mut unique_dst_list: Vec<String> = Vec::new();
781 {
782 let mut seen: HashSet<u32> = HashSet::with_capacity(dst_dense_list.len());
783 for &d in &dst_dense_list {
784 if seen.insert(d) {
785 if let Some(id) = dst_type_idx.to_id(d) {
786 unique_dst_list.push(id.to_string());
787 }
788 }
789 }
790 }
791 let dst_batch = hydrate_nodes(
792 snapshot,
793 catalog,
794 dst_type,
795 &unique_dst_list,
796 pushdown_sql.as_deref(),
797 )
798 .await?;
799
800 let dst_batch_id_col = dst_batch
802 .column_by_name("id")
803 .ok_or_else(|| OmniError::manifest("hydrated batch missing 'id' column".to_string()))?
804 .as_any()
805 .downcast_ref::<StringArray>()
806 .ok_or_else(|| OmniError::manifest("hydrated 'id' column is not Utf8".to_string()))?;
807 let mut dense_to_row: Vec<Option<u32>> = vec![None; dst_type_idx.len()];
808 for row in 0..dst_batch_id_col.len() {
809 let id_str = dst_batch_id_col.value(row);
810 if let Some(dense) = dst_type_idx.to_dense(id_str) {
811 dense_to_row[dense as usize] = Some(row as u32);
812 }
813 }
814
815 let mut final_src_indices: Vec<u32> = Vec::new();
817 let mut dst_indices: Vec<u32> = Vec::new();
818 for (src_idx, dst_dense) in src_indices.iter().zip(dst_dense_list.iter()) {
819 if let Some(dst_row) = dense_to_row[*dst_dense as usize] {
820 final_src_indices.push(*src_idx);
821 dst_indices.push(dst_row);
822 }
823 }
824
825 let src_take = UInt32Array::from(final_src_indices);
826 let dst_take = UInt32Array::from(dst_indices);
827 let expanded_wide = take_batch(wide, &src_take)?;
828 let dst_prefixed = prefix_batch(&dst_batch, dst_var)?;
829 let aligned_dst = take_batch(&dst_prefixed, &dst_take)?;
830 *wide = hconcat_batches(&expanded_wide, &aligned_dst)?;
831
832 for f in &non_pushable {
834 apply_filter(wide, f, params)?;
835 }
836
837 Ok(())
838}
839
840async fn hydrate_nodes(
846 snapshot: &Snapshot,
847 catalog: &Catalog,
848 type_name: &str,
849 ids: &[String],
850 extra_filter_sql: Option<&str>,
851) -> Result<RecordBatch> {
852 let node_type = catalog
853 .node_types
854 .get(type_name)
855 .ok_or_else(|| OmniError::manifest(format!("unknown node type '{}'", type_name)))?;
856
857 if ids.is_empty() {
858 return Ok(RecordBatch::new_empty(node_type.arrow_schema.clone()));
859 }
860
861 let table_key = format!("node:{}", type_name);
862 let ds = snapshot.open(&table_key).await?;
863
864 let escaped: Vec<String> = ids
866 .iter()
867 .map(|id| format!("'{}'", id.replace('\'', "''")))
868 .collect();
869 let mut filter_sql = format!("id IN ({})", escaped.join(", "));
870 if let Some(extra) = extra_filter_sql {
871 filter_sql = format!("({}) AND ({})", filter_sql, extra);
872 }
873 let has_blobs = !node_type.blob_properties.is_empty();
874 let non_blob_cols: Vec<&str> = node_type
875 .arrow_schema
876 .fields()
877 .iter()
878 .filter(|f| !node_type.blob_properties.contains(f.name()))
879 .map(|f| f.name().as_str())
880 .collect();
881 let projection = has_blobs.then_some(non_blob_cols.as_slice());
882 let batches = crate::table_store::TableStore::scan_stream(
883 &ds,
884 projection,
885 Some(&filter_sql),
886 None,
887 false,
888 )
889 .await?
890 .try_collect::<Vec<RecordBatch>>()
891 .await
892 .map_err(|e| OmniError::Lance(e.to_string()))?;
893
894 let scan_result = if batches.is_empty() {
895 return Ok(RecordBatch::new_empty(node_type.arrow_schema.clone()));
896 } else if batches.len() == 1 {
897 batches.into_iter().next().unwrap()
898 } else {
899 let schema = batches[0].schema();
900 arrow_select::concat::concat_batches(&schema, &batches)
901 .map_err(|e| OmniError::Lance(e.to_string()))?
902 };
903
904 if has_blobs {
905 return add_null_blob_columns(&scan_result, node_type);
906 }
907 Ok(scan_result)
908}
909
910fn try_bulk_anti_join_mask(
913 wide: &RecordBatch,
914 inner_pipeline: &[IROp],
915 graph_index: Option<&GraphIndex>,
916 catalog: &Catalog,
917 outer_var: &str,
918) -> Option<BooleanArray> {
919 if inner_pipeline.len() != 1 {
920 return None;
921 }
922 let IROp::Expand {
923 src_var,
924 edge_type,
925 direction,
926 dst_filters,
927 ..
928 } = &inner_pipeline[0]
929 else {
930 return None;
931 };
932 if src_var != outer_var {
933 return None;
934 }
935 if !dst_filters.is_empty() {
938 return None;
939 }
940 let gi = graph_index?;
941 let edge_def = catalog.edge_types.get(edge_type.as_str())?;
942
943 let src_type_name = match direction {
944 Direction::Out => &edge_def.from_type,
945 Direction::In => &edge_def.to_type,
946 };
947 let adj = match direction {
948 Direction::Out => gi.csr(edge_type),
949 Direction::In => gi.csc(edge_type),
950 }?;
951 let type_idx = gi.type_index(src_type_name)?;
952
953 let id_col_name = format!("{}.id", outer_var);
954 let outer_ids = wide
955 .column_by_name(&id_col_name)?
956 .as_any()
957 .downcast_ref::<StringArray>()?;
958
959 let keep_mask: Vec<bool> = (0..outer_ids.len())
960 .map(|i| {
961 let id = outer_ids.value(i);
962 match type_idx.to_dense(id) {
963 Some(dense) => !adj.has_neighbors(dense),
964 None => true, }
966 })
967 .collect();
968
969 Some(BooleanArray::from(keep_mask))
970}
971
972async fn execute_anti_join(
974 wide: &mut RecordBatch,
975 inner_pipeline: &[IROp],
976 params: &ParamMap,
977 snapshot: &Snapshot,
978 graph_index: Option<&GraphIndex>,
979 catalog: &Catalog,
980 outer_var: &str,
981) -> Result<()> {
982 if let Some(mask) =
984 try_bulk_anti_join_mask(wide, inner_pipeline, graph_index, catalog, outer_var)
985 {
986 *wide = arrow_select::filter::filter_record_batch(wide, &mask)
987 .map_err(|e| OmniError::Lance(e.to_string()))?;
988 return Ok(());
989 }
990
991 let num_rows = wide.num_rows();
993 let mut keep_mask = vec![true; num_rows];
994
995 for i in 0..num_rows {
996 let single_row = wide.slice(i, 1);
997 let mut inner_wide: Option<RecordBatch> = Some(single_row);
998
999 let no_search = SearchMode::default();
1000 execute_pipeline(
1001 inner_pipeline,
1002 params,
1003 snapshot,
1004 graph_index,
1005 catalog,
1006 &mut inner_wide,
1007 &no_search,
1008 )
1009 .await?;
1010
1011 let has_match = inner_wide
1012 .as_ref()
1013 .map(|batch| batch.num_rows() > 0)
1014 .unwrap_or(false);
1015
1016 if has_match {
1017 keep_mask[i] = false;
1018 }
1019 }
1020
1021 let mask = BooleanArray::from(keep_mask);
1022 *wide = arrow_select::filter::filter_record_batch(wide, &mask)
1023 .map_err(|e| OmniError::Lance(e.to_string()))?;
1024 Ok(())
1025}
1026
1027async fn execute_node_scan(
1029 type_name: &str,
1030 variable: &str,
1031 filters: &[IRFilter],
1032 params: &ParamMap,
1033 snapshot: &Snapshot,
1034 catalog: &Catalog,
1035 search_mode: &SearchMode,
1036) -> Result<RecordBatch> {
1037 let table_key = format!("node:{}", type_name);
1038 let ds = snapshot.open(&table_key).await?;
1039
1040 let filter_sql = build_lance_filter(filters, params);
1042
1043 let node_type = &catalog.node_types[type_name];
1047 let has_blobs = !node_type.blob_properties.is_empty();
1048 let non_blob_cols: Vec<&str> = node_type
1049 .arrow_schema
1050 .fields()
1051 .iter()
1052 .filter(|f| !node_type.blob_properties.contains(f.name()))
1053 .map(|f| f.name().as_str())
1054 .collect();
1055 let projection = has_blobs.then_some(non_blob_cols.as_slice());
1056 let batches = crate::table_store::TableStore::scan_stream_with(
1057 &ds,
1058 projection,
1059 filter_sql.as_deref(),
1060 None,
1061 false,
1062 |scanner| {
1063 for filter in filters {
1065 if is_search_filter(filter) {
1066 if let Some(fts_query) = build_fts_query(&filter.left, params) {
1067 scanner.full_text_search(fts_query).map_err(|e| {
1068 OmniError::Lance(format!("full_text_search filter: {}", e))
1069 })?;
1070 }
1071 }
1072 }
1073
1074 if let Some((ref var, ref prop, ref vec, k)) = search_mode.nearest {
1076 if var == variable {
1077 let query_arr = Float32Array::from(vec.clone());
1078 scanner
1079 .nearest(prop, &query_arr, k)
1080 .map_err(|e| OmniError::Lance(format!("nearest: {}", e)))?;
1081 }
1082 }
1083
1084 if let Some((ref var, ref prop, ref text)) = search_mode.bm25 {
1086 if var == variable {
1087 let fts_query = lance_index::scalar::FullTextSearchQuery::new(text.clone())
1088 .with_column(prop.clone())
1089 .map_err(|e| OmniError::Lance(format!("fts with_column: {}", e)))?;
1090 scanner
1091 .full_text_search(fts_query)
1092 .map_err(|e| OmniError::Lance(format!("full_text_search: {}", e)))?;
1093 }
1094 }
1095 Ok(())
1096 },
1097 )
1098 .await?
1099 .try_collect::<Vec<RecordBatch>>()
1100 .await
1101 .map_err(|e| OmniError::Lance(e.to_string()))?;
1102
1103 let scan_result = if batches.is_empty() {
1104 RecordBatch::new_empty(batches.first().map(|b| b.schema()).unwrap_or_else(|| {
1105 let fields: Vec<_> = node_type
1107 .arrow_schema
1108 .fields()
1109 .iter()
1110 .filter(|f| !node_type.blob_properties.contains(f.name()))
1111 .map(|f| f.as_ref().clone())
1112 .collect();
1113 Arc::new(Schema::new(fields))
1114 }))
1115 } else if batches.len() == 1 {
1116 batches.into_iter().next().unwrap()
1117 } else {
1118 let schema = batches[0].schema();
1119 arrow_select::concat::concat_batches(&schema, &batches)
1120 .map_err(|e| OmniError::Lance(e.to_string()))?
1121 };
1122
1123 if has_blobs {
1125 return add_null_blob_columns(&scan_result, node_type);
1126 }
1127 Ok(scan_result)
1128}
1129
1130fn add_null_blob_columns(
1133 batch: &RecordBatch,
1134 node_type: &omnigraph_compiler::catalog::NodeType,
1135) -> Result<RecordBatch> {
1136 let num_rows = batch.num_rows();
1137 let mut fields = Vec::with_capacity(node_type.arrow_schema.fields().len());
1138 let mut columns: Vec<ArrayRef> = Vec::with_capacity(node_type.arrow_schema.fields().len());
1139
1140 for field in node_type.arrow_schema.fields() {
1141 if node_type.blob_properties.contains(field.name()) {
1142 fields.push(Field::new(field.name(), DataType::Utf8, true));
1143 columns.push(Arc::new(StringArray::from(vec![None::<&str>; num_rows])));
1144 } else if let Some(col) = batch.column_by_name(field.name()) {
1145 let batch_schema = batch.schema();
1146 let batch_field = batch_schema
1147 .field_with_name(field.name())
1148 .map_err(|e| OmniError::Lance(e.to_string()))?;
1149 fields.push(batch_field.clone());
1150 columns.push(col.clone());
1151 }
1152 }
1153
1154 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
1155 .map_err(|e| OmniError::Lance(e.to_string()))
1156}
1157
1158fn build_lance_filter(filters: &[IRFilter], params: &ParamMap) -> Option<String> {
1160 if filters.is_empty() {
1161 return None;
1162 }
1163
1164 let parts: Vec<String> = filters
1165 .iter()
1166 .filter_map(|f| ir_filter_to_sql(f, params))
1167 .collect();
1168
1169 if parts.is_empty() {
1170 return None;
1171 }
1172
1173 Some(parts.join(" AND "))
1174}
1175
1176fn ir_filter_to_sql(filter: &IRFilter, params: &ParamMap) -> Option<String> {
1177 if is_search_filter(filter) {
1180 return None;
1181 }
1182
1183 let left = ir_expr_to_sql(&filter.left, params)?;
1184 let right = ir_expr_to_sql(&filter.right, params)?;
1185 let op = match filter.op {
1186 CompOp::Eq => "=",
1187 CompOp::Ne => "!=",
1188 CompOp::Gt => ">",
1189 CompOp::Lt => "<",
1190 CompOp::Ge => ">=",
1191 CompOp::Le => "<=",
1192 CompOp::Contains => return None, };
1194 Some(format!("{} {} {}", left, op, right))
1195}
1196
1197fn build_fts_query(
1199 expr: &IRExpr,
1200 params: &ParamMap,
1201) -> Option<lance_index::scalar::FullTextSearchQuery> {
1202 match expr {
1203 IRExpr::Search { field, query } => {
1204 let prop = extract_property(field)?;
1205 let q = resolve_to_string(query, params)?;
1206 lance_index::scalar::FullTextSearchQuery::new(q)
1207 .with_column(prop)
1208 .ok()
1209 }
1210 IRExpr::Fuzzy {
1211 field,
1212 query,
1213 max_edits,
1214 } => {
1215 let prop = extract_property(field)?;
1216 let q = resolve_to_string(query, params)?;
1217 let edits = max_edits
1218 .as_ref()
1219 .and_then(|e| resolve_to_int(e, params))
1220 .unwrap_or(2) as u32;
1221 lance_index::scalar::FullTextSearchQuery::new_fuzzy(q, Some(edits))
1222 .with_column(prop)
1223 .ok()
1224 }
1225 IRExpr::MatchText { field, query } => {
1226 let prop = extract_property(field)?;
1228 let q = resolve_to_string(query, params)?;
1229 lance_index::scalar::FullTextSearchQuery::new(q)
1230 .with_column(prop)
1231 .ok()
1232 }
1233 _ => None,
1234 }
1235}
1236
1237fn extract_property(expr: &IRExpr) -> Option<String> {
1239 match expr {
1240 IRExpr::PropAccess { property, .. } => Some(property.clone()),
1241 _ => None,
1242 }
1243}
1244
1245fn resolve_to_string(expr: &IRExpr, params: &ParamMap) -> Option<String> {
1247 match expr {
1248 IRExpr::Literal(Literal::String(s)) => Some(s.clone()),
1249 IRExpr::Param(name) => match params.get(name)? {
1250 Literal::String(s) => Some(s.clone()),
1251 _ => None,
1252 },
1253 _ => None,
1254 }
1255}
1256
1257fn resolve_to_int(expr: &IRExpr, params: &ParamMap) -> Option<i64> {
1259 match expr {
1260 IRExpr::Literal(Literal::Integer(n)) => Some(*n),
1261 IRExpr::Param(name) => match params.get(name)? {
1262 Literal::Integer(n) => Some(*n),
1263 _ => None,
1264 },
1265 _ => None,
1266 }
1267}
1268
1269fn ir_expr_to_sql(expr: &IRExpr, params: &ParamMap) -> Option<String> {
1270 match expr {
1271 IRExpr::PropAccess { property, .. } => Some(property.clone()),
1272 IRExpr::Literal(lit) => Some(literal_to_sql(lit)),
1273 IRExpr::Param(name) => params.get(name).map(literal_to_sql),
1274 _ => None,
1275 }
1276}
1277
1278pub(super) fn literal_to_sql(lit: &Literal) -> String {
1279 match lit {
1280 Literal::Null => "NULL".to_string(),
1281 Literal::String(s) => format!("'{}'", s.replace('\'', "''")),
1282 Literal::Integer(n) => n.to_string(),
1283 Literal::Float(f) => f.to_string(),
1284 Literal::Bool(b) => b.to_string(),
1285 Literal::Date(s) => format!("'{}'", s.replace('\'', "''")),
1286 Literal::DateTime(s) => format!("'{}'", s.replace('\'', "''")),
1287 Literal::List(_) => "NULL".to_string(), }
1289}
1290
1291fn prefix_batch(batch: &RecordBatch, variable: &str) -> Result<RecordBatch> {
1292 let fields: Vec<Field> = batch.schema().fields().iter().map(|f| {
1293 Field::new(format!("{}.{}", variable, f.name()), f.data_type().clone(), f.is_nullable())
1294 }).collect();
1295 let schema = Arc::new(Schema::new(fields));
1296 RecordBatch::try_new(schema, batch.columns().to_vec()).map_err(|e| OmniError::Lance(e.to_string()))
1297}
1298
1299fn cross_join_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
1300 let n = left.num_rows();
1301 let m = right.num_rows();
1302 if n == 0 || m == 0 {
1303 let mut fields: Vec<Field> = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect();
1304 fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
1305 return Ok(RecordBatch::new_empty(Arc::new(Schema::new(fields))));
1306 }
1307 let left_indices: Vec<u32> = (0..n as u32).flat_map(|i| std::iter::repeat(i).take(m)).collect();
1308 let right_indices: Vec<u32> = (0..n).flat_map(|_| 0..m as u32).collect();
1309 let left_expanded = take_batch(left, &UInt32Array::from(left_indices))?;
1310 let right_expanded = take_batch(right, &UInt32Array::from(right_indices))?;
1311 hconcat_batches(&left_expanded, &right_expanded)
1312}
1313
1314fn hconcat_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
1315 let mut fields: Vec<Field> = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect();
1316 if cfg!(debug_assertions) {
1317 let left_schema = left.schema();
1318 let left_names: HashSet<&str> = left_schema.fields().iter().map(|f| f.name().as_str()).collect();
1319 let right_schema = right.schema();
1320 for f in right_schema.fields() {
1321 debug_assert!(!left_names.contains(f.name().as_str()), "hconcat_batches: duplicate column '{}'", f.name());
1322 }
1323 }
1324 fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
1325 let mut columns: Vec<ArrayRef> = left.columns().to_vec();
1326 columns.extend(right.columns().to_vec());
1327 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns).map_err(|e| OmniError::Lance(e.to_string()))
1328}
1329
1330fn take_batch(batch: &RecordBatch, indices: &UInt32Array) -> Result<RecordBatch> {
1331 let columns: Vec<ArrayRef> = batch.columns().iter()
1332 .map(|col| arrow_select::take::take(col.as_ref(), indices, None))
1333 .collect::<std::result::Result<Vec<_>, _>>()
1334 .map_err(|e| OmniError::Lance(e.to_string()))?;
1335 RecordBatch::try_new(batch.schema(), columns).map_err(|e| OmniError::Lance(e.to_string()))
1336}