1use super::*;
2
3use super::projection::{apply_filter, apply_ordering, project_return};
4
5pub(crate) struct EmbeddingResolver<'a> {
11 cell: &'a tokio::sync::OnceCell<EmbeddingClient>,
12 config: Option<&'a crate::embedding::EmbeddingConfig>,
13}
14
15impl EmbeddingResolver<'_> {
16 async fn resolve(&self) -> Result<&EmbeddingClient> {
17 let config = self.config.cloned();
18 self.cell
19 .get_or_try_init(|| async move {
20 match config {
21 Some(cfg) => EmbeddingClient::new(cfg),
22 None => EmbeddingClient::from_env(),
23 }
24 })
25 .await
26 }
27}
28
29impl Omnigraph {
30 pub async fn query(
32 &self,
33 target: impl Into<ReadTarget>,
34 query_source: &str,
35 query_name: &str,
36 params: &ParamMap,
37 ) -> Result<QueryResult> {
38 self.ensure_schema_state_valid().await?;
39 let resolved = self.resolved_target(target).await?;
40 let catalog = self.catalog();
41
42 let query_decl = omnigraph_compiler::find_named_query(query_source, query_name)
43 .map_err(|e| OmniError::manifest(e.to_string()))?;
44 let type_ctx = typecheck_query(&catalog, &query_decl)?;
45 let ir = lower_query(&catalog, &query_decl, &type_ctx)?;
46
47 let needs_graph = ir
48 .pipeline
49 .iter()
50 .any(|op| matches!(op, IROp::Expand { .. } | IROp::AntiJoin { .. }));
51 let graph_index = if needs_graph {
53 GraphIndexHandle::cached(self, &resolved)
54 } else {
55 GraphIndexHandle::none()
56 };
57
58 execute_query(
59 &ir,
60 params,
61 &resolved.snapshot,
62 &graph_index,
63 &catalog,
64 &EmbeddingResolver {
65 cell: self.embedding_cell(),
66 config: self.embedding_config_ref(),
67 },
68 )
69 .await
70 }
71
72 pub async fn run_query_at(
77 &self,
78 version: u64,
79 query_source: &str,
80 query_name: &str,
81 params: &ParamMap,
82 ) -> Result<QueryResult> {
83 self.ensure_schema_state_valid().await?;
84 let snapshot = self.snapshot_at_version(version).await?;
85 let catalog = self.catalog();
86
87 let query_decl = omnigraph_compiler::find_named_query(query_source, query_name)
88 .map_err(|e| OmniError::manifest(e.to_string()))?;
89 let type_ctx = typecheck_query(&catalog, &query_decl)?;
90 let ir = lower_query(&catalog, &query_decl, &type_ctx)?;
91
92 let needs_graph = ir
93 .pipeline
94 .iter()
95 .any(|op| matches!(op, IROp::Expand { .. } | IROp::AntiJoin { .. }));
96 let graph_index = if needs_graph {
100 let edge_types = catalog
101 .edge_types
102 .iter()
103 .map(|(name, et)| (name.clone(), (et.from_type.clone(), et.to_type.clone())))
104 .collect();
105 GraphIndexHandle::direct(&snapshot, edge_types)
106 } else {
107 GraphIndexHandle::none()
108 };
109
110 execute_query(
111 &ir,
112 params,
113 &snapshot,
114 &graph_index,
115 &catalog,
116 &EmbeddingResolver {
117 cell: self.embedding_cell(),
118 config: self.embedding_config_ref(),
119 },
120 )
121 .await
122 }
123}
124
125#[derive(Debug, Default)]
129struct SearchMode {
130 nearest: Option<(String, String, Vec<f32>, usize)>,
132 bm25: Option<(String, String, String)>,
134 rrf: Option<RrfMode>,
136}
137
138#[derive(Debug)]
139struct RrfMode {
140 primary: Box<SearchMode>,
141 secondary: Box<SearchMode>,
142 k: u32,
143 limit: usize,
144}
145
146async fn extract_search_mode(
148 ir: &QueryIR,
149 params: &ParamMap,
150 catalog: &Catalog,
151 embedding: &EmbeddingResolver<'_>,
152) -> Result<SearchMode> {
153 if ir.order_by.is_empty() {
154 return Ok(SearchMode::default());
155 }
156 let ordering = &ir.order_by[0];
157 match &ordering.expr {
158 IRExpr::Nearest {
159 variable,
160 property,
161 query,
162 } => {
163 let vec =
164 resolve_nearest_query_vec(ir, catalog, variable, property, query, params, embedding)
165 .await?;
166 let k = ir.limit.ok_or_else(|| {
167 OmniError::manifest("nearest() ordering requires a limit clause".to_string())
168 })? as usize;
169 Ok(SearchMode {
170 nearest: Some((variable.clone(), property.clone(), vec, k)),
171 ..Default::default()
172 })
173 }
174 IRExpr::Bm25 { field, query } => {
175 let var = match field.as_ref() {
176 IRExpr::PropAccess { variable, .. } => variable.clone(),
177 _ => {
178 return Err(OmniError::manifest(
179 "bm25 field must be a property access".to_string(),
180 ));
181 }
182 };
183 let prop = extract_property(field).ok_or_else(|| {
184 OmniError::manifest("bm25 field must be a property access".to_string())
185 })?;
186 let text = resolve_to_string(query, params).ok_or_else(|| {
187 OmniError::manifest("bm25 query must resolve to a string".to_string())
188 })?;
189 Ok(SearchMode {
190 bm25: Some((var, prop, text)),
191 ..Default::default()
192 })
193 }
194 IRExpr::Rrf {
195 primary,
196 secondary,
197 k,
198 } => {
199 let limit = ir.limit.ok_or_else(|| {
200 OmniError::manifest("rrf() ordering requires a limit clause".to_string())
201 })? as usize;
202 let k_val = k
203 .as_ref()
204 .and_then(|e| resolve_to_int(e, params))
205 .unwrap_or(60) as u32;
206
207 let primary_mode =
208 extract_sub_search_mode(ir, primary, params, catalog, ir.limit, embedding).await?;
209 let secondary_mode =
210 extract_sub_search_mode(ir, secondary, params, catalog, ir.limit, embedding)
211 .await?;
212
213 Ok(SearchMode {
214 rrf: Some(RrfMode {
215 primary: Box::new(primary_mode),
216 secondary: Box::new(secondary_mode),
217 k: k_val,
218 limit,
219 }),
220 ..Default::default()
221 })
222 }
223 _ => Ok(SearchMode::default()),
224 }
225}
226
227async fn extract_sub_search_mode(
229 ir: &QueryIR,
230 expr: &IRExpr,
231 params: &ParamMap,
232 catalog: &Catalog,
233 limit: Option<u64>,
234 embedding: &EmbeddingResolver<'_>,
235) -> Result<SearchMode> {
236 match expr {
237 IRExpr::Nearest {
238 variable,
239 property,
240 query,
241 } => {
242 let vec =
243 resolve_nearest_query_vec(ir, catalog, variable, property, query, params, embedding)
244 .await?;
245 let k = limit.unwrap_or(100) as usize;
246 Ok(SearchMode {
247 nearest: Some((variable.clone(), property.clone(), vec, k)),
248 ..Default::default()
249 })
250 }
251 IRExpr::Bm25 { field, query } => {
252 let var = match field.as_ref() {
253 IRExpr::PropAccess { variable, .. } => variable.clone(),
254 _ => {
255 return Err(OmniError::manifest(
256 "bm25 field must be a property access".to_string(),
257 ));
258 }
259 };
260 let prop = extract_property(field).ok_or_else(|| {
261 OmniError::manifest("bm25 field must be a property access".to_string())
262 })?;
263 let text = resolve_to_string(query, params).ok_or_else(|| {
264 OmniError::manifest("bm25 query must resolve to a string".to_string())
265 })?;
266 Ok(SearchMode {
267 bm25: Some((var, prop, text)),
268 ..Default::default()
269 })
270 }
271 _ => Ok(SearchMode::default()),
272 }
273}
274
275async fn resolve_nearest_query_vec(
277 ir: &QueryIR,
278 catalog: &Catalog,
279 variable: &str,
280 property: &str,
281 expr: &IRExpr,
282 params: &ParamMap,
283 embedding: &EmbeddingResolver<'_>,
284) -> Result<Vec<f32>> {
285 let lit = resolve_literal_or_param(expr, params)?;
286 match lit {
287 Literal::List(_) => literal_to_f32_vec(&lit),
288 Literal::String(text) => {
289 let (expected_dim, recorded_model) =
290 nearest_property_dim_and_model(ir, catalog, variable, property)?;
291 let client = embedding.resolve().await?;
295 if let Some(recorded) = &recorded_model {
300 let resolved = &client.config().model;
301 if resolved != recorded {
302 return Err(OmniError::manifest(format!(
303 "nearest() on '{property}': its stored vectors were embedded with model \
304 '{recorded}', but the query embedder resolves to '{resolved}'. Set \
305 OMNIGRAPH_EMBED_MODEL='{recorded}' (and the matching provider) or re-embed \
306 the stored vectors."
307 )));
308 }
309 }
310 client.embed_query_text(&text, expected_dim).await
311 }
312 _ => Err(OmniError::manifest(
313 "nearest query must be a string or list of floats".to_string(),
314 )),
315 }
316}
317
318fn resolve_literal_or_param(expr: &IRExpr, params: &ParamMap) -> Result<Literal> {
319 Ok(match expr {
320 IRExpr::Literal(lit) => lit.clone(),
321 IRExpr::Param(name) => params
322 .get(name)
323 .cloned()
324 .ok_or_else(|| OmniError::manifest(format!("parameter '{}' not provided", name)))?,
325 _ => {
326 return Err(OmniError::manifest(
327 "nearest query must be a literal or parameter".to_string(),
328 ));
329 }
330 })
331}
332
333fn literal_to_f32_vec(lit: &Literal) -> Result<Vec<f32>> {
335 match lit {
336 Literal::List(items) => items
337 .iter()
338 .map(|item| match item {
339 Literal::Float(f) => Ok(*f as f32),
340 Literal::Integer(n) => Ok(*n as f32),
341 _ => Err(OmniError::manifest(
342 "vector elements must be numeric".to_string(),
343 )),
344 })
345 .collect(),
346 _ => Err(OmniError::manifest(
347 "nearest query must be a list of floats".to_string(),
348 )),
349 }
350}
351
352fn nearest_property_dim_and_model(
355 ir: &QueryIR,
356 catalog: &Catalog,
357 variable: &str,
358 property: &str,
359) -> Result<(usize, Option<String>)> {
360 let type_name = resolve_binding_type_name(&ir.pipeline, variable).ok_or_else(|| {
361 OmniError::manifest_internal(format!(
362 "nearest() variable '${}' is not bound to a node type in the lowered pipeline",
363 variable
364 ))
365 })?;
366 let node_type = catalog.node_types.get(type_name).ok_or_else(|| {
367 OmniError::manifest_internal(format!(
368 "nearest() binding '${}' resolved unknown node type '{}'",
369 variable, type_name
370 ))
371 })?;
372 let prop = node_type.properties.get(property).ok_or_else(|| {
373 OmniError::manifest_internal(format!(
374 "nearest() property '{}.{}' is missing from the catalog",
375 type_name, property
376 ))
377 })?;
378 let dim = match prop.scalar {
379 ScalarType::Vector(dim) if !prop.list => dim as usize,
380 _ => {
381 return Err(OmniError::manifest_internal(format!(
382 "nearest() property '{}.{}' is not a scalar vector",
383 type_name, property
384 )));
385 }
386 };
387 let recorded_model = node_type
388 .embed_sources
389 .get(property)
390 .and_then(|embed| embed.model.clone());
391 Ok((dim, recorded_model))
392}
393
394fn resolve_binding_type_name<'a>(pipeline: &'a [IROp], variable: &str) -> Option<&'a str> {
395 for op in pipeline {
396 match op {
397 IROp::NodeScan {
398 variable: bound_var,
399 type_name,
400 ..
401 } if bound_var == variable => return Some(type_name.as_str()),
402 IROp::Expand {
403 dst_var, dst_type, ..
404 } if dst_var == variable => return Some(dst_type.as_str()),
405 IROp::AntiJoin { inner, .. } => {
406 if let Some(type_name) = resolve_binding_type_name(inner, variable) {
407 return Some(type_name);
408 }
409 }
410 _ => {}
411 }
412 }
413 None
414}
415
416pub async fn execute_query(
418 ir: &QueryIR,
419 params: &ParamMap,
420 snapshot: &Snapshot,
421 graph_index: &GraphIndexHandle<'_>,
422 catalog: &Catalog,
423 embedding: &EmbeddingResolver<'_>,
424) -> Result<QueryResult> {
425 let search_mode = extract_search_mode(ir, params, catalog, embedding).await?;
426
427 if let Some(ref rrf) = search_mode.rrf {
429 return execute_rrf_query(ir, params, snapshot, graph_index, catalog, rrf).await;
430 }
431
432 let mut wide: Option<RecordBatch> = None;
433 execute_pipeline(
434 &ir.pipeline,
435 params,
436 snapshot,
437 graph_index,
438 catalog,
439 &mut wide,
440 &search_mode,
441 )
442 .await?;
443 let wide_batch = wide.unwrap_or_else(|| RecordBatch::new_empty(Arc::new(Schema::empty())));
444
445 let has_aggregates = ir
447 .return_exprs
448 .iter()
449 .any(|p| matches!(&p.expr, IRExpr::Aggregate { .. }));
450 let mut result_batch = project_return(&wide_batch, &ir.return_exprs, params)?;
451
452 if !ir.order_by.is_empty() && !is_search_ordered(&search_mode) {
454 result_batch = if has_aggregates {
455 apply_ordering(result_batch.clone(), &ir.order_by, &result_batch, params)?
456 } else {
457 apply_ordering(result_batch, &ir.order_by, &wide_batch, params)?
458 };
459 }
460
461 if let Some(limit) = ir.limit {
463 let len = result_batch.num_rows().min(limit as usize);
464 result_batch = result_batch.slice(0, len);
465 }
466
467 Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
468}
469
470fn is_search_ordered(search_mode: &SearchMode) -> bool {
472 search_mode.nearest.is_some() || search_mode.bm25.is_some()
473}
474
475async fn execute_rrf_query(
477 ir: &QueryIR,
478 params: &ParamMap,
479 snapshot: &Snapshot,
480 graph_index: &GraphIndexHandle<'_>,
481 catalog: &Catalog,
482 rrf: &RrfMode,
483) -> Result<QueryResult> {
484 let mut primary_wide: Option<RecordBatch> = None;
486 execute_pipeline(
487 &ir.pipeline,
488 params,
489 snapshot,
490 graph_index,
491 catalog,
492 &mut primary_wide,
493 &rrf.primary,
494 )
495 .await?;
496
497 let mut secondary_wide: Option<RecordBatch> = None;
499 execute_pipeline(
500 &ir.pipeline,
501 params,
502 snapshot,
503 graph_index,
504 catalog,
505 &mut secondary_wide,
506 &rrf.secondary,
507 )
508 .await?;
509
510 let primary_var = rrf
513 .primary
514 .nearest
515 .as_ref()
516 .map(|(v, ..)| v.as_str())
517 .or_else(|| rrf.primary.bm25.as_ref().map(|(v, ..)| v.as_str()))
518 .ok_or_else(|| OmniError::manifest("rrf primary must be nearest or bm25".to_string()))?;
519
520 let primary_batch = primary_wide.as_ref().ok_or_else(|| {
521 OmniError::manifest(format!(
522 "rrf primary variable '{}' not in bindings",
523 primary_var
524 ))
525 })?;
526 let secondary_batch = secondary_wide.as_ref().ok_or_else(|| {
527 OmniError::manifest(format!(
528 "rrf secondary variable '{}' not in bindings",
529 primary_var
530 ))
531 })?;
532
533 let id_col_name = format!("{}.id", primary_var);
535 let primary_ids = extract_id_column_by_name(primary_batch, &id_col_name)?;
536 let secondary_ids = extract_id_column_by_name(secondary_batch, &id_col_name)?;
537
538 let mut primary_rank: HashMap<String, usize> = HashMap::new();
539 for (i, id) in primary_ids.iter().enumerate() {
540 primary_rank.entry(id.clone()).or_insert(i);
541 }
542 let mut secondary_rank: HashMap<String, usize> = HashMap::new();
543 for (i, id) in secondary_ids.iter().enumerate() {
544 secondary_rank.entry(id.clone()).or_insert(i);
545 }
546
547 let mut all_ids: Vec<String> = primary_ids.clone();
549 for id in &secondary_ids {
550 if !primary_rank.contains_key(id) {
551 all_ids.push(id.clone());
552 }
553 }
554
555 let k = rrf.k as f64;
557 let mut scored: Vec<(String, f64)> = all_ids
558 .iter()
559 .map(|id| {
560 let p = primary_rank
561 .get(id)
562 .map(|&r| 1.0 / (k + r as f64 + 1.0))
563 .unwrap_or(0.0);
564 let s = secondary_rank
565 .get(id)
566 .map(|&r| 1.0 / (k + r as f64 + 1.0))
567 .unwrap_or(0.0);
568 (id.clone(), p + s)
569 })
570 .collect();
571 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
572 scored.truncate(rrf.limit);
573
574 let winning_ids: Vec<String> = scored.iter().map(|(id, _)| id.clone()).collect();
576
577 let mut id_to_batch_row: HashMap<String, (&RecordBatch, usize)> = HashMap::new();
579 for (i, id) in primary_ids.iter().enumerate() {
580 id_to_batch_row
581 .entry(id.clone())
582 .or_insert((primary_batch, i));
583 }
584 for (i, id) in secondary_ids.iter().enumerate() {
585 id_to_batch_row
586 .entry(id.clone())
587 .or_insert((secondary_batch, i));
588 }
589
590 let fused_batch = build_fused_batch(&winning_ids, &id_to_batch_row, primary_batch.schema())?;
592
593 let result_batch = project_return(&fused_batch, &ir.return_exprs, params)?;
595
596 Ok(QueryResult::new(result_batch.schema(), vec![result_batch]))
598}
599
600fn extract_id_column_by_name(batch: &RecordBatch, col_name: &str) -> Result<Vec<String>> {
601 let col = batch.column_by_name(col_name).ok_or_else(|| {
602 OmniError::manifest(format!("batch missing '{}' column for RRF", col_name))
603 })?;
604 let ids = col
605 .as_any()
606 .downcast_ref::<StringArray>()
607 .ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", col_name)))?;
608 Ok((0..ids.len()).map(|i| ids.value(i).to_string()).collect())
609}
610
611fn build_fused_batch(
612 ordered_ids: &[String],
613 id_to_batch_row: &HashMap<String, (&RecordBatch, usize)>,
614 schema: SchemaRef,
615) -> Result<RecordBatch> {
616 if ordered_ids.is_empty() {
617 return Ok(RecordBatch::new_empty(schema));
618 }
619
620 let mut row_slices: Vec<RecordBatch> = Vec::with_capacity(ordered_ids.len());
622 for id in ordered_ids {
623 if let Some(&(batch, row_idx)) = id_to_batch_row.get(id) {
624 row_slices.push(batch.slice(row_idx, 1));
625 }
626 }
627
628 if row_slices.is_empty() {
629 return Ok(RecordBatch::new_empty(schema));
630 }
631
632 let schema = row_slices[0].schema();
633 arrow_select::concat::concat_batches(&schema, &row_slices)
634 .map_err(|e| OmniError::Lance(e.to_string()))
635}
636
637fn is_search_filter(filter: &IRFilter) -> bool {
639 matches!(
640 &filter.left,
641 IRExpr::Search { .. } | IRExpr::Fuzzy { .. } | IRExpr::MatchText { .. }
642 )
643}
644
645fn search_filter_variable(filter: &IRFilter) -> Option<&str> {
647 let field = match &filter.left {
648 IRExpr::Search { field, .. } => field,
649 IRExpr::Fuzzy { field, .. } => field,
650 IRExpr::MatchText { field, .. } => field,
651 _ => return None,
652 };
653 match field.as_ref() {
654 IRExpr::PropAccess { variable, .. } => Some(variable.as_str()),
655 _ => None,
656 }
657}
658
659fn execute_pipeline<'a>(
660 pipeline: &'a [IROp],
661 params: &'a ParamMap,
662 snapshot: &'a Snapshot,
663 graph_index: &'a GraphIndexHandle<'a>,
664 catalog: &'a Catalog,
665 wide: &'a mut Option<RecordBatch>,
666 search_mode: &'a SearchMode,
667) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
668 Box::pin(async move {
669 let mut hoisted_search_filters: HashMap<String, Vec<IRFilter>> = HashMap::new();
671 let mut hoisted_indices: HashSet<usize> = HashSet::new();
672 for (i, op) in pipeline.iter().enumerate() {
673 if let IROp::Filter(filter) = op {
674 if is_search_filter(filter) {
675 if let Some(var) = search_filter_variable(filter) {
676 hoisted_search_filters
677 .entry(var.to_string())
678 .or_default()
679 .push(filter.clone());
680 hoisted_indices.insert(i);
681 }
682 }
683 }
684 }
685
686 for (i, op) in pipeline.iter().enumerate() {
687 if hoisted_indices.contains(&i) {
689 continue;
690 }
691 match op {
692 IROp::NodeScan {
693 variable,
694 type_name,
695 filters,
696 } => {
697 let mut all_filters: Vec<IRFilter> = filters.clone();
699 if let Some(extra) = hoisted_search_filters.get(variable) {
700 all_filters.extend(extra.iter().cloned());
701 }
702 let batch = execute_node_scan(
703 type_name,
704 variable,
705 &all_filters,
706 params,
707 snapshot,
708 catalog,
709 search_mode,
710 )
711 .await?;
712 let prefixed = prefix_batch(&batch, variable)?;
713 *wide = Some(match wide.take() {
714 None => prefixed,
715 Some(existing) => cross_join_batches(&existing, &prefixed)?,
716 });
717 }
718 IROp::Filter(filter) => {
719 if let Some(batch) = wide.as_mut() {
720 apply_filter(batch, filter, params)?;
721 }
722 }
723 IROp::Expand {
724 src_var,
725 dst_var,
726 edge_type,
727 direction,
728 dst_type,
729 min_hops,
730 max_hops,
731 dst_filters,
732 } => {
733 if let Some(batch) = wide.as_mut() {
734 execute_expand(
735 batch,
736 graph_index,
737 snapshot,
738 catalog,
739 src_var,
740 dst_var,
741 edge_type,
742 *direction,
743 dst_type,
744 *min_hops,
745 *max_hops,
746 dst_filters,
747 params,
748 )
749 .await?;
750 }
751 }
752 IROp::AntiJoin { outer_var, inner } => {
753 let gi = graph_index;
754 if let Some(batch) = wide.as_mut() {
755 execute_anti_join(batch, inner, params, snapshot, gi, catalog, outer_var)
756 .await?;
757 }
758 }
759 }
760 }
761 Ok(())
762 })
763}
764
765pub struct GraphIndexHandle<'a> {
773 cell: tokio::sync::OnceCell<Option<Arc<GraphIndex>>>,
774 builder: GraphIndexBuilder<'a>,
775}
776
777enum GraphIndexBuilder<'a> {
778 None,
779 Cached(&'a Omnigraph, &'a crate::db::ResolvedTarget),
780 Direct(&'a Snapshot, HashMap<String, (String, String)>),
781}
782
783impl<'a> GraphIndexHandle<'a> {
784 fn none() -> Self {
785 Self {
786 cell: tokio::sync::OnceCell::new(),
787 builder: GraphIndexBuilder::None,
788 }
789 }
790
791 fn cached(db: &'a Omnigraph, resolved: &'a crate::db::ResolvedTarget) -> Self {
792 Self {
793 cell: tokio::sync::OnceCell::new(),
794 builder: GraphIndexBuilder::Cached(db, resolved),
795 }
796 }
797
798 fn direct(snapshot: &'a Snapshot, edge_types: HashMap<String, (String, String)>) -> Self {
799 Self {
800 cell: tokio::sync::OnceCell::new(),
801 builder: GraphIndexBuilder::Direct(snapshot, edge_types),
802 }
803 }
804
805 async fn get(&self) -> Result<Option<&GraphIndex>> {
808 let built = self
809 .cell
810 .get_or_try_init(|| async {
811 match &self.builder {
812 GraphIndexBuilder::None => Ok::<Option<Arc<GraphIndex>>, OmniError>(None),
813 GraphIndexBuilder::Cached(db, resolved) => {
814 Ok(Some(db.graph_index_for_resolved(resolved).await?))
815 }
816 GraphIndexBuilder::Direct(snapshot, edge_types) => {
817 Ok(Some(Arc::new(GraphIndex::build(snapshot, edge_types).await?)))
818 }
819 }
820 })
821 .await?;
822 Ok(built.as_deref())
823 }
824
825 fn is_built(&self) -> bool {
829 matches!(self.cell.get(), Some(Some(_)))
830 }
831}
832
833fn traversal_indexed_override() -> Option<bool> {
837 match std::env::var("OMNIGRAPH_TRAVERSAL_MODE").ok().as_deref() {
838 Some("indexed") => Some(true),
839 Some("csr") => Some(false),
840 _ => None,
841 }
842}
843
844const DEFAULT_EXPAND_INDEXED_MAX_FRONTIER: usize = 1024;
848const DEFAULT_EXPAND_INDEXED_MAX_HOPS: u32 = 6;
851
852fn expand_indexed_max_frontier() -> usize {
853 std::env::var("OMNIGRAPH_EXPAND_INDEXED_MAX_FRONTIER")
854 .ok()
855 .and_then(|v| v.parse::<usize>().ok())
856 .unwrap_or(DEFAULT_EXPAND_INDEXED_MAX_FRONTIER)
857}
858
859fn expand_indexed_max_hops() -> u32 {
860 std::env::var("OMNIGRAPH_EXPAND_INDEXED_MAX_HOPS")
861 .ok()
862 .and_then(|v| v.parse::<u32>().ok())
863 .filter(|&v| v > 0)
864 .unwrap_or(DEFAULT_EXPAND_INDEXED_MAX_HOPS)
865}
866
867#[derive(Debug, Clone, Copy, PartialEq, Eq)]
871enum ExpandMode {
872 IndexedScan,
875 Csr,
879}
880
881const CSR_BUILD_FACTOR: f64 = 1.5;
886
887#[derive(Debug, Clone)]
891struct ExpandCostInputs {
892 frontier_rows: usize,
894 edge_count: u64,
896 src_node_count: u64,
898 effective_max_hops: u32,
900 max_hops_cap: u32,
903 max_frontier_cap: usize,
906 coverage: crate::table_store::IndexCoverage,
909 csr_cached: bool,
913}
914
915fn choose_expand_mode(i: &ExpandCostInputs) -> ExpandMode {
926 if i.effective_max_hops > i.max_hops_cap || i.frontier_rows > i.max_frontier_cap {
930 return ExpandMode::Csr;
931 }
932
933 let hops = i.effective_max_hops.max(1) as f64;
934 let frontier = i.frontier_rows as f64;
935 let edges = i.edge_count as f64;
936 let src = i.src_node_count.max(1) as f64;
937 let fanout = edges / src;
938
939 let indexed_cost = match i.coverage {
942 crate::table_store::IndexCoverage::Indexed => hops * frontier * fanout,
943 crate::table_store::IndexCoverage::Degraded { .. } => hops * edges,
944 };
945 let csr_cost = if i.csr_cached {
947 0.0
948 } else {
949 CSR_BUILD_FACTOR * edges
950 };
951
952 if indexed_cost < csr_cost {
953 ExpandMode::IndexedScan
954 } else {
955 ExpandMode::Csr
956 }
957}
958
959fn cost_effective_hops(requested_max_hops: u32, same_type: bool) -> u32 {
964 if same_type {
965 requested_max_hops
966 } else {
967 requested_max_hops.min(1)
968 }
969}
970
971fn gather_cost_inputs(
976 snapshot: &Snapshot,
977 catalog: &Catalog,
978 edge_type: &str,
979 direction: Direction,
980 frontier_rows: usize,
981 effective_max_hops: u32,
982 coverage: crate::table_store::IndexCoverage,
983 csr_cached: bool,
984) -> Option<ExpandCostInputs> {
985 let edge_entry = snapshot.entry(&format!("edge:{}", edge_type))?;
986 let edge_def = catalog.edge_types.get(edge_type)?;
987 let effective_max_hops =
990 cost_effective_hops(effective_max_hops, edge_def.from_type == edge_def.to_type);
991 let src_type = match direction {
994 Direction::Out => &edge_def.from_type,
995 Direction::In => &edge_def.to_type,
996 };
997 let src_entry = snapshot.entry(&format!("node:{}", src_type))?;
998 Some(ExpandCostInputs {
999 frontier_rows,
1000 edge_count: edge_entry.row_count,
1001 src_node_count: src_entry.row_count,
1002 effective_max_hops,
1003 max_hops_cap: expand_indexed_max_hops(),
1004 max_frontier_cap: expand_indexed_max_frontier(),
1005 coverage,
1006 csr_cached,
1007 })
1008}
1009
1010fn coverage_for_decision(
1014 coverage: &Result<crate::table_store::IndexCoverage>,
1015) -> crate::table_store::IndexCoverage {
1016 match coverage {
1017 Ok(c) => c.clone(),
1018 Err(_) => crate::table_store::IndexCoverage::Degraded {
1019 reason: "coverage check failed".to_string(),
1020 },
1021 }
1022}
1023
1024fn warn_on_degraded_coverage(
1028 coverage: &Result<crate::table_store::IndexCoverage>,
1029 key_col: &str,
1030 edge_type: &str,
1031) {
1032 match coverage {
1033 Ok(crate::table_store::IndexCoverage::Degraded { reason }) => tracing::warn!(
1034 target: "omnigraph::traverse",
1035 edge = %edge_type,
1036 key_col = key_col,
1037 reason = %reason,
1038 "indexed traversal falls back to a full edge scan (results correct, perf degraded)"
1039 ),
1040 Ok(crate::table_store::IndexCoverage::Indexed) => {}
1041 Err(e) => tracing::debug!(
1042 target: "omnigraph::traverse",
1043 error = %e,
1044 "index-coverage check failed; proceeding with traversal"
1045 ),
1046 }
1047}
1048
1049fn endpoint_columns(direction: Direction) -> (&'static str, &'static str) {
1053 match direction {
1054 Direction::Out => ("src", "dst"),
1055 Direction::In => ("dst", "src"),
1056 }
1057}
1058
1059async fn execute_expand(
1064 wide: &mut RecordBatch,
1065 graph_index: &GraphIndexHandle<'_>,
1066 snapshot: &Snapshot,
1067 catalog: &Catalog,
1068 src_var: &str,
1069 dst_var: &str,
1070 edge_type: &str,
1071 direction: Direction,
1072 dst_type: &str,
1073 min_hops: u32,
1074 max_hops: Option<u32>,
1075 dst_filters: &[IRFilter],
1076 params: &ParamMap,
1077) -> Result<()> {
1078 let frontier_rows = wide.num_rows();
1079 let effective_max_hops = max_hops.unwrap_or(min_hops.max(1));
1080 let (key_col, _) = endpoint_columns(direction);
1081 let edge_table_key = format!("edge:{}", edge_type);
1082
1083 let forced = traversal_indexed_override();
1089 let lean_indexed = match forced {
1090 Some(v) => v,
1091 None => match gather_cost_inputs(
1092 snapshot,
1093 catalog,
1094 edge_type,
1095 direction,
1096 frontier_rows,
1097 effective_max_hops,
1098 crate::table_store::IndexCoverage::Indexed,
1099 graph_index.is_built(),
1100 ) {
1101 Some(inputs) => choose_expand_mode(&inputs) == ExpandMode::IndexedScan,
1102 None => {
1105 frontier_rows <= expand_indexed_max_frontier()
1106 && effective_max_hops <= expand_indexed_max_hops()
1107 }
1108 },
1109 };
1110
1111 if !lean_indexed {
1112 tracing::debug!(
1113 target: "omnigraph::traverse",
1114 edge = %edge_type,
1115 frontier = frontier_rows,
1116 hops = effective_max_hops,
1117 mode = "csr",
1118 "expand mode chosen",
1119 );
1120 let gi = graph_index.get().await?.ok_or_else(|| {
1121 OmniError::manifest("graph index required for CSR traversal".to_string())
1122 })?;
1123 return execute_expand_csr(
1124 wide, gi, snapshot, catalog, src_var, dst_var, edge_type, direction, dst_type,
1125 min_hops, max_hops, dst_filters, params,
1126 )
1127 .await;
1128 }
1129
1130 let edge_ds = snapshot.open(&edge_table_key).await?;
1134 let coverage =
1135 crate::table_store::TableStore::key_column_index_coverage(&edge_ds, key_col).await;
1136
1137 if forced.is_none() {
1138 if let Some(inputs) = gather_cost_inputs(
1139 snapshot,
1140 catalog,
1141 edge_type,
1142 direction,
1143 frontier_rows,
1144 effective_max_hops,
1145 coverage_for_decision(&coverage),
1146 graph_index.is_built(),
1147 ) {
1148 if choose_expand_mode(&inputs) == ExpandMode::Csr {
1149 tracing::debug!(
1150 target: "omnigraph::traverse",
1151 edge = %edge_type,
1152 frontier = frontier_rows,
1153 hops = effective_max_hops,
1154 mode = "csr",
1155 reason = "index coverage degraded",
1156 "expand mode chosen",
1157 );
1158 let gi = graph_index.get().await?.ok_or_else(|| {
1159 OmniError::manifest("graph index required for CSR traversal".to_string())
1160 })?;
1161 return execute_expand_csr(
1162 wide, gi, snapshot, catalog, src_var, dst_var, edge_type, direction, dst_type,
1163 min_hops, max_hops, dst_filters, params,
1164 )
1165 .await;
1166 }
1167 }
1168 }
1169
1170 tracing::debug!(
1171 target: "omnigraph::traverse",
1172 edge = %edge_type,
1173 frontier = frontier_rows,
1174 hops = effective_max_hops,
1175 mode = "indexed",
1176 "expand mode chosen",
1177 );
1178 warn_on_degraded_coverage(&coverage, key_col, edge_type);
1180 execute_expand_indexed(
1181 wide, snapshot, catalog, src_var, dst_var, edge_type, direction, dst_type, min_hops,
1182 max_hops, dst_filters, params, edge_ds,
1183 )
1184 .await
1185}
1186
1187async fn execute_expand_indexed(
1194 wide: &mut RecordBatch,
1195 snapshot: &Snapshot,
1196 catalog: &Catalog,
1197 src_var: &str,
1198 dst_var: &str,
1199 edge_type: &str,
1200 direction: Direction,
1201 dst_type: &str,
1202 min_hops: u32,
1203 max_hops: Option<u32>,
1204 dst_filters: &[IRFilter],
1205 params: &ParamMap,
1206 edge_ds: Dataset,
1207) -> Result<()> {
1208 let src_id_col_name = format!("{}.id", src_var);
1209 let src_ids = wide
1210 .column_by_name(&src_id_col_name)
1211 .ok_or_else(|| {
1212 OmniError::manifest(format!("wide batch missing '{}' column", src_id_col_name))
1213 })?
1214 .as_any()
1215 .downcast_ref::<StringArray>()
1216 .ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", src_id_col_name)))?
1217 .clone();
1218
1219 let edge_def = catalog
1220 .edge_types
1221 .get(edge_type)
1222 .ok_or_else(|| OmniError::manifest(format!("unknown edge type '{}'", edge_type)))?;
1223 let same_type = edge_def.from_type == edge_def.to_type;
1224 let (key_col, opp_col) = endpoint_columns(direction);
1228
1229 let max = max_hops.unwrap_or(min_hops.max(1));
1230 let max = if same_type { max } else { max.min(1) };
1238
1239 let mut interner = crate::graph_index::TypeIndex::new();
1246 let n = src_ids.len();
1247 let mut frontiers: Vec<Vec<u32>> = Vec::with_capacity(n);
1248 let mut visited: Vec<HashSet<u32>> = Vec::with_capacity(n);
1249 let mut seen_dst: Vec<HashSet<u32>> = Vec::with_capacity(n);
1250 for i in 0..n {
1251 let sid = interner.get_or_insert(src_ids.value(i));
1252 let mut v = HashSet::new();
1253 if same_type {
1254 v.insert(sid);
1255 }
1256 frontiers.push(vec![sid]);
1257 visited.push(v);
1258 seen_dst.push(HashSet::new());
1259 }
1260
1261 let mut src_indices: Vec<u32> = Vec::new();
1262 let mut dst_dense: Vec<u32> = Vec::new();
1263
1264 for hop in 1..=max {
1265 let mut union_dense: Vec<u32> = Vec::new();
1267 {
1268 let mut seen: HashSet<u32> = HashSet::new();
1269 for f in &frontiers {
1270 for &node in f {
1271 if seen.insert(node) {
1272 union_dense.push(node);
1273 }
1274 }
1275 }
1276 }
1277 if union_dense.is_empty() {
1278 break;
1279 }
1280 let union_keys: Vec<String> = union_dense
1281 .iter()
1282 .map(|&u| {
1283 interner
1284 .to_id(u)
1285 .expect("interned frontier id must resolve")
1286 .to_string()
1287 })
1288 .collect();
1289
1290 let batches = crate::table_store::TableStore::scan_edges_by_endpoint(
1291 &edge_ds, key_col, opp_col, &union_keys,
1292 )
1293 .await?;
1294
1295 let mut neighbor_map: HashMap<u32, Vec<u32>> = HashMap::new();
1297 for batch in &batches {
1298 let keys = batch
1299 .column_by_name(key_col)
1300 .ok_or_else(|| OmniError::manifest(format!("edge batch missing '{}'", key_col)))?
1301 .as_any()
1302 .downcast_ref::<StringArray>()
1303 .ok_or_else(|| OmniError::manifest(format!("edge '{}' is not Utf8", key_col)))?;
1304 let opps = batch
1305 .column_by_name(opp_col)
1306 .ok_or_else(|| OmniError::manifest(format!("edge batch missing '{}'", opp_col)))?
1307 .as_any()
1308 .downcast_ref::<StringArray>()
1309 .ok_or_else(|| OmniError::manifest(format!("edge '{}' is not Utf8", opp_col)))?;
1310 for r in 0..batch.num_rows() {
1311 let k = interner.get_or_insert(keys.value(r));
1312 let o = interner.get_or_insert(opps.value(r));
1313 neighbor_map.entry(k).or_default().push(o);
1314 }
1315 }
1316
1317 for i in 0..n {
1319 let cur = std::mem::take(&mut frontiers[i]);
1320 let mut next: Vec<u32> = Vec::new();
1321 for &node in &cur {
1322 let Some(neighbors) = neighbor_map.get(&node) else {
1323 continue;
1324 };
1325 for &neighbor in neighbors {
1326 if !same_type || visited[i].insert(neighbor) {
1327 next.push(neighbor);
1328 if hop >= min_hops && seen_dst[i].insert(neighbor) {
1329 src_indices.push(i as u32);
1330 dst_dense.push(neighbor);
1331 }
1332 }
1333 }
1334 }
1335 frontiers[i] = next;
1336 }
1337 }
1338
1339 let dst_ids: Vec<String> = dst_dense
1342 .iter()
1343 .map(|&d| {
1344 interner
1345 .to_id(d)
1346 .expect("interned dst id must resolve")
1347 .to_string()
1348 })
1349 .collect();
1350
1351 expand_hydrate_and_align(
1352 wide, src_indices, dst_ids, snapshot, catalog, dst_type, dst_var, dst_filters, params,
1353 )
1354 .await
1355}
1356
1357async fn expand_hydrate_and_align(
1361 wide: &mut RecordBatch,
1362 src_indices: Vec<u32>,
1363 dst_ids: Vec<String>,
1364 snapshot: &Snapshot,
1365 catalog: &Catalog,
1366 dst_type: &str,
1367 dst_var: &str,
1368 dst_filters: &[IRFilter],
1369 params: &ParamMap,
1370) -> Result<()> {
1371 let non_pushable: Vec<&IRFilter> = dst_filters
1376 .iter()
1377 .filter(|f| ir_filter_to_expr(f, params, None).is_none())
1378 .collect();
1379
1380 let mut unique_dst_list: Vec<String> = Vec::new();
1382 {
1383 let mut seen: HashSet<&str> = HashSet::with_capacity(dst_ids.len());
1384 for id in &dst_ids {
1385 if seen.insert(id.as_str()) {
1386 unique_dst_list.push(id.clone());
1387 }
1388 }
1389 }
1390 let dst_batch =
1391 hydrate_nodes(snapshot, catalog, dst_type, &unique_dst_list, dst_filters, params).await?;
1392
1393 let dst_batch_id_col = dst_batch
1395 .column_by_name("id")
1396 .ok_or_else(|| OmniError::manifest("hydrated batch missing 'id' column".to_string()))?
1397 .as_any()
1398 .downcast_ref::<StringArray>()
1399 .ok_or_else(|| OmniError::manifest("hydrated 'id' column is not Utf8".to_string()))?;
1400 let mut id_to_row: HashMap<&str, u32> = HashMap::with_capacity(dst_batch_id_col.len());
1401 for row in 0..dst_batch_id_col.len() {
1402 id_to_row.insert(dst_batch_id_col.value(row), row as u32);
1403 }
1404
1405 let mut final_src_indices: Vec<u32> = Vec::with_capacity(src_indices.len());
1407 let mut dst_indices: Vec<u32> = Vec::with_capacity(src_indices.len());
1408 for (&src_idx, dst_id) in src_indices.iter().zip(dst_ids.iter()) {
1409 if let Some(&dst_row) = id_to_row.get(dst_id.as_str()) {
1410 final_src_indices.push(src_idx);
1411 dst_indices.push(dst_row);
1412 }
1413 }
1414
1415 let src_take = UInt32Array::from(final_src_indices);
1416 let dst_take = UInt32Array::from(dst_indices);
1417 let expanded_wide = take_batch(wide, &src_take)?;
1418 let dst_prefixed = prefix_batch(&dst_batch, dst_var)?;
1419 let aligned_dst = take_batch(&dst_prefixed, &dst_take)?;
1420 *wide = hconcat_batches(&expanded_wide, &aligned_dst)?;
1421
1422 for f in &non_pushable {
1423 apply_filter(wide, f, params)?;
1424 }
1425 Ok(())
1426}
1427
1428async fn execute_expand_csr(
1432 wide: &mut RecordBatch,
1433 graph_index: &GraphIndex,
1434 snapshot: &Snapshot,
1435 catalog: &Catalog,
1436 src_var: &str,
1437 dst_var: &str,
1438 edge_type: &str,
1439 direction: Direction,
1440 dst_type: &str,
1441 min_hops: u32,
1442 max_hops: Option<u32>,
1443 dst_filters: &[IRFilter],
1444 params: &ParamMap,
1445) -> Result<()> {
1446 let src_id_col_name = format!("{}.id", src_var);
1447 let src_ids = wide
1448 .column_by_name(&src_id_col_name)
1449 .ok_or_else(|| {
1450 OmniError::manifest(format!("wide batch missing '{}' column", src_id_col_name))
1451 })?
1452 .as_any()
1453 .downcast_ref::<StringArray>()
1454 .ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", src_id_col_name)))?
1455 .clone();
1456
1457 let edge_def = catalog
1459 .edge_types
1460 .get(edge_type)
1461 .ok_or_else(|| OmniError::manifest(format!("unknown edge type '{}'", edge_type)))?;
1462
1463 let (src_type_name, dst_type_name) = match direction {
1464 Direction::Out => (&edge_def.from_type, &edge_def.to_type),
1465 Direction::In => (&edge_def.to_type, &edge_def.from_type),
1466 };
1467
1468 let src_type_idx = graph_index
1469 .type_index(src_type_name)
1470 .ok_or_else(|| OmniError::manifest(format!("no type index for '{}'", src_type_name)))?;
1471 let dst_type_idx = graph_index
1472 .type_index(dst_type_name)
1473 .ok_or_else(|| OmniError::manifest(format!("no type index for '{}'", dst_type_name)))?;
1474
1475 let adj = match direction {
1476 Direction::Out => graph_index.csr(edge_type),
1477 Direction::In => graph_index.csc(edge_type),
1478 }
1479 .ok_or_else(|| OmniError::manifest(format!("no adjacency index for edge '{}'", edge_type)))?;
1480
1481 let max = max_hops.unwrap_or(min_hops.max(1));
1482
1483 let same_type = src_type_name == dst_type_name;
1484 let max = if same_type { max } else { max.min(1) };
1487
1488 let mut src_indices: Vec<u32> = Vec::new();
1492 let mut dst_dense_list: Vec<u32> = Vec::new();
1493 for i in 0..src_ids.len() {
1494 let src_id = src_ids.value(i);
1495 let Some(src_dense) = src_type_idx.to_dense(src_id) else {
1496 continue;
1497 };
1498
1499 let mut frontier: Vec<u32> = vec![src_dense];
1501 let mut visited: HashSet<u32> = HashSet::new();
1502 let mut seen_dst_dense: HashSet<u32> = HashSet::new();
1503 if same_type {
1507 visited.insert(src_dense);
1508 }
1509
1510 for hop in 1..=max {
1511 let mut next_frontier = Vec::new();
1512 for &node in &frontier {
1513 for &neighbor in adj.neighbors(node) {
1514 if !same_type || visited.insert(neighbor) {
1515 next_frontier.push(neighbor);
1516 if hop >= min_hops && seen_dst_dense.insert(neighbor) {
1517 src_indices.push(i as u32);
1518 dst_dense_list.push(neighbor);
1519 }
1520 }
1521 }
1522 }
1523 frontier = next_frontier;
1524 if frontier.is_empty() {
1525 break;
1526 }
1527 }
1528 }
1529
1530 let mut tail_src_indices: Vec<u32> = Vec::with_capacity(src_indices.len());
1534 let mut dst_ids: Vec<String> = Vec::with_capacity(dst_dense_list.len());
1535 for (&s, &d) in src_indices.iter().zip(dst_dense_list.iter()) {
1536 if let Some(id) = dst_type_idx.to_id(d) {
1537 tail_src_indices.push(s);
1538 dst_ids.push(id.to_string());
1539 }
1540 }
1541
1542 expand_hydrate_and_align(
1543 wide,
1544 tail_src_indices,
1545 dst_ids,
1546 snapshot,
1547 catalog,
1548 dst_type,
1549 dst_var,
1550 dst_filters,
1551 params,
1552 )
1553 .await
1554}
1555
1556async fn hydrate_nodes(
1567 snapshot: &Snapshot,
1568 catalog: &Catalog,
1569 type_name: &str,
1570 ids: &[String],
1571 dst_filters: &[IRFilter],
1572 params: &ParamMap,
1573) -> Result<RecordBatch> {
1574 use datafusion::prelude::{col, lit};
1575
1576 let node_type = catalog
1577 .node_types
1578 .get(type_name)
1579 .ok_or_else(|| OmniError::manifest(format!("unknown node type '{}'", type_name)))?;
1580
1581 if ids.is_empty() {
1582 return Ok(RecordBatch::new_empty(node_type.arrow_schema.clone()));
1583 }
1584
1585 let table_key = format!("node:{}", type_name);
1586 let ds = snapshot.open(&table_key).await?;
1587
1588 let id_list: Vec<datafusion::prelude::Expr> = ids.iter().map(|id| lit(id.clone())).collect();
1590 let mut filter_expr = col("id").in_list(id_list, false);
1591 if let Some(dst_expr) = build_lance_filter_expr(dst_filters, params, Some(&node_type.arrow_schema))
1592 {
1593 filter_expr = filter_expr.and(dst_expr);
1594 }
1595
1596 let has_blobs = !node_type.blob_properties.is_empty();
1597 let non_blob_cols: Vec<&str> = node_type
1598 .arrow_schema
1599 .fields()
1600 .iter()
1601 .filter(|f| !node_type.blob_properties.contains(f.name()))
1602 .map(|f| f.name().as_str())
1603 .collect();
1604 let projection = has_blobs.then_some(non_blob_cols.as_slice());
1605 let batches = crate::table_store::TableStore::scan_stream_with(
1606 &ds,
1607 projection,
1608 None,
1609 None,
1610 false,
1611 |scanner| {
1612 scanner.filter_expr(filter_expr);
1613 Ok(())
1614 },
1615 )
1616 .await?
1617 .try_collect::<Vec<RecordBatch>>()
1618 .await
1619 .map_err(|e| OmniError::Lance(e.to_string()))?;
1620
1621 let scan_result = if batches.is_empty() {
1622 return Ok(RecordBatch::new_empty(node_type.arrow_schema.clone()));
1623 } else if batches.len() == 1 {
1624 batches.into_iter().next().unwrap()
1625 } else {
1626 let schema = batches[0].schema();
1627 arrow_select::concat::concat_batches(&schema, &batches)
1628 .map_err(|e| OmniError::Lance(e.to_string()))?
1629 };
1630
1631 if has_blobs {
1632 return add_null_blob_columns(&scan_result, node_type);
1633 }
1634 Ok(scan_result)
1635}
1636
1637fn bulk_anti_join_applies(inner_pipeline: &[IROp], outer_var: &str) -> bool {
1642 matches!(
1643 inner_pipeline,
1644 [IROp::Expand { src_var, dst_filters, min_hops, max_hops, .. }]
1645 if src_var == outer_var
1646 && dst_filters.is_empty()
1647 && *min_hops == 1
1652 && (*max_hops).unwrap_or(1) == 1
1653 )
1654}
1655
1656fn try_bulk_anti_join_mask(
1659 wide: &RecordBatch,
1660 inner_pipeline: &[IROp],
1661 graph_index: Option<&GraphIndex>,
1662 catalog: &Catalog,
1663 outer_var: &str,
1664) -> Option<BooleanArray> {
1665 if !bulk_anti_join_applies(inner_pipeline, outer_var) {
1666 return None;
1667 }
1668 let IROp::Expand {
1669 edge_type,
1670 direction,
1671 ..
1672 } = &inner_pipeline[0]
1673 else {
1674 return None;
1675 };
1676 let gi = graph_index?;
1677 let edge_def = catalog.edge_types.get(edge_type.as_str())?;
1678
1679 let src_type_name = match direction {
1680 Direction::Out => &edge_def.from_type,
1681 Direction::In => &edge_def.to_type,
1682 };
1683 let adj = match direction {
1684 Direction::Out => gi.csr(edge_type),
1685 Direction::In => gi.csc(edge_type),
1686 }?;
1687 let type_idx = gi.type_index(src_type_name)?;
1688
1689 let id_col_name = format!("{}.id", outer_var);
1690 let outer_ids = wide
1691 .column_by_name(&id_col_name)?
1692 .as_any()
1693 .downcast_ref::<StringArray>()?;
1694
1695 let keep_mask: Vec<bool> = (0..outer_ids.len())
1696 .map(|i| {
1697 let id = outer_ids.value(i);
1698 match type_idx.to_dense(id) {
1699 Some(dense) => !adj.has_neighbors(dense),
1700 None => true, }
1702 })
1703 .collect();
1704
1705 Some(BooleanArray::from(keep_mask))
1706}
1707
1708async fn execute_anti_join(
1710 wide: &mut RecordBatch,
1711 inner_pipeline: &[IROp],
1712 params: &ParamMap,
1713 snapshot: &Snapshot,
1714 graph_index: &GraphIndexHandle<'_>,
1715 catalog: &Catalog,
1716 outer_var: &str,
1717) -> Result<()> {
1718 let gi = if bulk_anti_join_applies(inner_pipeline, outer_var) {
1723 graph_index.get().await?
1724 } else {
1725 None
1726 };
1727 if let Some(mask) = try_bulk_anti_join_mask(wide, inner_pipeline, gi, catalog, outer_var) {
1729 *wide = arrow_select::filter::filter_record_batch(wide, &mask)
1730 .map_err(|e| OmniError::Lance(e.to_string()))?;
1731 return Ok(());
1732 }
1733
1734 let num_rows = wide.num_rows();
1742 if num_rows == 0 {
1743 return Ok(());
1744 }
1745
1746 let tag_col: String = {
1755 let mut n = 0usize;
1756 loop {
1757 let candidate = format!("__antijoin_outer_row_{n}");
1758 if wide.schema().column_with_name(&candidate).is_none() {
1759 break candidate;
1760 }
1761 n += 1;
1762 }
1763 };
1764 let mut fields: Vec<Field> = wide
1765 .schema()
1766 .fields()
1767 .iter()
1768 .map(|f| f.as_ref().clone())
1769 .collect();
1770 fields.push(Field::new(tag_col.as_str(), DataType::UInt32, false));
1771 let mut columns: Vec<ArrayRef> = wide.columns().to_vec();
1772 columns.push(Arc::new(UInt32Array::from_iter_values(0..num_rows as u32)));
1773 let tagged = RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
1774 .map_err(|e| OmniError::Lance(e.to_string()))?;
1775
1776 let mut inner_wide: Option<RecordBatch> = Some(tagged);
1777 let no_search = SearchMode::default();
1778 execute_pipeline(
1779 inner_pipeline,
1780 params,
1781 snapshot,
1782 graph_index,
1783 catalog,
1784 &mut inner_wide,
1785 &no_search,
1786 )
1787 .await?;
1788
1789 let mut matched: HashSet<u32> = HashSet::new();
1793 if let Some(batch) = inner_wide {
1794 if batch.num_rows() > 0 {
1795 let tags = batch
1796 .column_by_name(tag_col.as_str())
1797 .ok_or_else(|| {
1798 OmniError::manifest(
1799 "anti-join inner pipeline dropped the correlation column".to_string(),
1800 )
1801 })?
1802 .as_any()
1803 .downcast_ref::<UInt32Array>()
1804 .ok_or_else(|| {
1805 OmniError::manifest(format!("'{}' column is not UInt32", tag_col))
1806 })?;
1807 for i in 0..tags.len() {
1808 matched.insert(tags.value(i));
1809 }
1810 }
1811 }
1812
1813 let keep_mask: Vec<bool> = (0..num_rows as u32).map(|i| !matched.contains(&i)).collect();
1814 let mask = BooleanArray::from(keep_mask);
1815 *wide = arrow_select::filter::filter_record_batch(wide, &mask)
1816 .map_err(|e| OmniError::Lance(e.to_string()))?;
1817 Ok(())
1818}
1819
1820async fn execute_node_scan(
1822 type_name: &str,
1823 variable: &str,
1824 filters: &[IRFilter],
1825 params: &ParamMap,
1826 snapshot: &Snapshot,
1827 catalog: &Catalog,
1828 search_mode: &SearchMode,
1829) -> Result<RecordBatch> {
1830 let table_key = format!("node:{}", type_name);
1831 let ds = snapshot.open(&table_key).await?;
1832
1833 let node_type = &catalog.node_types[type_name];
1834
1835 let filter_expr = build_lance_filter_expr(filters, params, Some(&node_type.arrow_schema));
1846
1847 let has_blobs = !node_type.blob_properties.is_empty();
1851 let non_blob_cols: Vec<&str> = node_type
1852 .arrow_schema
1853 .fields()
1854 .iter()
1855 .filter(|f| !node_type.blob_properties.contains(f.name()))
1856 .map(|f| f.name().as_str())
1857 .collect();
1858 let projection = has_blobs.then_some(non_blob_cols.as_slice());
1859 let batches = crate::table_store::TableStore::scan_stream_with(
1860 &ds,
1861 projection,
1862 None,
1863 None,
1864 false,
1865 |scanner| {
1866 if let Some(ref expr) = filter_expr {
1868 scanner.filter_expr(expr.clone());
1869 }
1870
1871 for filter in filters {
1873 if is_search_filter(filter) {
1874 if let Some(fts_query) = build_fts_query(&filter.left, params) {
1875 scanner.full_text_search(fts_query).map_err(|e| {
1876 OmniError::Lance(format!("full_text_search filter: {}", e))
1877 })?;
1878 }
1879 }
1880 }
1881
1882 if let Some((ref var, ref prop, ref vec, k)) = search_mode.nearest {
1884 if var == variable {
1885 let query_arr = Float32Array::from(vec.clone());
1886 scanner
1887 .nearest(prop, &query_arr, k)
1888 .map_err(|e| OmniError::Lance(format!("nearest: {}", e)))?;
1889 }
1890 }
1891
1892 if let Some((ref var, ref prop, ref text)) = search_mode.bm25 {
1894 if var == variable {
1895 let fts_query = lance_index::scalar::FullTextSearchQuery::new(text.clone())
1896 .with_column(prop.clone())
1897 .map_err(|e| OmniError::Lance(format!("fts with_column: {}", e)))?;
1898 scanner
1899 .full_text_search(fts_query)
1900 .map_err(|e| OmniError::Lance(format!("full_text_search: {}", e)))?;
1901 }
1902 }
1903 Ok(())
1904 },
1905 )
1906 .await?
1907 .try_collect::<Vec<RecordBatch>>()
1908 .await
1909 .map_err(|e| OmniError::Lance(e.to_string()))?;
1910
1911 let scan_result = if batches.is_empty() {
1912 RecordBatch::new_empty(batches.first().map(|b| b.schema()).unwrap_or_else(|| {
1913 let fields: Vec<_> = node_type
1915 .arrow_schema
1916 .fields()
1917 .iter()
1918 .filter(|f| !node_type.blob_properties.contains(f.name()))
1919 .map(|f| f.as_ref().clone())
1920 .collect();
1921 Arc::new(Schema::new(fields))
1922 }))
1923 } else if batches.len() == 1 {
1924 batches.into_iter().next().unwrap()
1925 } else {
1926 let schema = batches[0].schema();
1927 arrow_select::concat::concat_batches(&schema, &batches)
1928 .map_err(|e| OmniError::Lance(e.to_string()))?
1929 };
1930
1931 if has_blobs {
1933 return add_null_blob_columns(&scan_result, node_type);
1934 }
1935 Ok(scan_result)
1936}
1937
1938fn add_null_blob_columns(
1941 batch: &RecordBatch,
1942 node_type: &omnigraph_compiler::catalog::NodeType,
1943) -> Result<RecordBatch> {
1944 let num_rows = batch.num_rows();
1945 let mut fields = Vec::with_capacity(node_type.arrow_schema.fields().len());
1946 let mut columns: Vec<ArrayRef> = Vec::with_capacity(node_type.arrow_schema.fields().len());
1947
1948 for field in node_type.arrow_schema.fields() {
1949 if node_type.blob_properties.contains(field.name()) {
1950 fields.push(Field::new(field.name(), DataType::Utf8, true));
1951 columns.push(Arc::new(StringArray::from(vec![None::<&str>; num_rows])));
1952 } else if let Some(col) = batch.column_by_name(field.name()) {
1953 let batch_schema = batch.schema();
1954 let batch_field = batch_schema
1955 .field_with_name(field.name())
1956 .map_err(|e| OmniError::Lance(e.to_string()))?;
1957 fields.push(batch_field.clone());
1958 columns.push(col.clone());
1959 }
1960 }
1961
1962 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
1963 .map_err(|e| OmniError::Lance(e.to_string()))
1964}
1965
1966fn build_fts_query(
1968 expr: &IRExpr,
1969 params: &ParamMap,
1970) -> Option<lance_index::scalar::FullTextSearchQuery> {
1971 match expr {
1972 IRExpr::Search { field, query } => {
1973 let prop = extract_property(field)?;
1974 let q = resolve_to_string(query, params)?;
1975 lance_index::scalar::FullTextSearchQuery::new(q)
1976 .with_column(prop)
1977 .ok()
1978 }
1979 IRExpr::Fuzzy {
1980 field,
1981 query,
1982 max_edits,
1983 } => {
1984 let prop = extract_property(field)?;
1985 let q = resolve_to_string(query, params)?;
1986 let edits = max_edits
1987 .as_ref()
1988 .and_then(|e| resolve_to_int(e, params))
1989 .unwrap_or(2) as u32;
1990 lance_index::scalar::FullTextSearchQuery::new_fuzzy(q, Some(edits))
1991 .with_column(prop)
1992 .ok()
1993 }
1994 IRExpr::MatchText { field, query } => {
1995 let prop = extract_property(field)?;
1997 let q = resolve_to_string(query, params)?;
1998 lance_index::scalar::FullTextSearchQuery::new(q)
1999 .with_column(prop)
2000 .ok()
2001 }
2002 _ => None,
2003 }
2004}
2005
2006fn extract_property(expr: &IRExpr) -> Option<String> {
2008 match expr {
2009 IRExpr::PropAccess { property, .. } => Some(property.clone()),
2010 _ => None,
2011 }
2012}
2013
2014fn resolve_to_string(expr: &IRExpr, params: &ParamMap) -> Option<String> {
2016 match expr {
2017 IRExpr::Literal(Literal::String(s)) => Some(s.clone()),
2018 IRExpr::Param(name) => match params.get(name)? {
2019 Literal::String(s) => Some(s.clone()),
2020 _ => None,
2021 },
2022 _ => None,
2023 }
2024}
2025
2026fn resolve_to_int(expr: &IRExpr, params: &ParamMap) -> Option<i64> {
2028 match expr {
2029 IRExpr::Literal(Literal::Integer(n)) => Some(*n),
2030 IRExpr::Param(name) => match params.get(name)? {
2031 Literal::Integer(n) => Some(*n),
2032 _ => None,
2033 },
2034 _ => None,
2035 }
2036}
2037
2038pub(super) fn literal_to_sql(lit: &Literal) -> String {
2039 match lit {
2040 Literal::Null => "NULL".to_string(),
2041 Literal::String(s) => format!("'{}'", s.replace('\'', "''")),
2042 Literal::Integer(n) => n.to_string(),
2043 Literal::Float(f) => f.to_string(),
2044 Literal::Bool(b) => b.to_string(),
2045 Literal::Date(s) => format!("'{}'", s.replace('\'', "''")),
2046 Literal::DateTime(s) => format!("'{}'", s.replace('\'', "''")),
2047 Literal::List(_) => "NULL".to_string(), }
2049}
2050
2051pub(super) fn build_lance_filter_expr(
2076 filters: &[IRFilter],
2077 params: &ParamMap,
2078 schema: Option<&Schema>,
2079) -> Option<datafusion::prelude::Expr> {
2080 use datafusion::logical_expr::Operator;
2081 use datafusion::prelude::Expr;
2082
2083 let mut acc: Option<Expr> = None;
2084 for f in filters {
2085 let Some(e) = ir_filter_to_expr(f, params, schema) else {
2086 continue;
2087 };
2088 acc = Some(match acc {
2089 None => e,
2090 Some(prev) => Expr::BinaryExpr(datafusion::logical_expr::BinaryExpr::new(
2091 Box::new(prev),
2092 Operator::And,
2093 Box::new(e),
2094 )),
2095 });
2096 }
2097 acc
2098}
2099
2100pub(super) fn ir_filter_to_expr(
2104 filter: &IRFilter,
2105 params: &ParamMap,
2106 schema: Option<&Schema>,
2107) -> Option<datafusion::prelude::Expr> {
2108 use datafusion::functions_nested::expr_fn::array_has;
2109
2110 if is_search_filter(filter) {
2111 return None;
2112 }
2113
2114 if matches!(filter.op, CompOp::Contains) {
2120 let left = ir_expr_to_expr(&filter.left, params, None)?;
2121 let right = ir_expr_to_expr(&filter.right, params, None)?;
2122 return Some(array_has(left, right));
2123 }
2124
2125 let left_col_type = prop_data_type(&filter.left, schema);
2130 let right_col_type = prop_data_type(&filter.right, schema);
2131 let left = ir_expr_to_expr(&filter.left, params, right_col_type.as_ref())?;
2132 let right = ir_expr_to_expr(&filter.right, params, left_col_type.as_ref())?;
2133 Some(match filter.op {
2134 CompOp::Eq => left.eq(right),
2135 CompOp::Ne => left.not_eq(right),
2136 CompOp::Gt => left.gt(right),
2137 CompOp::Lt => left.lt(right),
2138 CompOp::Ge => left.gt_eq(right),
2139 CompOp::Le => left.lt_eq(right),
2140 CompOp::Contains => unreachable!("handled above"),
2141 })
2142}
2143
2144pub(super) fn ir_expr_to_expr(
2148 expr: &IRExpr,
2149 params: &ParamMap,
2150 target: Option<&arrow_schema::DataType>,
2151) -> Option<datafusion::prelude::Expr> {
2152 use datafusion::prelude::col;
2153 match expr {
2154 IRExpr::PropAccess { property, .. } => Some(col(property)),
2155 IRExpr::Literal(l) => literal_to_expr_coerced(l, target),
2156 IRExpr::Param(name) => params
2157 .get(name)
2158 .and_then(|l| literal_to_expr_coerced(l, target)),
2159 _ => None,
2160 }
2161}
2162
2163fn prop_data_type(expr: &IRExpr, schema: Option<&Schema>) -> Option<arrow_schema::DataType> {
2166 match expr {
2167 IRExpr::PropAccess { property, .. } => schema?
2168 .field_with_name(property)
2169 .ok()
2170 .map(|f| f.data_type().clone()),
2171 _ => None,
2172 }
2173}
2174
2175fn literal_to_expr_coerced(
2181 lit: &Literal,
2182 target: Option<&arrow_schema::DataType>,
2183) -> Option<datafusion::prelude::Expr> {
2184 if let Some(target) = target {
2185 if let Some(e) = literal_to_typed_expr(lit, target) {
2186 return Some(e);
2187 }
2188 }
2189 literal_to_expr(lit)
2190}
2191
2192fn literal_to_typed_expr(
2207 lit: &Literal,
2208 target: &arrow_schema::DataType,
2209) -> Option<datafusion::prelude::Expr> {
2210 use datafusion::prelude::lit as df_lit;
2211 use datafusion::scalar::ScalarValue;
2212
2213 let arr = super::projection::literal_to_array(lit, 1).ok()?;
2214 if arr.data_type() == target {
2215 return Some(df_lit(ScalarValue::try_from_array(&arr, 0).ok()?));
2216 }
2217 let casted = arrow_cast::cast::cast(&arr, target).ok()?;
2218 if target.is_integer() {
2219 let back = arrow_cast::cast::cast(&casted, arr.data_type()).ok()?;
2220 let original = ScalarValue::try_from_array(&arr, 0).ok()?;
2221 let round_tripped = ScalarValue::try_from_array(&back, 0).ok()?;
2222 if original != round_tripped {
2223 return None;
2224 }
2225 }
2226 Some(df_lit(ScalarValue::try_from_array(&casted, 0).ok()?))
2227}
2228
2229fn literal_to_expr(lit: &Literal) -> Option<datafusion::prelude::Expr> {
2236 use datafusion::prelude::lit as df_lit;
2237 Some(match lit {
2238 Literal::Null => df_lit(datafusion::scalar::ScalarValue::Null),
2239 Literal::String(s) => df_lit(s.clone()),
2240 Literal::Integer(n) => df_lit(*n),
2241 Literal::Float(f) => df_lit(*f),
2242 Literal::Bool(b) => df_lit(*b),
2243 Literal::Date(s) => df_lit(s.clone()),
2250 Literal::DateTime(s) => df_lit(s.clone()),
2251 Literal::List(_) => return None,
2252 })
2253}
2254
2255fn prefix_batch(batch: &RecordBatch, variable: &str) -> Result<RecordBatch> {
2256 let fields: Vec<Field> = batch
2257 .schema()
2258 .fields()
2259 .iter()
2260 .map(|f| {
2261 Field::new(
2262 format!("{}.{}", variable, f.name()),
2263 f.data_type().clone(),
2264 f.is_nullable(),
2265 )
2266 })
2267 .collect();
2268 let schema = Arc::new(Schema::new(fields));
2269 RecordBatch::try_new(schema, batch.columns().to_vec())
2270 .map_err(|e| OmniError::Lance(e.to_string()))
2271}
2272
2273fn cross_join_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
2274 let n = left.num_rows();
2275 let m = right.num_rows();
2276 if n == 0 || m == 0 {
2277 let mut fields: Vec<Field> = left
2278 .schema()
2279 .fields()
2280 .iter()
2281 .map(|f| f.as_ref().clone())
2282 .collect();
2283 fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
2284 return Ok(RecordBatch::new_empty(Arc::new(Schema::new(fields))));
2285 }
2286 let left_indices: Vec<u32> = (0..n as u32)
2287 .flat_map(|i| std::iter::repeat(i).take(m))
2288 .collect();
2289 let right_indices: Vec<u32> = (0..n).flat_map(|_| 0..m as u32).collect();
2290 let left_expanded = take_batch(left, &UInt32Array::from(left_indices))?;
2291 let right_expanded = take_batch(right, &UInt32Array::from(right_indices))?;
2292 hconcat_batches(&left_expanded, &right_expanded)
2293}
2294
2295fn hconcat_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
2296 let mut fields: Vec<Field> = left
2297 .schema()
2298 .fields()
2299 .iter()
2300 .map(|f| f.as_ref().clone())
2301 .collect();
2302 if cfg!(debug_assertions) {
2303 let left_schema = left.schema();
2304 let left_names: HashSet<&str> = left_schema
2305 .fields()
2306 .iter()
2307 .map(|f| f.name().as_str())
2308 .collect();
2309 let right_schema = right.schema();
2310 for f in right_schema.fields() {
2311 debug_assert!(
2312 !left_names.contains(f.name().as_str()),
2313 "hconcat_batches: duplicate column '{}'",
2314 f.name()
2315 );
2316 }
2317 }
2318 fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
2319 let mut columns: Vec<ArrayRef> = left.columns().to_vec();
2320 columns.extend(right.columns().to_vec());
2321 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
2322 .map_err(|e| OmniError::Lance(e.to_string()))
2323}
2324
2325fn take_batch(batch: &RecordBatch, indices: &UInt32Array) -> Result<RecordBatch> {
2326 let columns: Vec<ArrayRef> = batch
2327 .columns()
2328 .iter()
2329 .map(|col| arrow_select::take::take(col.as_ref(), indices, None))
2330 .collect::<std::result::Result<Vec<_>, _>>()
2331 .map_err(|e| OmniError::Lance(e.to_string()))?;
2332 RecordBatch::try_new(batch.schema(), columns).map_err(|e| OmniError::Lance(e.to_string()))
2333}
2334
2335#[cfg(test)]
2336mod expand_chooser_tests {
2337 use super::*;
2338 use crate::table_store::IndexCoverage;
2339
2340 fn inputs(
2343 frontier_rows: usize,
2344 edge_count: u64,
2345 src_node_count: u64,
2346 effective_max_hops: u32,
2347 coverage: IndexCoverage,
2348 ) -> ExpandCostInputs {
2349 ExpandCostInputs {
2350 frontier_rows,
2351 edge_count,
2352 src_node_count,
2353 effective_max_hops,
2354 max_hops_cap: 6,
2355 max_frontier_cap: 1024,
2356 coverage,
2357 csr_cached: false,
2358 }
2359 }
2360
2361 #[test]
2362 fn selective_frontier_on_large_graph_picks_indexed() {
2363 let m = choose_expand_mode(&inputs(50, 10_000_000, 1_000_000, 1, IndexCoverage::Indexed));
2366 assert_eq!(m, ExpandMode::IndexedScan);
2367 }
2368
2369 #[test]
2370 fn flat_in_edge_count_same_selectivity_same_choice() {
2371 let small = choose_expand_mode(&inputs(50, 100_000, 1_000_000, 1, IndexCoverage::Indexed));
2374 let huge =
2375 choose_expand_mode(&inputs(50, 100_000_000, 1_000_000, 1, IndexCoverage::Indexed));
2376 assert_eq!(small, ExpandMode::IndexedScan);
2377 assert_eq!(huge, ExpandMode::IndexedScan);
2378 }
2379
2380 #[test]
2381 fn frontier_large_fraction_of_source_picks_csr() {
2382 let m = choose_expand_mode(&inputs(200, 1_000, 100, 1, IndexCoverage::Indexed));
2385 assert_eq!(m, ExpandMode::Csr);
2386 }
2387
2388 #[test]
2389 fn frontier_over_hard_cap_picks_csr() {
2390 let m = choose_expand_mode(&inputs(2000, 10_000_000, 1_000_000, 1, IndexCoverage::Indexed));
2392 assert_eq!(m, ExpandMode::Csr);
2393 }
2394
2395 #[test]
2396 fn hops_over_hard_cap_picks_csr() {
2397 let m = choose_expand_mode(&inputs(10, 10_000_000, 1_000_000, 8, IndexCoverage::Indexed));
2398 assert_eq!(m, ExpandMode::Csr);
2399 }
2400
2401 #[test]
2402 fn degraded_single_hop_tiny_frontier_stays_indexed() {
2403 let m = choose_expand_mode(&inputs(
2406 5,
2407 10_000,
2408 10_000,
2409 1,
2410 IndexCoverage::Degraded {
2411 reason: "no btree".into(),
2412 },
2413 ));
2414 assert_eq!(m, ExpandMode::IndexedScan);
2415 }
2416
2417 #[test]
2418 fn degraded_multi_hop_picks_csr() {
2419 let m = choose_expand_mode(&inputs(
2421 5,
2422 10_000,
2423 10_000,
2424 2,
2425 IndexCoverage::Degraded {
2426 reason: "no btree".into(),
2427 },
2428 ));
2429 assert_eq!(m, ExpandMode::Csr);
2430 }
2431
2432 #[test]
2433 fn warm_csr_is_always_reused() {
2434 let mut i = inputs(1, 10_000_000, 1_000_000, 1, IndexCoverage::Indexed);
2437 i.csr_cached = true;
2438 assert_eq!(choose_expand_mode(&i), ExpandMode::Csr);
2439 }
2440
2441 #[test]
2442 fn cost_model_caps_cross_type_hops() {
2443 assert_eq!(cost_effective_hops(5, true), 5);
2446 assert_eq!(cost_effective_hops(5, false), 1);
2447 assert_eq!(cost_effective_hops(1, false), 1);
2448
2449 let mut i = inputs(50, 10_000, 100, cost_effective_hops(5, false), IndexCoverage::Indexed);
2453 assert_eq!(choose_expand_mode(&i), ExpandMode::IndexedScan);
2454 i.effective_max_hops = 5; assert_eq!(choose_expand_mode(&i), ExpandMode::Csr);
2456 }
2457}
2458
2459#[cfg(test)]
2460mod literal_lowering_tests {
2461 use super::*;
2462 use datafusion::prelude::Expr;
2463 use datafusion::scalar::ScalarValue;
2464
2465 #[test]
2471 fn date_literals_coerce_to_typed_arrow_scalars() {
2472 use arrow_schema::DataType;
2473 let dt = literal_to_expr_coerced(
2474 &Literal::DateTime("2024-06-01T12:00:00Z".into()),
2475 Some(&DataType::Date64),
2476 )
2477 .unwrap();
2478 assert!(
2479 matches!(dt, Expr::Literal(ScalarValue::Date64(Some(_)), ..)),
2480 "DateTime vs Date64 column must coerce to a typed Date64, got {dt:?}"
2481 );
2482 let d = literal_to_expr_coerced(&Literal::Date("2024-06-01".into()), Some(&DataType::Date32))
2483 .unwrap();
2484 assert!(
2485 matches!(d, Expr::Literal(ScalarValue::Date32(Some(_)), ..)),
2486 "Date vs Date32 column must coerce to a typed Date32, got {d:?}"
2487 );
2488 let nat = literal_to_expr_coerced(&Literal::Date("2024-06-01".into()), None).unwrap();
2489 assert!(
2490 matches!(nat, Expr::Literal(ScalarValue::Utf8(Some(_)), ..)),
2491 "no target should keep the natural Utf8 date literal, got {nat:?}"
2492 );
2493 }
2494
2495 #[test]
2498 fn malformed_date_literal_falls_back_to_string() {
2499 use arrow_schema::DataType;
2500 let bad = literal_to_expr_coerced(
2501 &Literal::DateTime("not-a-date".into()),
2502 Some(&DataType::Date64),
2503 )
2504 .unwrap();
2505 assert!(
2506 matches!(bad, Expr::Literal(ScalarValue::Utf8(Some(_)), ..)),
2507 "malformed DateTime literal should fall back to a Utf8 literal, got {bad:?}"
2508 );
2509 }
2510
2511 #[test]
2516 fn integer_literal_coerces_to_narrow_column_type() {
2517 use arrow_schema::DataType;
2518 let i32_lit = literal_to_expr_coerced(&Literal::Integer(5), Some(&DataType::Int32)).unwrap();
2519 assert!(
2520 matches!(i32_lit, Expr::Literal(ScalarValue::Int32(Some(5)), ..)),
2521 "integer literal vs Int32 column must lower to Int32, got {i32_lit:?}"
2522 );
2523 let u32_lit = literal_to_expr_coerced(&Literal::Integer(7), Some(&DataType::UInt32)).unwrap();
2524 assert!(
2525 matches!(u32_lit, Expr::Literal(ScalarValue::UInt32(Some(7)), ..)),
2526 "integer literal vs UInt32 column must lower to UInt32, got {u32_lit:?}"
2527 );
2528 }
2529
2530 #[test]
2531 fn float_literal_coerces_to_f32_column_type() {
2532 use arrow_schema::DataType;
2533 let f32_lit =
2534 literal_to_expr_coerced(&Literal::Float(1.5), Some(&DataType::Float32)).unwrap();
2535 assert!(
2536 matches!(f32_lit, Expr::Literal(ScalarValue::Float32(Some(_)), ..)),
2537 "float literal vs Float32 column must lower to Float32, got {f32_lit:?}"
2538 );
2539 }
2540
2541 #[test]
2545 fn fractional_float_vs_int_column_falls_back_not_truncate() {
2546 use arrow_schema::DataType;
2547 let e = literal_to_expr_coerced(&Literal::Float(2.7), Some(&DataType::Int32)).unwrap();
2548 assert!(
2549 matches!(e, Expr::Literal(ScalarValue::Float64(Some(_)), ..)),
2550 "fractional float vs Int32 must fall back to natural Float64, got {e:?}"
2551 );
2552 }
2553
2554 #[test]
2556 fn whole_float_vs_int_column_coerces() {
2557 use arrow_schema::DataType;
2558 let e = literal_to_expr_coerced(&Literal::Float(2.0), Some(&DataType::Int32)).unwrap();
2559 assert!(
2560 matches!(e, Expr::Literal(ScalarValue::Int32(Some(2)), ..)),
2561 "whole-number float vs Int32 is lossless and must coerce to Int32(2), got {e:?}"
2562 );
2563 }
2564
2565 #[test]
2568 fn out_of_range_int_vs_narrow_column_falls_back() {
2569 use arrow_schema::DataType;
2570 let e = literal_to_expr_coerced(&Literal::Integer(3_000_000_000), Some(&DataType::Int32))
2571 .unwrap();
2572 assert!(
2573 matches!(e, Expr::Literal(ScalarValue::Int64(Some(3_000_000_000)), ..)),
2574 "out-of-range integer vs Int32 must fall back to natural Int64, got {e:?}"
2575 );
2576 }
2577
2578 #[test]
2582 fn float_vs_f32_column_coerces_even_when_not_exactly_representable() {
2583 use arrow_schema::DataType;
2584 let e = literal_to_expr_coerced(&Literal::Float(0.1), Some(&DataType::Float32)).unwrap();
2585 assert!(
2586 matches!(e, Expr::Literal(ScalarValue::Float32(Some(_)), ..)),
2587 "float target must coerce 0.1 to Float32 (exempt from lossless guard), got {e:?}"
2588 );
2589 }
2590
2591 #[test]
2594 fn literal_without_target_keeps_natural_width() {
2595 let nat = literal_to_expr_coerced(&Literal::Integer(5), None).unwrap();
2596 assert!(
2597 matches!(nat, Expr::Literal(ScalarValue::Int64(Some(5)), ..)),
2598 "no target should keep the natural Int64 width, got {nat:?}"
2599 );
2600 }
2601
2602 fn binary_has_int32_literal(e: &Expr) -> bool {
2604 if let Expr::BinaryExpr(b) = e {
2605 [b.left.as_ref(), b.right.as_ref()]
2606 .iter()
2607 .any(|side| matches!(side, Expr::Literal(ScalarValue::Int32(Some(_)), ..)))
2608 } else {
2609 false
2610 }
2611 }
2612
2613 fn int32_schema() -> arrow_schema::Schema {
2614 use arrow_schema::{DataType, Field};
2615 arrow_schema::Schema::new(vec![Field::new("count", DataType::Int32, true)])
2616 }
2617
2618 fn count_prop() -> IRExpr {
2619 IRExpr::PropAccess {
2620 variable: "m".into(),
2621 property: "count".into(),
2622 }
2623 }
2624
2625 #[test]
2629 fn ir_filter_coerces_literal_for_range_op() {
2630 let schema = int32_schema();
2631 let filter = IRFilter {
2632 left: count_prop(),
2633 op: CompOp::Ge,
2634 right: IRExpr::Literal(Literal::Integer(2)),
2635 };
2636 let expr = ir_filter_to_expr(&filter, &ParamMap::new(), Some(&schema)).unwrap();
2637 assert!(
2638 binary_has_int32_literal(&expr),
2639 "range-op literal must coerce to the Int32 column type, got {expr:?}"
2640 );
2641 }
2642
2643 #[test]
2646 fn ir_filter_coerces_literal_when_column_is_on_the_right() {
2647 let schema = int32_schema();
2648 let filter = IRFilter {
2649 left: IRExpr::Literal(Literal::Integer(2)),
2650 op: CompOp::Lt,
2651 right: count_prop(),
2652 };
2653 let expr = ir_filter_to_expr(&filter, &ParamMap::new(), Some(&schema)).unwrap();
2654 assert!(
2655 binary_has_int32_literal(&expr),
2656 "reversed-operand literal must coerce to the Int32 column type, got {expr:?}"
2657 );
2658 }
2659}