1use super::*;
2
3use super::projection::{apply_filter, apply_ordering, project_return};
4
5impl Omnigraph {
6 pub async fn query(
8 &self,
9 target: impl Into<ReadTarget>,
10 query_source: &str,
11 query_name: &str,
12 params: &ParamMap,
13 ) -> Result<QueryResult> {
14 self.ensure_schema_state_valid().await?;
15 let resolved = self.resolved_target(target).await?;
16
17 let query_decl = omnigraph_compiler::find_named_query(query_source, query_name)
18 .map_err(|e| OmniError::manifest(e.to_string()))?;
19 let type_ctx = typecheck_query(self.catalog(), &query_decl)?;
20 let ir = lower_query(self.catalog(), &query_decl, &type_ctx)?;
21
22 let needs_graph = ir
23 .pipeline
24 .iter()
25 .any(|op| matches!(op, IROp::Expand { .. } | IROp::AntiJoin { .. }));
26 let graph_index = if needs_graph {
27 Some(self.graph_index_for_resolved(&resolved).await?)
28 } else {
29 None
30 };
31
32 execute_query(
33 &ir,
34 params,
35 &resolved.snapshot,
36 graph_index.as_deref(),
37 self.catalog(),
38 )
39 .await
40 }
41
42 pub async fn run_query_at(
47 &self,
48 version: u64,
49 query_source: &str,
50 query_name: &str,
51 params: &ParamMap,
52 ) -> Result<QueryResult> {
53 self.ensure_schema_state_valid().await?;
54 let snapshot = self.snapshot_at_version(version).await?;
55
56 let query_decl = omnigraph_compiler::find_named_query(query_source, query_name)
57 .map_err(|e| OmniError::manifest(e.to_string()))?;
58 let type_ctx = typecheck_query(self.catalog(), &query_decl)?;
59 let ir = lower_query(self.catalog(), &query_decl, &type_ctx)?;
60
61 let needs_graph = ir
62 .pipeline
63 .iter()
64 .any(|op| matches!(op, IROp::Expand { .. } | IROp::AntiJoin { .. }));
65 let graph_index = if needs_graph {
66 let edge_types = self
67 .catalog()
68 .edge_types
69 .iter()
70 .map(|(name, et)| (name.clone(), (et.from_type.clone(), et.to_type.clone())))
71 .collect();
72 Some(Arc::new(GraphIndex::build(&snapshot, &edge_types).await?))
73 } else {
74 None
75 };
76
77 execute_query(
78 &ir,
79 params,
80 &snapshot,
81 graph_index.as_deref(),
82 self.catalog(),
83 )
84 .await
85 }
86}
87
88#[derive(Debug, Default)]
92struct SearchMode {
93 nearest: Option<(String, String, Vec<f32>, usize)>,
95 bm25: Option<(String, String, String)>,
97 rrf: Option<RrfMode>,
99}
100
101#[derive(Debug)]
102struct RrfMode {
103 primary: Box<SearchMode>,
104 secondary: Box<SearchMode>,
105 k: u32,
106 limit: usize,
107}
108
109async fn extract_search_mode(
111 ir: &QueryIR,
112 params: &ParamMap,
113 catalog: &Catalog,
114) -> Result<SearchMode> {
115 if ir.order_by.is_empty() {
116 return Ok(SearchMode::default());
117 }
118 let ordering = &ir.order_by[0];
119 match &ordering.expr {
120 IRExpr::Nearest {
121 variable,
122 property,
123 query,
124 } => {
125 let vec =
126 resolve_nearest_query_vec(ir, catalog, variable, property, query, params).await?;
127 let k = ir.limit.ok_or_else(|| {
128 OmniError::manifest("nearest() ordering requires a limit clause".to_string())
129 })? as usize;
130 Ok(SearchMode {
131 nearest: Some((variable.clone(), property.clone(), vec, k)),
132 ..Default::default()
133 })
134 }
135 IRExpr::Bm25 { field, query } => {
136 let var = match field.as_ref() {
137 IRExpr::PropAccess { variable, .. } => variable.clone(),
138 _ => {
139 return Err(OmniError::manifest(
140 "bm25 field must be a property access".to_string(),
141 ));
142 }
143 };
144 let prop = extract_property(field).ok_or_else(|| {
145 OmniError::manifest("bm25 field must be a property access".to_string())
146 })?;
147 let text = resolve_to_string(query, params).ok_or_else(|| {
148 OmniError::manifest("bm25 query must resolve to a string".to_string())
149 })?;
150 Ok(SearchMode {
151 bm25: Some((var, prop, text)),
152 ..Default::default()
153 })
154 }
155 IRExpr::Rrf {
156 primary,
157 secondary,
158 k,
159 } => {
160 let limit = ir.limit.ok_or_else(|| {
161 OmniError::manifest("rrf() ordering requires a limit clause".to_string())
162 })? as usize;
163 let k_val = k
164 .as_ref()
165 .and_then(|e| resolve_to_int(e, params))
166 .unwrap_or(60) as u32;
167
168 let primary_mode =
169 extract_sub_search_mode(ir, primary, params, catalog, ir.limit).await?;
170 let secondary_mode =
171 extract_sub_search_mode(ir, secondary, params, catalog, ir.limit).await?;
172
173 Ok(SearchMode {
174 rrf: Some(RrfMode {
175 primary: Box::new(primary_mode),
176 secondary: Box::new(secondary_mode),
177 k: k_val,
178 limit,
179 }),
180 ..Default::default()
181 })
182 }
183 _ => Ok(SearchMode::default()),
184 }
185}
186
187async fn extract_sub_search_mode(
189 ir: &QueryIR,
190 expr: &IRExpr,
191 params: &ParamMap,
192 catalog: &Catalog,
193 limit: Option<u64>,
194) -> Result<SearchMode> {
195 match expr {
196 IRExpr::Nearest {
197 variable,
198 property,
199 query,
200 } => {
201 let vec =
202 resolve_nearest_query_vec(ir, catalog, variable, property, query, params).await?;
203 let k = limit.unwrap_or(100) as usize;
204 Ok(SearchMode {
205 nearest: Some((variable.clone(), property.clone(), vec, k)),
206 ..Default::default()
207 })
208 }
209 IRExpr::Bm25 { field, query } => {
210 let var = match field.as_ref() {
211 IRExpr::PropAccess { variable, .. } => variable.clone(),
212 _ => {
213 return Err(OmniError::manifest(
214 "bm25 field must be a property access".to_string(),
215 ));
216 }
217 };
218 let prop = extract_property(field).ok_or_else(|| {
219 OmniError::manifest("bm25 field must be a property access".to_string())
220 })?;
221 let text = resolve_to_string(query, params).ok_or_else(|| {
222 OmniError::manifest("bm25 query must resolve to a string".to_string())
223 })?;
224 Ok(SearchMode {
225 bm25: Some((var, prop, text)),
226 ..Default::default()
227 })
228 }
229 _ => Ok(SearchMode::default()),
230 }
231}
232
233async fn resolve_nearest_query_vec(
235 ir: &QueryIR,
236 catalog: &Catalog,
237 variable: &str,
238 property: &str,
239 expr: &IRExpr,
240 params: &ParamMap,
241) -> Result<Vec<f32>> {
242 let lit = resolve_literal_or_param(expr, params)?;
243 match lit {
244 Literal::List(_) => literal_to_f32_vec(&lit),
245 Literal::String(text) => {
246 let expected_dim = nearest_property_dimension(ir, catalog, variable, property)?;
247 EmbeddingClient::from_env()?
248 .embed_query_text(&text, expected_dim)
249 .await
250 }
251 _ => Err(OmniError::manifest(
252 "nearest query must be a string or list of floats".to_string(),
253 )),
254 }
255}
256
257fn resolve_literal_or_param(expr: &IRExpr, params: &ParamMap) -> Result<Literal> {
258 Ok(match expr {
259 IRExpr::Literal(lit) => lit.clone(),
260 IRExpr::Param(name) => params
261 .get(name)
262 .cloned()
263 .ok_or_else(|| OmniError::manifest(format!("parameter '{}' not provided", name)))?,
264 _ => {
265 return Err(OmniError::manifest(
266 "nearest query must be a literal or parameter".to_string(),
267 ));
268 }
269 })
270}
271
272fn literal_to_f32_vec(lit: &Literal) -> Result<Vec<f32>> {
274 match lit {
275 Literal::List(items) => items
276 .iter()
277 .map(|item| match item {
278 Literal::Float(f) => Ok(*f as f32),
279 Literal::Integer(n) => Ok(*n as f32),
280 _ => Err(OmniError::manifest(
281 "vector elements must be numeric".to_string(),
282 )),
283 })
284 .collect(),
285 _ => Err(OmniError::manifest(
286 "nearest query must be a list of floats".to_string(),
287 )),
288 }
289}
290
291fn nearest_property_dimension(
292 ir: &QueryIR,
293 catalog: &Catalog,
294 variable: &str,
295 property: &str,
296) -> Result<usize> {
297 let type_name = resolve_binding_type_name(&ir.pipeline, variable).ok_or_else(|| {
298 OmniError::manifest_internal(format!(
299 "nearest() variable '${}' is not bound to a node type in the lowered pipeline",
300 variable
301 ))
302 })?;
303 let node_type = catalog.node_types.get(type_name).ok_or_else(|| {
304 OmniError::manifest_internal(format!(
305 "nearest() binding '${}' resolved unknown node type '{}'",
306 variable, type_name
307 ))
308 })?;
309 let prop = node_type.properties.get(property).ok_or_else(|| {
310 OmniError::manifest_internal(format!(
311 "nearest() property '{}.{}' is missing from the catalog",
312 type_name, property
313 ))
314 })?;
315 match prop.scalar {
316 ScalarType::Vector(dim) if !prop.list => Ok(dim as usize),
317 _ => Err(OmniError::manifest_internal(format!(
318 "nearest() property '{}.{}' is not a scalar vector",
319 type_name, property
320 ))),
321 }
322}
323
324fn resolve_binding_type_name<'a>(pipeline: &'a [IROp], variable: &str) -> Option<&'a str> {
325 for op in pipeline {
326 match op {
327 IROp::NodeScan {
328 variable: bound_var,
329 type_name,
330 ..
331 } if bound_var == variable => return Some(type_name.as_str()),
332 IROp::Expand {
333 dst_var, dst_type, ..
334 } if dst_var == variable => return Some(dst_type.as_str()),
335 IROp::AntiJoin { inner, .. } => {
336 if let Some(type_name) = resolve_binding_type_name(inner, variable) {
337 return Some(type_name);
338 }
339 }
340 _ => {}
341 }
342 }
343 None
344}
345
346pub async fn execute_query(
348 ir: &QueryIR,
349 params: &ParamMap,
350 snapshot: &Snapshot,
351 graph_index: Option<&GraphIndex>,
352 catalog: &Catalog,
353) -> Result<QueryResult> {
354 let search_mode = extract_search_mode(ir, params, catalog).await?;
355
356 if let Some(ref rrf) = search_mode.rrf {
358 return execute_rrf_query(ir, params, snapshot, graph_index, catalog, rrf).await;
359 }
360
361 let mut wide: Option<RecordBatch> = None;
362 execute_pipeline(&ir.pipeline, params, snapshot, graph_index, catalog, &mut wide, &search_mode).await?;
363 let wide_batch = wide.unwrap_or_else(|| RecordBatch::new_empty(Arc::new(Schema::empty())));
364
365 let has_aggregates = ir.return_exprs.iter().any(|p| matches!(&p.expr, IRExpr::Aggregate { .. }));
367 let mut result_batch = project_return(&wide_batch, &ir.return_exprs, params)?;
368
369 if !ir.order_by.is_empty() && !is_search_ordered(&search_mode) {
371 result_batch = if has_aggregates {
372 apply_ordering(result_batch.clone(), &ir.order_by, &result_batch, params)?
373 } else {
374 apply_ordering(result_batch, &ir.order_by, &wide_batch, params)?
375 };
376 }
377
378 if let Some(limit) = ir.limit {
380 let len = result_batch.num_rows().min(limit as usize);
381 result_batch = result_batch.slice(0, len);
382 }
383
384 Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
385}
386
387fn is_search_ordered(search_mode: &SearchMode) -> bool {
389 search_mode.nearest.is_some() || search_mode.bm25.is_some()
390}
391
392async fn execute_rrf_query(
394 ir: &QueryIR,
395 params: &ParamMap,
396 snapshot: &Snapshot,
397 graph_index: Option<&GraphIndex>,
398 catalog: &Catalog,
399 rrf: &RrfMode,
400) -> Result<QueryResult> {
401 let mut primary_wide: Option<RecordBatch> = None;
403 execute_pipeline(
404 &ir.pipeline,
405 params,
406 snapshot,
407 graph_index,
408 catalog,
409 &mut primary_wide,
410 &rrf.primary,
411 )
412 .await?;
413
414 let mut secondary_wide: Option<RecordBatch> = None;
416 execute_pipeline(
417 &ir.pipeline,
418 params,
419 snapshot,
420 graph_index,
421 catalog,
422 &mut secondary_wide,
423 &rrf.secondary,
424 )
425 .await?;
426
427 let primary_var = rrf
430 .primary
431 .nearest
432 .as_ref()
433 .map(|(v, ..)| v.as_str())
434 .or_else(|| rrf.primary.bm25.as_ref().map(|(v, ..)| v.as_str()))
435 .ok_or_else(|| OmniError::manifest("rrf primary must be nearest or bm25".to_string()))?;
436
437 let primary_batch = primary_wide.as_ref().ok_or_else(|| {
438 OmniError::manifest(format!(
439 "rrf primary variable '{}' not in bindings",
440 primary_var
441 ))
442 })?;
443 let secondary_batch = secondary_wide.as_ref().ok_or_else(|| {
444 OmniError::manifest(format!(
445 "rrf secondary variable '{}' not in bindings",
446 primary_var
447 ))
448 })?;
449
450 let id_col_name = format!("{}.id", primary_var);
452 let primary_ids = extract_id_column_by_name(primary_batch, &id_col_name)?;
453 let secondary_ids = extract_id_column_by_name(secondary_batch, &id_col_name)?;
454
455 let mut primary_rank: HashMap<String, usize> = HashMap::new();
456 for (i, id) in primary_ids.iter().enumerate() {
457 primary_rank.entry(id.clone()).or_insert(i);
458 }
459 let mut secondary_rank: HashMap<String, usize> = HashMap::new();
460 for (i, id) in secondary_ids.iter().enumerate() {
461 secondary_rank.entry(id.clone()).or_insert(i);
462 }
463
464 let mut all_ids: Vec<String> = primary_ids.clone();
466 for id in &secondary_ids {
467 if !primary_rank.contains_key(id) {
468 all_ids.push(id.clone());
469 }
470 }
471
472 let k = rrf.k as f64;
474 let mut scored: Vec<(String, f64)> = all_ids
475 .iter()
476 .map(|id| {
477 let p = primary_rank
478 .get(id)
479 .map(|&r| 1.0 / (k + r as f64 + 1.0))
480 .unwrap_or(0.0);
481 let s = secondary_rank
482 .get(id)
483 .map(|&r| 1.0 / (k + r as f64 + 1.0))
484 .unwrap_or(0.0);
485 (id.clone(), p + s)
486 })
487 .collect();
488 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
489 scored.truncate(rrf.limit);
490
491 let winning_ids: Vec<String> = scored.iter().map(|(id, _)| id.clone()).collect();
493
494 let mut id_to_batch_row: HashMap<String, (&RecordBatch, usize)> = HashMap::new();
496 for (i, id) in primary_ids.iter().enumerate() {
497 id_to_batch_row
498 .entry(id.clone())
499 .or_insert((primary_batch, i));
500 }
501 for (i, id) in secondary_ids.iter().enumerate() {
502 id_to_batch_row
503 .entry(id.clone())
504 .or_insert((secondary_batch, i));
505 }
506
507 let fused_batch = build_fused_batch(&winning_ids, &id_to_batch_row, primary_batch.schema())?;
509
510 let result_batch = project_return(&fused_batch, &ir.return_exprs, params)?;
512
513 Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
515}
516
517fn extract_id_column_by_name(batch: &RecordBatch, col_name: &str) -> Result<Vec<String>> {
518 let col = batch
519 .column_by_name(col_name)
520 .ok_or_else(|| OmniError::manifest(format!("batch missing '{}' column for RRF", col_name)))?;
521 let ids = col
522 .as_any()
523 .downcast_ref::<StringArray>()
524 .ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", col_name)))?;
525 Ok((0..ids.len()).map(|i| ids.value(i).to_string()).collect())
526}
527
528fn build_fused_batch(
529 ordered_ids: &[String],
530 id_to_batch_row: &HashMap<String, (&RecordBatch, usize)>,
531 schema: SchemaRef,
532) -> Result<RecordBatch> {
533 if ordered_ids.is_empty() {
534 return Ok(RecordBatch::new_empty(schema));
535 }
536
537 let mut row_slices: Vec<RecordBatch> = Vec::with_capacity(ordered_ids.len());
539 for id in ordered_ids {
540 if let Some(&(batch, row_idx)) = id_to_batch_row.get(id) {
541 row_slices.push(batch.slice(row_idx, 1));
542 }
543 }
544
545 if row_slices.is_empty() {
546 return Ok(RecordBatch::new_empty(schema));
547 }
548
549 let schema = row_slices[0].schema();
550 arrow_select::concat::concat_batches(&schema, &row_slices)
551 .map_err(|e| OmniError::Lance(e.to_string()))
552}
553
554fn is_search_filter(filter: &IRFilter) -> bool {
556 matches!(
557 &filter.left,
558 IRExpr::Search { .. } | IRExpr::Fuzzy { .. } | IRExpr::MatchText { .. }
559 )
560}
561
562fn search_filter_variable(filter: &IRFilter) -> Option<&str> {
564 let field = match &filter.left {
565 IRExpr::Search { field, .. } => field,
566 IRExpr::Fuzzy { field, .. } => field,
567 IRExpr::MatchText { field, .. } => field,
568 _ => return None,
569 };
570 match field.as_ref() {
571 IRExpr::PropAccess { variable, .. } => Some(variable.as_str()),
572 _ => None,
573 }
574}
575
576fn execute_pipeline<'a>(
577 pipeline: &'a [IROp],
578 params: &'a ParamMap,
579 snapshot: &'a Snapshot,
580 graph_index: Option<&'a GraphIndex>,
581 catalog: &'a Catalog,
582 wide: &'a mut Option<RecordBatch>,
583 search_mode: &'a SearchMode,
584) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
585 Box::pin(async move {
586 let mut hoisted_search_filters: HashMap<String, Vec<IRFilter>> = HashMap::new();
588 let mut hoisted_indices: HashSet<usize> = HashSet::new();
589 for (i, op) in pipeline.iter().enumerate() {
590 if let IROp::Filter(filter) = op {
591 if is_search_filter(filter) {
592 if let Some(var) = search_filter_variable(filter) {
593 hoisted_search_filters
594 .entry(var.to_string())
595 .or_default()
596 .push(filter.clone());
597 hoisted_indices.insert(i);
598 }
599 }
600 }
601 }
602
603 for (i, op) in pipeline.iter().enumerate() {
604 if hoisted_indices.contains(&i) {
606 continue;
607 }
608 match op {
609 IROp::NodeScan {
610 variable,
611 type_name,
612 filters,
613 } => {
614 let mut all_filters: Vec<IRFilter> = filters.clone();
616 if let Some(extra) = hoisted_search_filters.get(variable) {
617 all_filters.extend(extra.iter().cloned());
618 }
619 let batch = execute_node_scan(
620 type_name,
621 variable,
622 &all_filters,
623 params,
624 snapshot,
625 catalog,
626 search_mode,
627 )
628 .await?;
629 let prefixed = prefix_batch(&batch, variable)?;
630 *wide = Some(match wide.take() {
631 None => prefixed,
632 Some(existing) => cross_join_batches(&existing, &prefixed)?,
633 });
634 }
635 IROp::Filter(filter) => {
636 if let Some(batch) = wide.as_mut() {
637 apply_filter(batch, filter, params)?;
638 }
639 }
640 IROp::Expand {
641 src_var,
642 dst_var,
643 edge_type,
644 direction,
645 dst_type,
646 min_hops,
647 max_hops,
648 dst_filters,
649 } => {
650 let gi = graph_index.ok_or_else(|| {
651 OmniError::manifest("graph index required for traversal".to_string())
652 })?;
653 if let Some(batch) = wide.as_mut() {
654 execute_expand(
655 batch, gi, snapshot, catalog, src_var, dst_var, edge_type, *direction,
656 dst_type, *min_hops, *max_hops, dst_filters, params,
657 )
658 .await?;
659 }
660 }
661 IROp::AntiJoin { outer_var, inner } => {
662 let gi = graph_index;
663 if let Some(batch) = wide.as_mut() {
664 execute_anti_join(batch, inner, params, snapshot, gi, catalog, outer_var)
665 .await?;
666 }
667 }
668 }
669 }
670 Ok(())
671 })
672}
673
674async fn execute_expand(
676 wide: &mut RecordBatch,
677 graph_index: &GraphIndex,
678 snapshot: &Snapshot,
679 catalog: &Catalog,
680 src_var: &str,
681 dst_var: &str,
682 edge_type: &str,
683 direction: Direction,
684 dst_type: &str,
685 min_hops: u32,
686 max_hops: Option<u32>,
687 dst_filters: &[IRFilter],
688 params: &ParamMap,
689) -> Result<()> {
690 let src_id_col_name = format!("{}.id", src_var);
691 let src_ids = wide
692 .column_by_name(&src_id_col_name)
693 .ok_or_else(|| OmniError::manifest(format!("wide batch missing '{}' column", src_id_col_name)))?
694 .as_any()
695 .downcast_ref::<StringArray>()
696 .ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", src_id_col_name)))?
697 .clone();
698
699 let edge_def = catalog
701 .edge_types
702 .get(edge_type)
703 .ok_or_else(|| OmniError::manifest(format!("unknown edge type '{}'", edge_type)))?;
704
705 let (src_type_name, dst_type_name) = match direction {
706 Direction::Out => (&edge_def.from_type, &edge_def.to_type),
707 Direction::In => (&edge_def.to_type, &edge_def.from_type),
708 };
709
710 let src_type_idx = graph_index
711 .type_index(src_type_name)
712 .ok_or_else(|| OmniError::manifest(format!("no type index for '{}'", src_type_name)))?;
713 let dst_type_idx = graph_index
714 .type_index(dst_type_name)
715 .ok_or_else(|| OmniError::manifest(format!("no type index for '{}'", dst_type_name)))?;
716
717 let adj = match direction {
718 Direction::Out => graph_index.csr(edge_type),
719 Direction::In => graph_index.csc(edge_type),
720 }
721 .ok_or_else(|| OmniError::manifest(format!("no adjacency index for edge '{}'", edge_type)))?;
722
723 let max = max_hops.unwrap_or(min_hops.max(1));
724
725 let same_type = src_type_name == dst_type_name;
726
727 let mut src_indices: Vec<u32> = Vec::new();
731 let mut dst_dense_list: Vec<u32> = Vec::new();
732 for i in 0..src_ids.len() {
733 let src_id = src_ids.value(i);
734 let Some(src_dense) = src_type_idx.to_dense(src_id) else {
735 continue;
736 };
737
738 let mut frontier: Vec<u32> = vec![src_dense];
740 let mut visited: HashSet<u32> = HashSet::new();
741 let mut seen_dst_dense: HashSet<u32> = HashSet::new();
742 if same_type {
746 visited.insert(src_dense);
747 }
748
749 for hop in 1..=max {
750 let mut next_frontier = Vec::new();
751 for &node in &frontier {
752 for &neighbor in adj.neighbors(node) {
753 if !same_type || visited.insert(neighbor) {
754 next_frontier.push(neighbor);
755 if hop >= min_hops && seen_dst_dense.insert(neighbor) {
756 src_indices.push(i as u32);
757 dst_dense_list.push(neighbor);
758 }
759 }
760 }
761 }
762 frontier = next_frontier;
763 if frontier.is_empty() {
764 break;
765 }
766 }
767 }
768
769 let pushdown_sql = build_lance_filter(dst_filters, params);
771 let non_pushable: Vec<&IRFilter> = dst_filters
772 .iter()
773 .filter(|f| ir_filter_to_sql(f, params).is_none())
774 .collect();
775
776 let mut unique_dst_list: Vec<String> = Vec::new();
780 {
781 let mut seen: HashSet<u32> = HashSet::with_capacity(dst_dense_list.len());
782 for &d in &dst_dense_list {
783 if seen.insert(d) {
784 if let Some(id) = dst_type_idx.to_id(d) {
785 unique_dst_list.push(id.to_string());
786 }
787 }
788 }
789 }
790 let dst_batch = hydrate_nodes(
791 snapshot,
792 catalog,
793 dst_type,
794 &unique_dst_list,
795 pushdown_sql.as_deref(),
796 )
797 .await?;
798
799 let dst_batch_id_col = dst_batch
801 .column_by_name("id")
802 .ok_or_else(|| OmniError::manifest("hydrated batch missing 'id' column".to_string()))?
803 .as_any()
804 .downcast_ref::<StringArray>()
805 .ok_or_else(|| OmniError::manifest("hydrated 'id' column is not Utf8".to_string()))?;
806 let mut dense_to_row: Vec<Option<u32>> = vec![None; dst_type_idx.len()];
807 for row in 0..dst_batch_id_col.len() {
808 let id_str = dst_batch_id_col.value(row);
809 if let Some(dense) = dst_type_idx.to_dense(id_str) {
810 dense_to_row[dense as usize] = Some(row as u32);
811 }
812 }
813
814 let mut final_src_indices: Vec<u32> = Vec::new();
816 let mut dst_indices: Vec<u32> = Vec::new();
817 for (src_idx, dst_dense) in src_indices.iter().zip(dst_dense_list.iter()) {
818 if let Some(dst_row) = dense_to_row[*dst_dense as usize] {
819 final_src_indices.push(*src_idx);
820 dst_indices.push(dst_row);
821 }
822 }
823
824 let src_take = UInt32Array::from(final_src_indices);
825 let dst_take = UInt32Array::from(dst_indices);
826 let expanded_wide = take_batch(wide, &src_take)?;
827 let dst_prefixed = prefix_batch(&dst_batch, dst_var)?;
828 let aligned_dst = take_batch(&dst_prefixed, &dst_take)?;
829 *wide = hconcat_batches(&expanded_wide, &aligned_dst)?;
830
831 for f in &non_pushable {
833 apply_filter(wide, f, params)?;
834 }
835
836 Ok(())
837}
838
839async fn hydrate_nodes(
845 snapshot: &Snapshot,
846 catalog: &Catalog,
847 type_name: &str,
848 ids: &[String],
849 extra_filter_sql: Option<&str>,
850) -> Result<RecordBatch> {
851 let node_type = catalog
852 .node_types
853 .get(type_name)
854 .ok_or_else(|| OmniError::manifest(format!("unknown node type '{}'", type_name)))?;
855
856 if ids.is_empty() {
857 return Ok(RecordBatch::new_empty(node_type.arrow_schema.clone()));
858 }
859
860 let table_key = format!("node:{}", type_name);
861 let ds = snapshot.open(&table_key).await?;
862
863 let escaped: Vec<String> = ids
865 .iter()
866 .map(|id| format!("'{}'", id.replace('\'', "''")))
867 .collect();
868 let mut filter_sql = format!("id IN ({})", escaped.join(", "));
869 if let Some(extra) = extra_filter_sql {
870 filter_sql = format!("({}) AND ({})", filter_sql, extra);
871 }
872 let has_blobs = !node_type.blob_properties.is_empty();
873 let non_blob_cols: Vec<&str> = node_type
874 .arrow_schema
875 .fields()
876 .iter()
877 .filter(|f| !node_type.blob_properties.contains(f.name()))
878 .map(|f| f.name().as_str())
879 .collect();
880 let projection = has_blobs.then_some(non_blob_cols.as_slice());
881 let batches = crate::table_store::TableStore::scan_stream(
882 &ds,
883 projection,
884 Some(&filter_sql),
885 None,
886 false,
887 )
888 .await?
889 .try_collect::<Vec<RecordBatch>>()
890 .await
891 .map_err(|e| OmniError::Lance(e.to_string()))?;
892
893 let scan_result = if batches.is_empty() {
894 return Ok(RecordBatch::new_empty(node_type.arrow_schema.clone()));
895 } else if batches.len() == 1 {
896 batches.into_iter().next().unwrap()
897 } else {
898 let schema = batches[0].schema();
899 arrow_select::concat::concat_batches(&schema, &batches)
900 .map_err(|e| OmniError::Lance(e.to_string()))?
901 };
902
903 if has_blobs {
904 return add_null_blob_columns(&scan_result, node_type);
905 }
906 Ok(scan_result)
907}
908
909fn try_bulk_anti_join_mask(
912 wide: &RecordBatch,
913 inner_pipeline: &[IROp],
914 graph_index: Option<&GraphIndex>,
915 catalog: &Catalog,
916 outer_var: &str,
917) -> Option<BooleanArray> {
918 if inner_pipeline.len() != 1 {
919 return None;
920 }
921 let IROp::Expand {
922 src_var,
923 edge_type,
924 direction,
925 dst_filters,
926 ..
927 } = &inner_pipeline[0]
928 else {
929 return None;
930 };
931 if src_var != outer_var {
932 return None;
933 }
934 if !dst_filters.is_empty() {
937 return None;
938 }
939 let gi = graph_index?;
940 let edge_def = catalog.edge_types.get(edge_type.as_str())?;
941
942 let src_type_name = match direction {
943 Direction::Out => &edge_def.from_type,
944 Direction::In => &edge_def.to_type,
945 };
946 let adj = match direction {
947 Direction::Out => gi.csr(edge_type),
948 Direction::In => gi.csc(edge_type),
949 }?;
950 let type_idx = gi.type_index(src_type_name)?;
951
952 let id_col_name = format!("{}.id", outer_var);
953 let outer_ids = wide
954 .column_by_name(&id_col_name)?
955 .as_any()
956 .downcast_ref::<StringArray>()?;
957
958 let keep_mask: Vec<bool> = (0..outer_ids.len())
959 .map(|i| {
960 let id = outer_ids.value(i);
961 match type_idx.to_dense(id) {
962 Some(dense) => !adj.has_neighbors(dense),
963 None => true, }
965 })
966 .collect();
967
968 Some(BooleanArray::from(keep_mask))
969}
970
971async fn execute_anti_join(
973 wide: &mut RecordBatch,
974 inner_pipeline: &[IROp],
975 params: &ParamMap,
976 snapshot: &Snapshot,
977 graph_index: Option<&GraphIndex>,
978 catalog: &Catalog,
979 outer_var: &str,
980) -> Result<()> {
981 if let Some(mask) =
983 try_bulk_anti_join_mask(wide, inner_pipeline, graph_index, catalog, outer_var)
984 {
985 *wide = arrow_select::filter::filter_record_batch(wide, &mask)
986 .map_err(|e| OmniError::Lance(e.to_string()))?;
987 return Ok(());
988 }
989
990 let num_rows = wide.num_rows();
992 let mut keep_mask = vec![true; num_rows];
993
994 for i in 0..num_rows {
995 let single_row = wide.slice(i, 1);
996 let mut inner_wide: Option<RecordBatch> = Some(single_row);
997
998 let no_search = SearchMode::default();
999 execute_pipeline(
1000 inner_pipeline,
1001 params,
1002 snapshot,
1003 graph_index,
1004 catalog,
1005 &mut inner_wide,
1006 &no_search,
1007 )
1008 .await?;
1009
1010 let has_match = inner_wide
1011 .as_ref()
1012 .map(|batch| batch.num_rows() > 0)
1013 .unwrap_or(false);
1014
1015 if has_match {
1016 keep_mask[i] = false;
1017 }
1018 }
1019
1020 let mask = BooleanArray::from(keep_mask);
1021 *wide = arrow_select::filter::filter_record_batch(wide, &mask)
1022 .map_err(|e| OmniError::Lance(e.to_string()))?;
1023 Ok(())
1024}
1025
1026async fn execute_node_scan(
1028 type_name: &str,
1029 variable: &str,
1030 filters: &[IRFilter],
1031 params: &ParamMap,
1032 snapshot: &Snapshot,
1033 catalog: &Catalog,
1034 search_mode: &SearchMode,
1035) -> Result<RecordBatch> {
1036 let table_key = format!("node:{}", type_name);
1037 let ds = snapshot.open(&table_key).await?;
1038
1039 let filter_sql = build_lance_filter(filters, params);
1041
1042 let node_type = &catalog.node_types[type_name];
1046 let has_blobs = !node_type.blob_properties.is_empty();
1047 let non_blob_cols: Vec<&str> = node_type
1048 .arrow_schema
1049 .fields()
1050 .iter()
1051 .filter(|f| !node_type.blob_properties.contains(f.name()))
1052 .map(|f| f.name().as_str())
1053 .collect();
1054 let projection = has_blobs.then_some(non_blob_cols.as_slice());
1055 let batches = crate::table_store::TableStore::scan_stream_with(
1056 &ds,
1057 projection,
1058 filter_sql.as_deref(),
1059 None,
1060 false,
1061 |scanner| {
1062 for filter in filters {
1064 if is_search_filter(filter) {
1065 if let Some(fts_query) = build_fts_query(&filter.left, params) {
1066 scanner.full_text_search(fts_query).map_err(|e| {
1067 OmniError::Lance(format!("full_text_search filter: {}", e))
1068 })?;
1069 }
1070 }
1071 }
1072
1073 if let Some((ref var, ref prop, ref vec, k)) = search_mode.nearest {
1075 if var == variable {
1076 let query_arr = Float32Array::from(vec.clone());
1077 scanner
1078 .nearest(prop, &query_arr, k)
1079 .map_err(|e| OmniError::Lance(format!("nearest: {}", e)))?;
1080 }
1081 }
1082
1083 if let Some((ref var, ref prop, ref text)) = search_mode.bm25 {
1085 if var == variable {
1086 let fts_query = lance_index::scalar::FullTextSearchQuery::new(text.clone())
1087 .with_column(prop.clone())
1088 .map_err(|e| OmniError::Lance(format!("fts with_column: {}", e)))?;
1089 scanner
1090 .full_text_search(fts_query)
1091 .map_err(|e| OmniError::Lance(format!("full_text_search: {}", e)))?;
1092 }
1093 }
1094 Ok(())
1095 },
1096 )
1097 .await?
1098 .try_collect::<Vec<RecordBatch>>()
1099 .await
1100 .map_err(|e| OmniError::Lance(e.to_string()))?;
1101
1102 let scan_result = if batches.is_empty() {
1103 RecordBatch::new_empty(batches.first().map(|b| b.schema()).unwrap_or_else(|| {
1104 let fields: Vec<_> = node_type
1106 .arrow_schema
1107 .fields()
1108 .iter()
1109 .filter(|f| !node_type.blob_properties.contains(f.name()))
1110 .map(|f| f.as_ref().clone())
1111 .collect();
1112 Arc::new(Schema::new(fields))
1113 }))
1114 } else if batches.len() == 1 {
1115 batches.into_iter().next().unwrap()
1116 } else {
1117 let schema = batches[0].schema();
1118 arrow_select::concat::concat_batches(&schema, &batches)
1119 .map_err(|e| OmniError::Lance(e.to_string()))?
1120 };
1121
1122 if has_blobs {
1124 return add_null_blob_columns(&scan_result, node_type);
1125 }
1126 Ok(scan_result)
1127}
1128
1129fn add_null_blob_columns(
1132 batch: &RecordBatch,
1133 node_type: &omnigraph_compiler::catalog::NodeType,
1134) -> Result<RecordBatch> {
1135 let num_rows = batch.num_rows();
1136 let mut fields = Vec::with_capacity(node_type.arrow_schema.fields().len());
1137 let mut columns: Vec<ArrayRef> = Vec::with_capacity(node_type.arrow_schema.fields().len());
1138
1139 for field in node_type.arrow_schema.fields() {
1140 if node_type.blob_properties.contains(field.name()) {
1141 fields.push(Field::new(field.name(), DataType::Utf8, true));
1142 columns.push(Arc::new(StringArray::from(vec![None::<&str>; num_rows])));
1143 } else if let Some(col) = batch.column_by_name(field.name()) {
1144 let batch_schema = batch.schema();
1145 let batch_field = batch_schema
1146 .field_with_name(field.name())
1147 .map_err(|e| OmniError::Lance(e.to_string()))?;
1148 fields.push(batch_field.clone());
1149 columns.push(col.clone());
1150 }
1151 }
1152
1153 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
1154 .map_err(|e| OmniError::Lance(e.to_string()))
1155}
1156
1157fn build_lance_filter(filters: &[IRFilter], params: &ParamMap) -> Option<String> {
1159 if filters.is_empty() {
1160 return None;
1161 }
1162
1163 let parts: Vec<String> = filters
1164 .iter()
1165 .filter_map(|f| ir_filter_to_sql(f, params))
1166 .collect();
1167
1168 if parts.is_empty() {
1169 return None;
1170 }
1171
1172 Some(parts.join(" AND "))
1173}
1174
1175fn ir_filter_to_sql(filter: &IRFilter, params: &ParamMap) -> Option<String> {
1176 if is_search_filter(filter) {
1179 return None;
1180 }
1181
1182 let left = ir_expr_to_sql(&filter.left, params)?;
1183 let right = ir_expr_to_sql(&filter.right, params)?;
1184 let op = match filter.op {
1185 CompOp::Eq => "=",
1186 CompOp::Ne => "!=",
1187 CompOp::Gt => ">",
1188 CompOp::Lt => "<",
1189 CompOp::Ge => ">=",
1190 CompOp::Le => "<=",
1191 CompOp::Contains => return None, };
1193 Some(format!("{} {} {}", left, op, right))
1194}
1195
1196fn build_fts_query(
1198 expr: &IRExpr,
1199 params: &ParamMap,
1200) -> Option<lance_index::scalar::FullTextSearchQuery> {
1201 match expr {
1202 IRExpr::Search { field, query } => {
1203 let prop = extract_property(field)?;
1204 let q = resolve_to_string(query, params)?;
1205 lance_index::scalar::FullTextSearchQuery::new(q)
1206 .with_column(prop)
1207 .ok()
1208 }
1209 IRExpr::Fuzzy {
1210 field,
1211 query,
1212 max_edits,
1213 } => {
1214 let prop = extract_property(field)?;
1215 let q = resolve_to_string(query, params)?;
1216 let edits = max_edits
1217 .as_ref()
1218 .and_then(|e| resolve_to_int(e, params))
1219 .unwrap_or(2) as u32;
1220 lance_index::scalar::FullTextSearchQuery::new_fuzzy(q, Some(edits))
1221 .with_column(prop)
1222 .ok()
1223 }
1224 IRExpr::MatchText { field, query } => {
1225 let prop = extract_property(field)?;
1227 let q = resolve_to_string(query, params)?;
1228 lance_index::scalar::FullTextSearchQuery::new(q)
1229 .with_column(prop)
1230 .ok()
1231 }
1232 _ => None,
1233 }
1234}
1235
1236fn extract_property(expr: &IRExpr) -> Option<String> {
1238 match expr {
1239 IRExpr::PropAccess { property, .. } => Some(property.clone()),
1240 _ => None,
1241 }
1242}
1243
1244fn resolve_to_string(expr: &IRExpr, params: &ParamMap) -> Option<String> {
1246 match expr {
1247 IRExpr::Literal(Literal::String(s)) => Some(s.clone()),
1248 IRExpr::Param(name) => match params.get(name)? {
1249 Literal::String(s) => Some(s.clone()),
1250 _ => None,
1251 },
1252 _ => None,
1253 }
1254}
1255
1256fn resolve_to_int(expr: &IRExpr, params: &ParamMap) -> Option<i64> {
1258 match expr {
1259 IRExpr::Literal(Literal::Integer(n)) => Some(*n),
1260 IRExpr::Param(name) => match params.get(name)? {
1261 Literal::Integer(n) => Some(*n),
1262 _ => None,
1263 },
1264 _ => None,
1265 }
1266}
1267
1268fn ir_expr_to_sql(expr: &IRExpr, params: &ParamMap) -> Option<String> {
1269 match expr {
1270 IRExpr::PropAccess { property, .. } => Some(property.clone()),
1271 IRExpr::Literal(lit) => Some(literal_to_sql(lit)),
1272 IRExpr::Param(name) => params.get(name).map(literal_to_sql),
1273 _ => None,
1274 }
1275}
1276
1277pub(super) fn literal_to_sql(lit: &Literal) -> String {
1278 match lit {
1279 Literal::Null => "NULL".to_string(),
1280 Literal::String(s) => format!("'{}'", s.replace('\'', "''")),
1281 Literal::Integer(n) => n.to_string(),
1282 Literal::Float(f) => f.to_string(),
1283 Literal::Bool(b) => b.to_string(),
1284 Literal::Date(s) => format!("'{}'", s.replace('\'', "''")),
1285 Literal::DateTime(s) => format!("'{}'", s.replace('\'', "''")),
1286 Literal::List(_) => "NULL".to_string(), }
1288}
1289
1290fn prefix_batch(batch: &RecordBatch, variable: &str) -> Result<RecordBatch> {
1291 let fields: Vec<Field> = batch.schema().fields().iter().map(|f| {
1292 Field::new(format!("{}.{}", variable, f.name()), f.data_type().clone(), f.is_nullable())
1293 }).collect();
1294 let schema = Arc::new(Schema::new(fields));
1295 RecordBatch::try_new(schema, batch.columns().to_vec()).map_err(|e| OmniError::Lance(e.to_string()))
1296}
1297
1298fn cross_join_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
1299 let n = left.num_rows();
1300 let m = right.num_rows();
1301 if n == 0 || m == 0 {
1302 let mut fields: Vec<Field> = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect();
1303 fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
1304 return Ok(RecordBatch::new_empty(Arc::new(Schema::new(fields))));
1305 }
1306 let left_indices: Vec<u32> = (0..n as u32).flat_map(|i| std::iter::repeat(i).take(m)).collect();
1307 let right_indices: Vec<u32> = (0..n).flat_map(|_| 0..m as u32).collect();
1308 let left_expanded = take_batch(left, &UInt32Array::from(left_indices))?;
1309 let right_expanded = take_batch(right, &UInt32Array::from(right_indices))?;
1310 hconcat_batches(&left_expanded, &right_expanded)
1311}
1312
1313fn hconcat_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
1314 let mut fields: Vec<Field> = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect();
1315 if cfg!(debug_assertions) {
1316 let left_schema = left.schema();
1317 let left_names: HashSet<&str> = left_schema.fields().iter().map(|f| f.name().as_str()).collect();
1318 let right_schema = right.schema();
1319 for f in right_schema.fields() {
1320 debug_assert!(!left_names.contains(f.name().as_str()), "hconcat_batches: duplicate column '{}'", f.name());
1321 }
1322 }
1323 fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
1324 let mut columns: Vec<ArrayRef> = left.columns().to_vec();
1325 columns.extend(right.columns().to_vec());
1326 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns).map_err(|e| OmniError::Lance(e.to_string()))
1327}
1328
1329fn take_batch(batch: &RecordBatch, indices: &UInt32Array) -> Result<RecordBatch> {
1330 let columns: Vec<ArrayRef> = batch.columns().iter()
1331 .map(|col| arrow_select::take::take(col.as_ref(), indices, None))
1332 .collect::<std::result::Result<Vec<_>, _>>()
1333 .map_err(|e| OmniError::Lance(e.to_string()))?;
1334 RecordBatch::try_new(batch.schema(), columns).map_err(|e| OmniError::Lance(e.to_string()))
1335}