Skip to main content

dbx_core/storage/
columnar.rs

1//! Columnar Store — Arrow RecordBatch management.
2//!
3//! Manages in-memory columnar data using Apache Arrow's `RecordBatch` format.
4//! Provides schema enforcement and row→column conversion.
5
6use crate::error::{DbxError, DbxResult};
7use arrow::array::{
8    ArrayRef, BooleanBuilder, Float64Builder, Int32Builder, Int64Builder, StringBuilder,
9};
10use arrow::datatypes::{DataType, Schema};
11use arrow::record_batch::RecordBatch;
12use rayon::prelude::*;
13use std::sync::Arc;
14
15/// Represents a scalar value that can be stored in a column.
16#[derive(Debug, Clone, PartialEq)]
17pub enum ScalarValue {
18    Null,
19    Int32(i32),
20    Int64(i64),
21    Float64(f64),
22    Utf8(String),
23    Boolean(bool),
24    Binary(Vec<u8>),
25}
26
27impl ScalarValue {
28    /// Get the Arrow DataType for this value.
29    pub fn data_type(&self) -> DataType {
30        match self {
31            ScalarValue::Null => DataType::Null,
32            ScalarValue::Int32(_) => DataType::Int32,
33            ScalarValue::Int64(_) => DataType::Int64,
34            ScalarValue::Float64(_) => DataType::Float64,
35            ScalarValue::Utf8(_) => DataType::Utf8,
36            ScalarValue::Boolean(_) => DataType::Boolean,
37            ScalarValue::Binary(_) => DataType::Binary,
38        }
39    }
40
41    /// Extract a value from an Arrow array at the given index.
42    pub fn from_array(array: &ArrayRef, idx: usize) -> crate::error::DbxResult<Self> {
43        use arrow::array::AsArray;
44        if array.is_null(idx) {
45            return Ok(ScalarValue::Null);
46        }
47        match array.data_type() {
48            DataType::Int32 => Ok(ScalarValue::Int32(
49                array
50                    .as_primitive::<arrow::datatypes::Int32Type>()
51                    .value(idx),
52            )),
53            DataType::Int64 => Ok(ScalarValue::Int64(
54                array
55                    .as_primitive::<arrow::datatypes::Int64Type>()
56                    .value(idx),
57            )),
58            DataType::Float64 => Ok(ScalarValue::Float64(
59                array
60                    .as_primitive::<arrow::datatypes::Float64Type>()
61                    .value(idx),
62            )),
63            DataType::Boolean => Ok(ScalarValue::Boolean(array.as_boolean().value(idx))),
64            DataType::Utf8 => Ok(ScalarValue::Utf8(
65                array.as_string::<i32>().value(idx).to_string(),
66            )),
67            DataType::Binary => Ok(ScalarValue::Binary(
68                array.as_binary::<i32>().value(idx).to_vec(),
69            )),
70            dt => Err(crate::error::DbxError::TypeMismatch {
71                expected: "Int32|Int64|Float64|Boolean|Utf8|Binary".to_string(),
72                actual: format!("{dt:?}"),
73            }),
74        }
75    }
76}
77
78/// In-memory columnar store backed by Arrow RecordBatch.
79///
80/// Accumulates rows and converts them to columnar format on demand.
81pub struct ColumnarStore {
82    schema: Arc<Schema>,
83    rows: Vec<Vec<ScalarValue>>,
84}
85
86impl ColumnarStore {
87    /// Create a new ColumnarStore with the given schema.
88    pub fn new(schema: Arc<Schema>) -> Self {
89        Self {
90            schema,
91            rows: Vec::new(),
92        }
93    }
94
95    /// Append a row of values. Must match the schema's field count and types.
96    pub fn append_row(&mut self, values: &[ScalarValue]) -> DbxResult<()> {
97        let field_count = self.schema.fields().len();
98        if values.len() != field_count {
99            return Err(DbxError::Schema(format!(
100                "expected {field_count} columns, got {}",
101                values.len()
102            )));
103        }
104
105        // Type check each value against schema
106        for (i, (value, field)) in values.iter().zip(self.schema.fields()).enumerate() {
107            if !matches!(value, ScalarValue::Null) {
108                let expected = field.data_type();
109                let actual = value.data_type();
110                if *expected != actual {
111                    return Err(DbxError::TypeMismatch {
112                        expected: format!("column {i} ({}): {:?}", field.name(), expected),
113                        actual: format!("{actual:?}"),
114                    });
115                }
116            }
117        }
118
119        self.rows.push(values.to_vec());
120        Ok(())
121    }
122
123    /// Convert accumulated rows into an Arrow RecordBatch (Parallelized).
124    pub fn to_record_batch(&self) -> DbxResult<RecordBatch> {
125        if self.rows.is_empty() {
126            return Ok(RecordBatch::new_empty(Arc::clone(&self.schema)));
127        }
128
129        // Use Rayon for parallel column building
130        let columns: Vec<ArrayRef> = self
131            .schema
132            .fields()
133            .par_iter()
134            .enumerate()
135            .map(|(col_idx, field)| self.build_column(col_idx, field.data_type()))
136            .collect::<DbxResult<_>>()?;
137
138        Ok(RecordBatch::try_new(Arc::clone(&self.schema), columns)?)
139    }
140
141    /// Get the schema.
142    pub fn schema(&self) -> &Schema {
143        &self.schema
144    }
145
146    /// Get the number of accumulated rows.
147    pub fn row_count(&self) -> usize {
148        self.rows.len()
149    }
150
151    /// Clear all accumulated rows.
152    pub fn clear(&mut self) {
153        self.rows.clear();
154    }
155
156    /// Build a single column array from row data.
157    fn build_column(&self, col_idx: usize, data_type: &DataType) -> DbxResult<ArrayRef> {
158        match data_type {
159            DataType::Int32 => {
160                let mut builder = Int32Builder::with_capacity(self.rows.len());
161                for row in &self.rows {
162                    match &row[col_idx] {
163                        ScalarValue::Int32(v) => builder.append_value(*v),
164                        ScalarValue::Null => builder.append_null(),
165                        other => {
166                            return Err(DbxError::TypeMismatch {
167                                expected: "Int32".to_string(),
168                                actual: format!("{other:?}"),
169                            });
170                        }
171                    }
172                }
173                Ok(Arc::new(builder.finish()))
174            }
175            DataType::Int64 => {
176                let mut builder = Int64Builder::with_capacity(self.rows.len());
177                for row in &self.rows {
178                    match &row[col_idx] {
179                        ScalarValue::Int64(v) => builder.append_value(*v),
180                        ScalarValue::Null => builder.append_null(),
181                        other => {
182                            return Err(DbxError::TypeMismatch {
183                                expected: "Int64".to_string(),
184                                actual: format!("{other:?}"),
185                            });
186                        }
187                    }
188                }
189                Ok(Arc::new(builder.finish()))
190            }
191            DataType::Float64 => {
192                let mut builder = Float64Builder::with_capacity(self.rows.len());
193                for row in &self.rows {
194                    match &row[col_idx] {
195                        ScalarValue::Float64(v) => builder.append_value(*v),
196                        ScalarValue::Null => builder.append_null(),
197                        other => {
198                            return Err(DbxError::TypeMismatch {
199                                expected: "Float64".to_string(),
200                                actual: format!("{other:?}"),
201                            });
202                        }
203                    }
204                }
205                Ok(Arc::new(builder.finish()))
206            }
207            DataType::Utf8 => {
208                let mut builder = StringBuilder::with_capacity(self.rows.len(), 256);
209                for row in &self.rows {
210                    match &row[col_idx] {
211                        ScalarValue::Utf8(v) => builder.append_value(v),
212                        ScalarValue::Null => builder.append_null(),
213                        other => {
214                            return Err(DbxError::TypeMismatch {
215                                expected: "Utf8".to_string(),
216                                actual: format!("{other:?}"),
217                            });
218                        }
219                    }
220                }
221                Ok(Arc::new(builder.finish()))
222            }
223            DataType::Boolean => {
224                let mut builder = BooleanBuilder::with_capacity(self.rows.len());
225                for row in &self.rows {
226                    match &row[col_idx] {
227                        ScalarValue::Boolean(v) => builder.append_value(*v),
228                        ScalarValue::Null => builder.append_null(),
229                        other => {
230                            return Err(DbxError::TypeMismatch {
231                                expected: "Boolean".to_string(),
232                                actual: format!("{other:?}"),
233                            });
234                        }
235                    }
236                }
237                Ok(Arc::new(builder.finish()))
238            }
239            DataType::Binary => {
240                let mut builder = arrow::array::BinaryBuilder::with_capacity(self.rows.len(), 256);
241                for row in &self.rows {
242                    match &row[col_idx] {
243                        ScalarValue::Binary(v) => builder.append_value(v),
244                        ScalarValue::Null => builder.append_null(),
245                        other => {
246                            return Err(DbxError::TypeMismatch {
247                                expected: "Binary".to_string(),
248                                actual: format!("{other:?}"),
249                            });
250                        }
251                    }
252                }
253                Ok(Arc::new(builder.finish()))
254            }
255            dt => Err(DbxError::Schema(format!("unsupported data type: {dt:?}"))),
256        }
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use arrow::array::{Array, BooleanArray, Float64Array, Int32Array, Int64Array, StringArray};
264    use arrow::datatypes::Field;
265
266    fn test_schema() -> Arc<Schema> {
267        Arc::new(Schema::new(vec![
268            Field::new("id", DataType::Int32, false),
269            Field::new("name", DataType::Utf8, false),
270            Field::new("age", DataType::Int64, true),
271            Field::new("score", DataType::Float64, true),
272            Field::new("active", DataType::Boolean, false),
273        ]))
274    }
275
276    #[test]
277    fn create_empty_store() {
278        let store = ColumnarStore::new(test_schema());
279        assert_eq!(store.row_count(), 0);
280        let batch = store.to_record_batch().unwrap();
281        assert_eq!(batch.num_rows(), 0);
282        assert_eq!(batch.num_columns(), 5);
283    }
284
285    #[test]
286    fn append_and_convert() {
287        let mut store = ColumnarStore::new(test_schema());
288        store
289            .append_row(&[
290                ScalarValue::Int32(1),
291                ScalarValue::Utf8("Alice".to_string()),
292                ScalarValue::Int64(30),
293                ScalarValue::Float64(95.5),
294                ScalarValue::Boolean(true),
295            ])
296            .unwrap();
297        store
298            .append_row(&[
299                ScalarValue::Int32(2),
300                ScalarValue::Utf8("Bob".to_string()),
301                ScalarValue::Int64(25),
302                ScalarValue::Float64(87.3),
303                ScalarValue::Boolean(false),
304            ])
305            .unwrap();
306
307        assert_eq!(store.row_count(), 2);
308        let batch = store.to_record_batch().unwrap();
309        assert_eq!(batch.num_rows(), 2);
310        assert_eq!(batch.num_columns(), 5);
311
312        // Verify column data
313        let ids = batch
314            .column(0)
315            .as_any()
316            .downcast_ref::<Int32Array>()
317            .unwrap();
318        assert_eq!(ids.value(0), 1);
319        assert_eq!(ids.value(1), 2);
320
321        let names = batch
322            .column(1)
323            .as_any()
324            .downcast_ref::<StringArray>()
325            .unwrap();
326        assert_eq!(names.value(0), "Alice");
327        assert_eq!(names.value(1), "Bob");
328
329        let ages = batch
330            .column(2)
331            .as_any()
332            .downcast_ref::<Int64Array>()
333            .unwrap();
334        assert_eq!(ages.value(0), 30);
335        assert_eq!(ages.value(1), 25);
336
337        let scores = batch
338            .column(3)
339            .as_any()
340            .downcast_ref::<Float64Array>()
341            .unwrap();
342        assert!((scores.value(0) - 95.5).abs() < f64::EPSILON);
343
344        let active = batch
345            .column(4)
346            .as_any()
347            .downcast_ref::<BooleanArray>()
348            .unwrap();
349        assert!(active.value(0));
350        assert!(!active.value(1));
351    }
352
353    #[test]
354    fn null_handling() {
355        let mut store = ColumnarStore::new(test_schema());
356        store
357            .append_row(&[
358                ScalarValue::Int32(1),
359                ScalarValue::Utf8("Alice".to_string()),
360                ScalarValue::Null, // nullable age
361                ScalarValue::Null, // nullable score
362                ScalarValue::Boolean(true),
363            ])
364            .unwrap();
365
366        let batch = store.to_record_batch().unwrap();
367        let ages = batch
368            .column(2)
369            .as_any()
370            .downcast_ref::<Int64Array>()
371            .unwrap();
372        assert!(ages.is_null(0));
373
374        let scores = batch
375            .column(3)
376            .as_any()
377            .downcast_ref::<Float64Array>()
378            .unwrap();
379        assert!(scores.is_null(0));
380    }
381
382    #[test]
383    fn wrong_column_count_rejected() {
384        let mut store = ColumnarStore::new(test_schema());
385        let result = store.append_row(&[ScalarValue::Int32(1), ScalarValue::Utf8("x".into())]);
386        assert!(result.is_err());
387    }
388
389    #[test]
390    fn type_mismatch_rejected() {
391        let mut store = ColumnarStore::new(test_schema());
392        let result = store.append_row(&[
393            ScalarValue::Utf8("wrong".into()), // should be Int32
394            ScalarValue::Utf8("name".into()),
395            ScalarValue::Int64(0),
396            ScalarValue::Float64(0.0),
397            ScalarValue::Boolean(false),
398        ]);
399        assert!(result.is_err());
400    }
401
402    #[test]
403    fn clear_rows() {
404        let mut store = ColumnarStore::new(test_schema());
405        store
406            .append_row(&[
407                ScalarValue::Int32(1),
408                ScalarValue::Utf8("x".into()),
409                ScalarValue::Int64(0),
410                ScalarValue::Float64(0.0),
411                ScalarValue::Boolean(false),
412            ])
413            .unwrap();
414        assert_eq!(store.row_count(), 1);
415        store.clear();
416        assert_eq!(store.row_count(), 0);
417    }
418
419    #[test]
420    fn schema_accessible() {
421        let schema = test_schema();
422        let store = ColumnarStore::new(Arc::clone(&schema));
423        assert_eq!(store.schema().fields().len(), 5);
424        assert_eq!(store.schema().field(0).name(), "id");
425    }
426
427    #[test]
428    fn round_trip_1000_rows() {
429        let mut store = ColumnarStore::new(test_schema());
430        for i in 0..1000 {
431            store
432                .append_row(&[
433                    ScalarValue::Int32(i),
434                    ScalarValue::Utf8(format!("user_{i}")),
435                    ScalarValue::Int64(i as i64 * 2),
436                    ScalarValue::Float64(i as f64 * 1.5),
437                    ScalarValue::Boolean(i % 2 == 0),
438                ])
439                .unwrap();
440        }
441
442        let batch = store.to_record_batch().unwrap();
443        assert_eq!(batch.num_rows(), 1000);
444
445        let ids = batch
446            .column(0)
447            .as_any()
448            .downcast_ref::<Int32Array>()
449            .unwrap();
450        assert_eq!(ids.value(0), 0);
451        assert_eq!(ids.value(999), 999);
452
453        let names = batch
454            .column(1)
455            .as_any()
456            .downcast_ref::<StringArray>()
457            .unwrap();
458        assert_eq!(names.value(0), "user_0");
459        assert_eq!(names.value(999), "user_999");
460    }
461}