llkv_scan/
row_stream.rs

1use std::marker::PhantomData;
2use std::sync::Arc;
3
4use arrow::array::{ArrayRef, RecordBatch, UInt64Array};
5use arrow::buffer::BooleanBuffer;
6use arrow::datatypes::Schema;
7use croaring::Treemap;
8use llkv_column_map::store::{GatherNullPolicy, MultiGatherContext};
9use llkv_compute::analysis::computed_expr_requires_numeric;
10use llkv_compute::eval::{NumericArrayMap as ComputeNumericArrayMap, ScalarEvaluator};
11use llkv_compute::projection::{ComputedLiteralInfo, synthesize_computed_literal_array};
12use llkv_expr::ScalarExpr;
13use llkv_result::Result as LlkvResult;
14use llkv_types::{FieldId, LogicalFieldId, RowId, TableId};
15use rustc_hash::{FxHashMap, FxHashSet};
16use simd_r_drive_entry_handle::EntryHandle;
17
18use crate::ScanStorage;
19
20pub type NumericArrayMap = ComputeNumericArrayMap<FieldId>;
21
22pub enum RowIdSource {
23    Bitmap(Treemap),
24    Vector(Vec<RowId>),
25}
26
27impl From<Treemap> for RowIdSource {
28    fn from(bitmap: Treemap) -> Self {
29        RowIdSource::Bitmap(bitmap)
30    }
31}
32
33impl From<Vec<RowId>> for RowIdSource {
34    fn from(vector: Vec<RowId>) -> Self {
35        RowIdSource::Vector(vector)
36    }
37}
38
39impl From<&Treemap> for RowIdSource {
40    fn from(bitmap: &Treemap) -> Self {
41        RowIdSource::Vector(bitmap.iter().collect())
42    }
43}
44
45impl From<&[RowId]> for RowIdSource {
46    fn from(slice: &[RowId]) -> Self {
47        RowIdSource::Vector(slice.to_vec())
48    }
49}
50
51impl From<&Vec<RowId>> for RowIdSource {
52    fn from(vec: &Vec<RowId>) -> Self {
53        RowIdSource::Vector(vec.clone())
54    }
55}
56
57pub trait ColumnSliceSet<'a> {
58    fn len(&self) -> usize;
59    fn is_empty(&self) -> bool;
60    fn column(&self, idx: usize) -> &'a ArrayRef;
61    fn columns(&self) -> &'a [ArrayRef];
62}
63
64pub struct ColumnSlices<'a> {
65    columns: &'a [ArrayRef],
66}
67
68impl<'a> ColumnSlices<'a> {
69    pub fn new(columns: &'a [ArrayRef]) -> Self {
70        Self { columns }
71    }
72}
73
74impl<'a> ColumnSliceSet<'a> for ColumnSlices<'a> {
75    fn len(&self) -> usize {
76        self.columns.len()
77    }
78
79    fn is_empty(&self) -> bool {
80        self.columns.is_empty()
81    }
82
83    fn column(&self, idx: usize) -> &'a ArrayRef {
84        &self.columns[idx]
85    }
86
87    fn columns(&self) -> &'a [ArrayRef] {
88        self.columns
89    }
90}
91
92pub struct RowChunk<'a, C> {
93    pub row_ids: Option<&'a UInt64Array>,
94    pub columns: C,
95    pub visibility: Option<&'a BooleanBuffer>,
96    record_batch: &'a RecordBatch,
97}
98
99impl<'a, C> RowChunk<'a, C> {
100    pub fn record_batch(&self) -> &'a RecordBatch {
101        self.record_batch
102    }
103
104    pub fn to_record_batch(&self) -> RecordBatch {
105        self.record_batch.clone()
106    }
107
108    pub fn row_ids(&self) -> Option<&'a UInt64Array> {
109        self.row_ids
110    }
111}
112
113pub trait RowStream {
114    type Columns<'a>: ColumnSliceSet<'a>
115    where
116        Self: 'a;
117
118    fn schema(&self) -> &Arc<Schema>;
119
120    fn next_chunk<'a>(&'a mut self) -> LlkvResult<Option<RowChunk<'a, Self::Columns<'a>>>>;
121}
122
123#[derive(Clone)]
124pub struct ColumnProjectionInfo {
125    pub logical_field_id: LogicalFieldId,
126    pub data_type: arrow::datatypes::DataType,
127    pub output_name: String,
128}
129
130pub type ComputedProjectionInfo = ComputedLiteralInfo<FieldId>;
131
132#[derive(Clone)]
133pub enum ProjectionEval {
134    Column(ColumnProjectionInfo),
135    Computed(ComputedProjectionInfo),
136}
137
138#[derive(Clone)]
139pub enum ProjectionPlan {
140    Column {
141        source_idx: usize,
142    },
143    Computed {
144        eval_idx: usize,
145        passthrough_idx: Option<usize>,
146    },
147}
148
149pub struct RowStreamBuilder<'storage, P, S>
150where
151    P: llkv_storage::pager::Pager<Blob = EntryHandle> + Send + Sync,
152    S: ScanStorage<P>,
153{
154    storage: &'storage S,
155    table_id: TableId,
156    schema: Arc<Schema>,
157    unique_lfids: Arc<Vec<LogicalFieldId>>,
158    projection_evals: Arc<Vec<ProjectionEval>>,
159    _passthrough_fields: Arc<Vec<Option<FieldId>>>,
160    unique_index: Arc<FxHashMap<LogicalFieldId, usize>>,
161    numeric_fields: Arc<FxHashSet<FieldId>>,
162    requires_numeric: bool,
163    null_policy: GatherNullPolicy,
164    row_ids: RowIdSource,
165    chunk_size: usize,
166    gather_ctx: Option<MultiGatherContext>,
167    phantom: PhantomData<P>,
168    include_row_ids: bool,
169}
170
171impl<'storage, P, S> RowStreamBuilder<'storage, P, S>
172where
173    P: llkv_storage::pager::Pager<Blob = EntryHandle> + Send + Sync,
174    S: ScanStorage<P>,
175{
176    #[allow(clippy::too_many_arguments)]
177    pub fn new(
178        storage: &'storage S,
179        table_id: TableId,
180        schema: Arc<Schema>,
181        unique_lfids: Arc<Vec<LogicalFieldId>>,
182        projection_evals: Arc<Vec<ProjectionEval>>,
183        passthrough_fields: Arc<Vec<Option<FieldId>>>,
184        unique_index: Arc<FxHashMap<LogicalFieldId, usize>>,
185        numeric_fields: Arc<FxHashSet<FieldId>>,
186        requires_numeric: bool,
187        null_policy: GatherNullPolicy,
188        row_ids: impl Into<RowIdSource>,
189        chunk_size: usize,
190        include_row_ids: bool,
191    ) -> Self {
192        Self {
193            storage,
194            table_id,
195            schema,
196            unique_lfids,
197            projection_evals,
198            _passthrough_fields: passthrough_fields,
199            unique_index,
200            numeric_fields,
201            requires_numeric,
202            null_policy,
203            row_ids: row_ids.into(),
204            chunk_size,
205            gather_ctx: None,
206            phantom: PhantomData,
207            include_row_ids,
208        }
209    }
210
211    pub fn with_gather_context(mut self, ctx: MultiGatherContext) -> Self {
212        self.gather_ctx = Some(ctx);
213        self
214    }
215
216    pub fn build(self) -> LlkvResult<ScanRowStream<'storage, P, S>> {
217        let RowStreamBuilder {
218            storage,
219            table_id,
220            schema,
221            unique_lfids,
222            projection_evals,
223            _passthrough_fields: passthrough_fields,
224            unique_index,
225            numeric_fields,
226            requires_numeric,
227            null_policy,
228            row_ids,
229            chunk_size,
230            mut gather_ctx,
231            include_row_ids,
232            ..
233        } = self;
234
235        let gather_ctx = if unique_lfids.is_empty() {
236            None
237        } else if let Some(ctx) = gather_ctx.take() {
238            Some(ctx)
239        } else {
240            Some(storage.prepare_gather_context(unique_lfids.as_ref())?)
241        };
242
243        let (row_id_array, total_rows) = match row_ids {
244            RowIdSource::Bitmap(bitmap) => {
245                let len = bitmap.cardinality();
246                let array = Arc::new(UInt64Array::from_iter_values(bitmap.iter()));
247                (array, len as usize)
248            }
249            RowIdSource::Vector(vector) => {
250                let len = vector.len();
251                let array = Arc::new(UInt64Array::from(vector));
252                (array, len)
253            }
254        };
255
256        let columns_buf = Vec::with_capacity(projection_evals.len());
257        let numeric_arrays_cache = if requires_numeric {
258            Some(FxHashMap::default())
259        } else {
260            None
261        };
262
263        let mut projection_plan = Vec::with_capacity(projection_evals.len());
264        for (idx, eval) in projection_evals.iter().enumerate() {
265            match eval {
266                ProjectionEval::Column(info) => {
267                    let arr_idx = *unique_index
268                        .get(&info.logical_field_id)
269                        .expect("logical field id missing from index");
270                    projection_plan.push(ProjectionPlan::Column {
271                        source_idx: arr_idx,
272                    });
273                }
274                ProjectionEval::Computed(_) => {
275                    let passthrough_idx = passthrough_fields[idx].as_ref().map(|fid| {
276                        let lfid = LogicalFieldId::for_user(table_id, *fid);
277                        *unique_index
278                            .get(&lfid)
279                            .expect("passthrough field missing from index")
280                    });
281                    projection_plan.push(ProjectionPlan::Computed {
282                        eval_idx: idx,
283                        passthrough_idx,
284                    });
285                }
286            }
287        }
288
289        let chunk_ranges = if total_rows == 0 {
290            Vec::new()
291        } else {
292            (0..total_rows)
293                .step_by(chunk_size.max(1))
294                .map(|start| {
295                    let end = (start + chunk_size).min(total_rows);
296                    (start, end)
297                })
298                .collect()
299        };
300
301        Ok(ScanRowStream {
302            storage,
303            table_id,
304            schema,
305            unique_lfids,
306            projection_evals,
307            _passthrough_fields: passthrough_fields,
308            unique_index,
309            numeric_fields,
310            requires_numeric,
311            null_policy,
312            row_ids: row_id_array,
313            chunk_ranges,
314            range_idx: 0,
315            gather_ctx,
316            current_batch: None,
317            current_row_ids: None,
318            columns_buf,
319            numeric_arrays_cache,
320            phantom: PhantomData,
321            emit_row_ids: include_row_ids,
322            projection_plan,
323        })
324    }
325}
326
327pub struct ScanRowStream<'storage, P, S>
328where
329    P: llkv_storage::pager::Pager<Blob = EntryHandle> + Send + Sync,
330    S: ScanStorage<P>,
331{
332    storage: &'storage S,
333    table_id: TableId,
334    schema: Arc<Schema>,
335    unique_lfids: Arc<Vec<LogicalFieldId>>,
336    projection_evals: Arc<Vec<ProjectionEval>>,
337    _passthrough_fields: Arc<Vec<Option<FieldId>>>,
338    unique_index: Arc<FxHashMap<LogicalFieldId, usize>>,
339    numeric_fields: Arc<FxHashSet<FieldId>>,
340    requires_numeric: bool,
341    null_policy: GatherNullPolicy,
342    row_ids: Arc<UInt64Array>,
343    chunk_ranges: Vec<(usize, usize)>,
344    range_idx: usize,
345    gather_ctx: Option<MultiGatherContext>,
346    current_batch: Option<RecordBatch>,
347    current_row_ids: Option<ArrayRef>,
348    columns_buf: Vec<ArrayRef>,
349    numeric_arrays_cache: Option<NumericArrayMap>,
350    phantom: PhantomData<P>,
351    emit_row_ids: bool,
352    projection_plan: Vec<ProjectionPlan>,
353}
354
355impl<'storage, P, S> RowStream for ScanRowStream<'storage, P, S>
356where
357    P: llkv_storage::pager::Pager<Blob = EntryHandle> + Send + Sync,
358    S: ScanStorage<P>,
359{
360    type Columns<'a>
361        = ColumnSlices<'a>
362    where
363        Self: 'a;
364
365    fn schema(&self) -> &Arc<Schema> {
366        &self.schema
367    }
368
369    fn next_chunk<'a>(&'a mut self) -> LlkvResult<Option<RowChunk<'a, Self::Columns<'a>>>> {
370        while self.range_idx < self.chunk_ranges.len() {
371            let (start, end) = self.chunk_ranges[self.range_idx];
372            self.range_idx += 1;
373
374            let values = self.row_ids.values();
375            let window: &[RowId] = &values[start..end];
376
377            let unique_lfids = Arc::clone(&self.unique_lfids);
378            let projection_evals = Arc::clone(&self.projection_evals);
379            let unique_index = Arc::clone(&self.unique_index);
380            let numeric_fields = Arc::clone(&self.numeric_fields);
381            let requires_numeric = self.requires_numeric;
382            let null_policy = self.null_policy;
383            let schema = Arc::clone(&self.schema);
384            let projection_plan = self.projection_plan.clone();
385
386            let numeric_cache = self.numeric_arrays_cache.as_mut();
387            let batch_opt = materialize_row_window(
388                self.storage,
389                self.table_id,
390                unique_lfids.as_ref(),
391                projection_evals.as_ref(),
392                &projection_plan,
393                unique_index.as_ref(),
394                numeric_fields.as_ref(),
395                requires_numeric,
396                null_policy,
397                &schema,
398                window,
399                self.gather_ctx.as_mut(),
400                numeric_cache,
401                &mut self.columns_buf,
402            )?;
403
404            let Some(batch) = batch_opt else {
405                continue;
406            };
407
408            if batch.num_rows() == 0 {
409                continue;
410            }
411
412            let row_ids_ref = if self.emit_row_ids {
413                let row_id_slice = self.row_ids.slice(start, end - start);
414                self.current_row_ids = Some(Arc::new(row_id_slice) as ArrayRef);
415                self.current_row_ids
416                    .as_ref()
417                    .and_then(|arr| arr.as_any().downcast_ref::<UInt64Array>())
418            } else {
419                self.current_row_ids = None;
420                None
421            };
422            self.current_batch = Some(batch);
423
424            let batch_ref = self.current_batch.as_ref().expect("batch must be present");
425            let columns = batch_ref.columns();
426            let column_set = ColumnSlices::new(columns);
427
428            return Ok(Some(RowChunk {
429                row_ids: row_ids_ref,
430                columns: column_set,
431                visibility: None,
432                record_batch: batch_ref,
433            }));
434        }
435
436        Ok(None)
437    }
438}
439
440impl<'storage, P, S> ScanRowStream<'storage, P, S>
441where
442    P: llkv_storage::pager::Pager<Blob = EntryHandle> + Send + Sync,
443    S: ScanStorage<P>,
444{
445    pub fn into_gather_context(self) -> Option<MultiGatherContext> {
446        self.gather_ctx
447    }
448}
449
450#[allow(clippy::too_many_arguments)]
451pub fn materialize_row_window<P, S>(
452    storage: &S,
453    table_id: TableId,
454    unique_lfids: &[LogicalFieldId],
455    projection_evals: &[ProjectionEval],
456    projection_plan: &[ProjectionPlan],
457    unique_index: &FxHashMap<LogicalFieldId, usize>,
458    numeric_fields: &FxHashSet<FieldId>,
459    requires_numeric: bool,
460    null_policy: GatherNullPolicy,
461    out_schema: &Arc<Schema>,
462    window: &[RowId],
463    gather_ctx: Option<&mut MultiGatherContext>,
464    numeric_cache: Option<&mut NumericArrayMap>,
465    columns: &mut Vec<ArrayRef>,
466) -> LlkvResult<Option<RecordBatch>>
467where
468    P: llkv_storage::pager::Pager<Blob = EntryHandle> + Send + Sync,
469    S: ScanStorage<P>,
470{
471    if window.is_empty() {
472        return Ok(None);
473    }
474
475    let mut gathered_batch: Option<RecordBatch> = None;
476    let mut numeric_arrays_holder: Option<&mut NumericArrayMap> = None;
477    let batch_len = if unique_lfids.is_empty() {
478        if requires_numeric {
479            let map = numeric_cache.expect("numeric cache missing for computed projections");
480            map.clear();
481            numeric_arrays_holder = Some(map);
482        }
483        window.len()
484    } else {
485        let batch = storage.gather_row_window_with_context(
486            unique_lfids,
487            window,
488            null_policy,
489            gather_ctx,
490        )?;
491        if batch.num_rows() == 0 {
492            return Ok(None);
493        }
494        let batch_len = batch.num_rows();
495        if requires_numeric {
496            let map = numeric_cache.expect("numeric cache missing for computed projections");
497            map.clear();
498            for (lfid, array) in unique_lfids.iter().zip(batch.columns().iter()) {
499                let fid = lfid.field_id();
500                if numeric_fields.contains(&fid) {
501                    map.insert(fid, array.clone());
502                }
503            }
504            numeric_arrays_holder = Some(map);
505        }
506        gathered_batch = Some(batch);
507        batch_len
508    };
509
510    if batch_len == 0 {
511        return Ok(None);
512    }
513
514    let gathered_columns: &[ArrayRef] = if let Some(batch) = gathered_batch.as_ref() {
515        batch.columns()
516    } else {
517        &[]
518    };
519
520    columns.clear();
521    columns.reserve(projection_evals.len());
522    for (idx, plan) in projection_plan.iter().enumerate() {
523        match plan {
524            ProjectionPlan::Column { source_idx } => {
525                columns.push(Arc::clone(&gathered_columns[*source_idx]));
526            }
527            ProjectionPlan::Computed {
528                eval_idx,
529                passthrough_idx,
530            } => {
531                if let Some(arr_idx) = passthrough_idx {
532                    columns.push(Arc::clone(&gathered_columns[*arr_idx]));
533                    continue;
534                }
535                let info = match &projection_evals[*eval_idx] {
536                    ProjectionEval::Computed(info) => info,
537                    ProjectionEval::Column(_) => unreachable!("plan mismatch"),
538                };
539                let array: ArrayRef = match &info.expr {
540                    ScalarExpr::Literal(_) => synthesize_computed_literal_array(
541                        info,
542                        out_schema.field(idx).data_type(),
543                        batch_len,
544                    )?,
545                    ScalarExpr::Cast { .. } if !computed_expr_requires_numeric(&info.expr) => {
546                        synthesize_computed_literal_array(
547                            info,
548                            out_schema.field(idx).data_type(),
549                            batch_len,
550                        )?
551                    }
552                    ScalarExpr::GetField { base, field_name } => {
553                        fn eval_get_field(
554                            expr: &ScalarExpr<FieldId>,
555                            field_name: &str,
556                            gathered_columns: &[ArrayRef],
557                            unique_index: &FxHashMap<LogicalFieldId, usize>,
558                            table_id: TableId,
559                        ) -> LlkvResult<ArrayRef> {
560                            let base_array = match expr {
561                                ScalarExpr::Column(fid) => {
562                                    let lfid = LogicalFieldId::for_user(table_id, *fid);
563                                    let arr_idx = *unique_index.get(&lfid).ok_or_else(|| {
564                                        llkv_result::Error::Internal(
565                                            "field missing from unique arrays".into(),
566                                        )
567                                    })?;
568                                    Arc::clone(&gathered_columns[arr_idx])
569                                }
570                                ScalarExpr::GetField {
571                                    base: inner_base,
572                                    field_name: inner_field,
573                                } => eval_get_field(
574                                    inner_base,
575                                    inner_field,
576                                    gathered_columns,
577                                    unique_index,
578                                    table_id,
579                                )?,
580                                _ => {
581                                    return Err(llkv_result::Error::InvalidArgumentError(
582                                        "GetField base must be a column or another GetField".into(),
583                                    ));
584                                }
585                            };
586
587                            let struct_array = base_array
588                                .as_any()
589                                .downcast_ref::<arrow::array::StructArray>()
590                                .ok_or_else(|| {
591                                    llkv_result::Error::InvalidArgumentError(
592                                        "GetField can only be applied to struct types".into(),
593                                    )
594                                })?;
595
596                            struct_array
597                                .column_by_name(field_name)
598                                .ok_or_else(|| {
599                                    llkv_result::Error::InvalidArgumentError(format!(
600                                        "Field '{}' not found in struct",
601                                        field_name
602                                    ))
603                                })
604                                .map(Arc::clone)
605                        }
606
607                        eval_get_field(base, field_name, gathered_columns, unique_index, table_id)?
608                    }
609                    _ => {
610                        let numeric_arrays = numeric_arrays_holder
611                            .as_ref()
612                            .expect("numeric arrays should exist for computed projection");
613                        ScalarEvaluator::evaluate_batch(&info.expr, batch_len, numeric_arrays)?
614                    }
615                };
616                columns.push(array);
617            }
618        }
619    }
620
621    let batch = RecordBatch::try_new(Arc::clone(out_schema), columns.clone())?;
622    Ok(Some(batch))
623}