1use std::marker::PhantomData;
2use std::sync::Arc;
3
4use arrow::array::{ArrayRef, RecordBatch, UInt64Array};
5use arrow::buffer::BooleanBuffer;
6use arrow::datatypes::Schema;
7use croaring::Treemap;
8use llkv_column_map::store::{GatherNullPolicy, MultiGatherContext};
9use llkv_compute::analysis::computed_expr_requires_numeric;
10use llkv_compute::eval::{NumericArrayMap as ComputeNumericArrayMap, ScalarEvaluator};
11use llkv_compute::projection::{ComputedLiteralInfo, synthesize_computed_literal_array};
12use llkv_expr::ScalarExpr;
13use llkv_result::Result as LlkvResult;
14use llkv_types::{FieldId, LogicalFieldId, RowId, TableId};
15use rustc_hash::{FxHashMap, FxHashSet};
16use simd_r_drive_entry_handle::EntryHandle;
17
18use crate::ScanStorage;
19
20pub type NumericArrayMap = ComputeNumericArrayMap<FieldId>;
21
22pub enum RowIdSource {
23 Bitmap(Treemap),
24 Vector(Vec<RowId>),
25}
26
27impl From<Treemap> for RowIdSource {
28 fn from(bitmap: Treemap) -> Self {
29 RowIdSource::Bitmap(bitmap)
30 }
31}
32
33impl From<Vec<RowId>> for RowIdSource {
34 fn from(vector: Vec<RowId>) -> Self {
35 RowIdSource::Vector(vector)
36 }
37}
38
39impl From<&Treemap> for RowIdSource {
40 fn from(bitmap: &Treemap) -> Self {
41 RowIdSource::Vector(bitmap.iter().collect())
42 }
43}
44
45impl From<&[RowId]> for RowIdSource {
46 fn from(slice: &[RowId]) -> Self {
47 RowIdSource::Vector(slice.to_vec())
48 }
49}
50
51impl From<&Vec<RowId>> for RowIdSource {
52 fn from(vec: &Vec<RowId>) -> Self {
53 RowIdSource::Vector(vec.clone())
54 }
55}
56
57pub trait ColumnSliceSet<'a> {
58 fn len(&self) -> usize;
59 fn is_empty(&self) -> bool;
60 fn column(&self, idx: usize) -> &'a ArrayRef;
61 fn columns(&self) -> &'a [ArrayRef];
62}
63
64pub struct ColumnSlices<'a> {
65 columns: &'a [ArrayRef],
66}
67
68impl<'a> ColumnSlices<'a> {
69 pub fn new(columns: &'a [ArrayRef]) -> Self {
70 Self { columns }
71 }
72}
73
74impl<'a> ColumnSliceSet<'a> for ColumnSlices<'a> {
75 fn len(&self) -> usize {
76 self.columns.len()
77 }
78
79 fn is_empty(&self) -> bool {
80 self.columns.is_empty()
81 }
82
83 fn column(&self, idx: usize) -> &'a ArrayRef {
84 &self.columns[idx]
85 }
86
87 fn columns(&self) -> &'a [ArrayRef] {
88 self.columns
89 }
90}
91
92pub struct RowChunk<'a, C> {
93 pub row_ids: Option<&'a UInt64Array>,
94 pub columns: C,
95 pub visibility: Option<&'a BooleanBuffer>,
96 record_batch: &'a RecordBatch,
97}
98
99impl<'a, C> RowChunk<'a, C> {
100 pub fn record_batch(&self) -> &'a RecordBatch {
101 self.record_batch
102 }
103
104 pub fn to_record_batch(&self) -> RecordBatch {
105 self.record_batch.clone()
106 }
107
108 pub fn row_ids(&self) -> Option<&'a UInt64Array> {
109 self.row_ids
110 }
111}
112
113pub trait RowStream {
114 type Columns<'a>: ColumnSliceSet<'a>
115 where
116 Self: 'a;
117
118 fn schema(&self) -> &Arc<Schema>;
119
120 fn next_chunk<'a>(&'a mut self) -> LlkvResult<Option<RowChunk<'a, Self::Columns<'a>>>>;
121}
122
123#[derive(Clone)]
124pub struct ColumnProjectionInfo {
125 pub logical_field_id: LogicalFieldId,
126 pub data_type: arrow::datatypes::DataType,
127 pub output_name: String,
128}
129
130pub type ComputedProjectionInfo = ComputedLiteralInfo<FieldId>;
131
132#[derive(Clone)]
133pub enum ProjectionEval {
134 Column(ColumnProjectionInfo),
135 Computed(ComputedProjectionInfo),
136}
137
138#[derive(Clone)]
139pub enum ProjectionPlan {
140 Column {
141 source_idx: usize,
142 },
143 Computed {
144 eval_idx: usize,
145 passthrough_idx: Option<usize>,
146 },
147}
148
149pub struct RowStreamBuilder<'storage, P, S>
150where
151 P: llkv_storage::pager::Pager<Blob = EntryHandle> + Send + Sync,
152 S: ScanStorage<P>,
153{
154 storage: &'storage S,
155 table_id: TableId,
156 schema: Arc<Schema>,
157 unique_lfids: Arc<Vec<LogicalFieldId>>,
158 projection_evals: Arc<Vec<ProjectionEval>>,
159 _passthrough_fields: Arc<Vec<Option<FieldId>>>,
160 unique_index: Arc<FxHashMap<LogicalFieldId, usize>>,
161 numeric_fields: Arc<FxHashSet<FieldId>>,
162 requires_numeric: bool,
163 null_policy: GatherNullPolicy,
164 row_ids: RowIdSource,
165 chunk_size: usize,
166 gather_ctx: Option<MultiGatherContext>,
167 phantom: PhantomData<P>,
168 include_row_ids: bool,
169}
170
171impl<'storage, P, S> RowStreamBuilder<'storage, P, S>
172where
173 P: llkv_storage::pager::Pager<Blob = EntryHandle> + Send + Sync,
174 S: ScanStorage<P>,
175{
176 #[allow(clippy::too_many_arguments)]
177 pub fn new(
178 storage: &'storage S,
179 table_id: TableId,
180 schema: Arc<Schema>,
181 unique_lfids: Arc<Vec<LogicalFieldId>>,
182 projection_evals: Arc<Vec<ProjectionEval>>,
183 passthrough_fields: Arc<Vec<Option<FieldId>>>,
184 unique_index: Arc<FxHashMap<LogicalFieldId, usize>>,
185 numeric_fields: Arc<FxHashSet<FieldId>>,
186 requires_numeric: bool,
187 null_policy: GatherNullPolicy,
188 row_ids: impl Into<RowIdSource>,
189 chunk_size: usize,
190 include_row_ids: bool,
191 ) -> Self {
192 Self {
193 storage,
194 table_id,
195 schema,
196 unique_lfids,
197 projection_evals,
198 _passthrough_fields: passthrough_fields,
199 unique_index,
200 numeric_fields,
201 requires_numeric,
202 null_policy,
203 row_ids: row_ids.into(),
204 chunk_size,
205 gather_ctx: None,
206 phantom: PhantomData,
207 include_row_ids,
208 }
209 }
210
211 pub fn with_gather_context(mut self, ctx: MultiGatherContext) -> Self {
212 self.gather_ctx = Some(ctx);
213 self
214 }
215
216 pub fn build(self) -> LlkvResult<ScanRowStream<'storage, P, S>> {
217 let RowStreamBuilder {
218 storage,
219 table_id,
220 schema,
221 unique_lfids,
222 projection_evals,
223 _passthrough_fields: passthrough_fields,
224 unique_index,
225 numeric_fields,
226 requires_numeric,
227 null_policy,
228 row_ids,
229 chunk_size,
230 mut gather_ctx,
231 include_row_ids,
232 ..
233 } = self;
234
235 let gather_ctx = if unique_lfids.is_empty() {
236 None
237 } else if let Some(ctx) = gather_ctx.take() {
238 Some(ctx)
239 } else {
240 Some(storage.prepare_gather_context(unique_lfids.as_ref())?)
241 };
242
243 let (row_id_array, total_rows) = match row_ids {
244 RowIdSource::Bitmap(bitmap) => {
245 let len = bitmap.cardinality();
246 let array = Arc::new(UInt64Array::from_iter_values(bitmap.iter()));
247 (array, len as usize)
248 }
249 RowIdSource::Vector(vector) => {
250 let len = vector.len();
251 let array = Arc::new(UInt64Array::from(vector));
252 (array, len)
253 }
254 };
255
256 let columns_buf = Vec::with_capacity(projection_evals.len());
257 let numeric_arrays_cache = if requires_numeric {
258 Some(FxHashMap::default())
259 } else {
260 None
261 };
262
263 let mut projection_plan = Vec::with_capacity(projection_evals.len());
264 for (idx, eval) in projection_evals.iter().enumerate() {
265 match eval {
266 ProjectionEval::Column(info) => {
267 let arr_idx = *unique_index
268 .get(&info.logical_field_id)
269 .expect("logical field id missing from index");
270 projection_plan.push(ProjectionPlan::Column {
271 source_idx: arr_idx,
272 });
273 }
274 ProjectionEval::Computed(_) => {
275 let passthrough_idx = passthrough_fields[idx].as_ref().map(|fid| {
276 let lfid = LogicalFieldId::for_user(table_id, *fid);
277 *unique_index
278 .get(&lfid)
279 .expect("passthrough field missing from index")
280 });
281 projection_plan.push(ProjectionPlan::Computed {
282 eval_idx: idx,
283 passthrough_idx,
284 });
285 }
286 }
287 }
288
289 let chunk_ranges = if total_rows == 0 {
290 Vec::new()
291 } else {
292 (0..total_rows)
293 .step_by(chunk_size.max(1))
294 .map(|start| {
295 let end = (start + chunk_size).min(total_rows);
296 (start, end)
297 })
298 .collect()
299 };
300
301 Ok(ScanRowStream {
302 storage,
303 table_id,
304 schema,
305 unique_lfids,
306 projection_evals,
307 _passthrough_fields: passthrough_fields,
308 unique_index,
309 numeric_fields,
310 requires_numeric,
311 null_policy,
312 row_ids: row_id_array,
313 chunk_ranges,
314 range_idx: 0,
315 gather_ctx,
316 current_batch: None,
317 current_row_ids: None,
318 columns_buf,
319 numeric_arrays_cache,
320 phantom: PhantomData,
321 emit_row_ids: include_row_ids,
322 projection_plan,
323 })
324 }
325}
326
327pub struct ScanRowStream<'storage, P, S>
328where
329 P: llkv_storage::pager::Pager<Blob = EntryHandle> + Send + Sync,
330 S: ScanStorage<P>,
331{
332 storage: &'storage S,
333 table_id: TableId,
334 schema: Arc<Schema>,
335 unique_lfids: Arc<Vec<LogicalFieldId>>,
336 projection_evals: Arc<Vec<ProjectionEval>>,
337 _passthrough_fields: Arc<Vec<Option<FieldId>>>,
338 unique_index: Arc<FxHashMap<LogicalFieldId, usize>>,
339 numeric_fields: Arc<FxHashSet<FieldId>>,
340 requires_numeric: bool,
341 null_policy: GatherNullPolicy,
342 row_ids: Arc<UInt64Array>,
343 chunk_ranges: Vec<(usize, usize)>,
344 range_idx: usize,
345 gather_ctx: Option<MultiGatherContext>,
346 current_batch: Option<RecordBatch>,
347 current_row_ids: Option<ArrayRef>,
348 columns_buf: Vec<ArrayRef>,
349 numeric_arrays_cache: Option<NumericArrayMap>,
350 phantom: PhantomData<P>,
351 emit_row_ids: bool,
352 projection_plan: Vec<ProjectionPlan>,
353}
354
355impl<'storage, P, S> RowStream for ScanRowStream<'storage, P, S>
356where
357 P: llkv_storage::pager::Pager<Blob = EntryHandle> + Send + Sync,
358 S: ScanStorage<P>,
359{
360 type Columns<'a>
361 = ColumnSlices<'a>
362 where
363 Self: 'a;
364
365 fn schema(&self) -> &Arc<Schema> {
366 &self.schema
367 }
368
369 fn next_chunk<'a>(&'a mut self) -> LlkvResult<Option<RowChunk<'a, Self::Columns<'a>>>> {
370 while self.range_idx < self.chunk_ranges.len() {
371 let (start, end) = self.chunk_ranges[self.range_idx];
372 self.range_idx += 1;
373
374 let values = self.row_ids.values();
375 let window: &[RowId] = &values[start..end];
376
377 let unique_lfids = Arc::clone(&self.unique_lfids);
378 let projection_evals = Arc::clone(&self.projection_evals);
379 let unique_index = Arc::clone(&self.unique_index);
380 let numeric_fields = Arc::clone(&self.numeric_fields);
381 let requires_numeric = self.requires_numeric;
382 let null_policy = self.null_policy;
383 let schema = Arc::clone(&self.schema);
384 let projection_plan = self.projection_plan.clone();
385
386 let numeric_cache = self.numeric_arrays_cache.as_mut();
387 let batch_opt = materialize_row_window(
388 self.storage,
389 self.table_id,
390 unique_lfids.as_ref(),
391 projection_evals.as_ref(),
392 &projection_plan,
393 unique_index.as_ref(),
394 numeric_fields.as_ref(),
395 requires_numeric,
396 null_policy,
397 &schema,
398 window,
399 self.gather_ctx.as_mut(),
400 numeric_cache,
401 &mut self.columns_buf,
402 )?;
403
404 let Some(batch) = batch_opt else {
405 continue;
406 };
407
408 if batch.num_rows() == 0 {
409 continue;
410 }
411
412 let row_ids_ref = if self.emit_row_ids {
413 let row_id_slice = self.row_ids.slice(start, end - start);
414 self.current_row_ids = Some(Arc::new(row_id_slice) as ArrayRef);
415 self.current_row_ids
416 .as_ref()
417 .and_then(|arr| arr.as_any().downcast_ref::<UInt64Array>())
418 } else {
419 self.current_row_ids = None;
420 None
421 };
422 self.current_batch = Some(batch);
423
424 let batch_ref = self.current_batch.as_ref().expect("batch must be present");
425 let columns = batch_ref.columns();
426 let column_set = ColumnSlices::new(columns);
427
428 return Ok(Some(RowChunk {
429 row_ids: row_ids_ref,
430 columns: column_set,
431 visibility: None,
432 record_batch: batch_ref,
433 }));
434 }
435
436 Ok(None)
437 }
438}
439
440impl<'storage, P, S> ScanRowStream<'storage, P, S>
441where
442 P: llkv_storage::pager::Pager<Blob = EntryHandle> + Send + Sync,
443 S: ScanStorage<P>,
444{
445 pub fn into_gather_context(self) -> Option<MultiGatherContext> {
446 self.gather_ctx
447 }
448}
449
450#[allow(clippy::too_many_arguments)]
451pub fn materialize_row_window<P, S>(
452 storage: &S,
453 table_id: TableId,
454 unique_lfids: &[LogicalFieldId],
455 projection_evals: &[ProjectionEval],
456 projection_plan: &[ProjectionPlan],
457 unique_index: &FxHashMap<LogicalFieldId, usize>,
458 numeric_fields: &FxHashSet<FieldId>,
459 requires_numeric: bool,
460 null_policy: GatherNullPolicy,
461 out_schema: &Arc<Schema>,
462 window: &[RowId],
463 gather_ctx: Option<&mut MultiGatherContext>,
464 numeric_cache: Option<&mut NumericArrayMap>,
465 columns: &mut Vec<ArrayRef>,
466) -> LlkvResult<Option<RecordBatch>>
467where
468 P: llkv_storage::pager::Pager<Blob = EntryHandle> + Send + Sync,
469 S: ScanStorage<P>,
470{
471 if window.is_empty() {
472 return Ok(None);
473 }
474
475 let mut gathered_batch: Option<RecordBatch> = None;
476 let mut numeric_arrays_holder: Option<&mut NumericArrayMap> = None;
477 let batch_len = if unique_lfids.is_empty() {
478 if requires_numeric {
479 let map = numeric_cache.expect("numeric cache missing for computed projections");
480 map.clear();
481 numeric_arrays_holder = Some(map);
482 }
483 window.len()
484 } else {
485 let batch = storage.gather_row_window_with_context(
486 unique_lfids,
487 window,
488 null_policy,
489 gather_ctx,
490 )?;
491 if batch.num_rows() == 0 {
492 return Ok(None);
493 }
494 let batch_len = batch.num_rows();
495 if requires_numeric {
496 let map = numeric_cache.expect("numeric cache missing for computed projections");
497 map.clear();
498 for (lfid, array) in unique_lfids.iter().zip(batch.columns().iter()) {
499 let fid = lfid.field_id();
500 if numeric_fields.contains(&fid) {
501 map.insert(fid, array.clone());
502 }
503 }
504 numeric_arrays_holder = Some(map);
505 }
506 gathered_batch = Some(batch);
507 batch_len
508 };
509
510 if batch_len == 0 {
511 return Ok(None);
512 }
513
514 let gathered_columns: &[ArrayRef] = if let Some(batch) = gathered_batch.as_ref() {
515 batch.columns()
516 } else {
517 &[]
518 };
519
520 columns.clear();
521 columns.reserve(projection_evals.len());
522 for (idx, plan) in projection_plan.iter().enumerate() {
523 match plan {
524 ProjectionPlan::Column { source_idx } => {
525 columns.push(Arc::clone(&gathered_columns[*source_idx]));
526 }
527 ProjectionPlan::Computed {
528 eval_idx,
529 passthrough_idx,
530 } => {
531 if let Some(arr_idx) = passthrough_idx {
532 columns.push(Arc::clone(&gathered_columns[*arr_idx]));
533 continue;
534 }
535 let info = match &projection_evals[*eval_idx] {
536 ProjectionEval::Computed(info) => info,
537 ProjectionEval::Column(_) => unreachable!("plan mismatch"),
538 };
539 let array: ArrayRef = match &info.expr {
540 ScalarExpr::Literal(_) => synthesize_computed_literal_array(
541 info,
542 out_schema.field(idx).data_type(),
543 batch_len,
544 )?,
545 ScalarExpr::Cast { .. } if !computed_expr_requires_numeric(&info.expr) => {
546 synthesize_computed_literal_array(
547 info,
548 out_schema.field(idx).data_type(),
549 batch_len,
550 )?
551 }
552 ScalarExpr::GetField { base, field_name } => {
553 fn eval_get_field(
554 expr: &ScalarExpr<FieldId>,
555 field_name: &str,
556 gathered_columns: &[ArrayRef],
557 unique_index: &FxHashMap<LogicalFieldId, usize>,
558 table_id: TableId,
559 ) -> LlkvResult<ArrayRef> {
560 let base_array = match expr {
561 ScalarExpr::Column(fid) => {
562 let lfid = LogicalFieldId::for_user(table_id, *fid);
563 let arr_idx = *unique_index.get(&lfid).ok_or_else(|| {
564 llkv_result::Error::Internal(
565 "field missing from unique arrays".into(),
566 )
567 })?;
568 Arc::clone(&gathered_columns[arr_idx])
569 }
570 ScalarExpr::GetField {
571 base: inner_base,
572 field_name: inner_field,
573 } => eval_get_field(
574 inner_base,
575 inner_field,
576 gathered_columns,
577 unique_index,
578 table_id,
579 )?,
580 _ => {
581 return Err(llkv_result::Error::InvalidArgumentError(
582 "GetField base must be a column or another GetField".into(),
583 ));
584 }
585 };
586
587 let struct_array = base_array
588 .as_any()
589 .downcast_ref::<arrow::array::StructArray>()
590 .ok_or_else(|| {
591 llkv_result::Error::InvalidArgumentError(
592 "GetField can only be applied to struct types".into(),
593 )
594 })?;
595
596 struct_array
597 .column_by_name(field_name)
598 .ok_or_else(|| {
599 llkv_result::Error::InvalidArgumentError(format!(
600 "Field '{}' not found in struct",
601 field_name
602 ))
603 })
604 .map(Arc::clone)
605 }
606
607 eval_get_field(base, field_name, gathered_columns, unique_index, table_id)?
608 }
609 _ => {
610 let numeric_arrays = numeric_arrays_holder
611 .as_ref()
612 .expect("numeric arrays should exist for computed projection");
613 ScalarEvaluator::evaluate_batch(&info.expr, batch_len, numeric_arrays)?
614 }
615 };
616 columns.push(array);
617 }
618 }
619 }
620
621 let batch = RecordBatch::try_new(Arc::clone(out_schema), columns.clone())?;
622 Ok(Some(batch))
623}