1use crate::error::{DbxError, DbxResult};
7use arrow::array::{
8 ArrayRef, BooleanBuilder, Float64Builder, Int32Builder, Int64Builder, StringBuilder,
9};
10use arrow::datatypes::{DataType, Schema};
11use arrow::record_batch::RecordBatch;
12use rayon::prelude::*;
13use std::sync::Arc;
14
15#[derive(Debug, Clone, PartialEq)]
17pub enum ScalarValue {
18 Null,
19 Int32(i32),
20 Int64(i64),
21 Float64(f64),
22 Utf8(String),
23 Boolean(bool),
24 Binary(Vec<u8>),
25}
26
27impl ScalarValue {
28 pub fn data_type(&self) -> DataType {
30 match self {
31 ScalarValue::Null => DataType::Null,
32 ScalarValue::Int32(_) => DataType::Int32,
33 ScalarValue::Int64(_) => DataType::Int64,
34 ScalarValue::Float64(_) => DataType::Float64,
35 ScalarValue::Utf8(_) => DataType::Utf8,
36 ScalarValue::Boolean(_) => DataType::Boolean,
37 ScalarValue::Binary(_) => DataType::Binary,
38 }
39 }
40
41 pub fn from_array(array: &ArrayRef, idx: usize) -> crate::error::DbxResult<Self> {
43 use arrow::array::AsArray;
44 if array.is_null(idx) {
45 return Ok(ScalarValue::Null);
46 }
47 match array.data_type() {
48 DataType::Int32 => Ok(ScalarValue::Int32(
49 array
50 .as_primitive::<arrow::datatypes::Int32Type>()
51 .value(idx),
52 )),
53 DataType::Int64 => Ok(ScalarValue::Int64(
54 array
55 .as_primitive::<arrow::datatypes::Int64Type>()
56 .value(idx),
57 )),
58 DataType::Float64 => Ok(ScalarValue::Float64(
59 array
60 .as_primitive::<arrow::datatypes::Float64Type>()
61 .value(idx),
62 )),
63 DataType::Boolean => Ok(ScalarValue::Boolean(array.as_boolean().value(idx))),
64 DataType::Utf8 => Ok(ScalarValue::Utf8(
65 array.as_string::<i32>().value(idx).to_string(),
66 )),
67 DataType::Binary => Ok(ScalarValue::Binary(
68 array.as_binary::<i32>().value(idx).to_vec(),
69 )),
70 dt => Err(crate::error::DbxError::TypeMismatch {
71 expected: "Int32|Int64|Float64|Boolean|Utf8|Binary".to_string(),
72 actual: format!("{dt:?}"),
73 }),
74 }
75 }
76}
77
78pub struct ColumnarStore {
82 schema: Arc<Schema>,
83 rows: Vec<Vec<ScalarValue>>,
84}
85
86impl ColumnarStore {
87 pub fn new(schema: Arc<Schema>) -> Self {
89 Self {
90 schema,
91 rows: Vec::new(),
92 }
93 }
94
95 pub fn append_row(&mut self, values: &[ScalarValue]) -> DbxResult<()> {
97 let field_count = self.schema.fields().len();
98 if values.len() != field_count {
99 return Err(DbxError::Schema(format!(
100 "expected {field_count} columns, got {}",
101 values.len()
102 )));
103 }
104
105 for (i, (value, field)) in values.iter().zip(self.schema.fields()).enumerate() {
107 if !matches!(value, ScalarValue::Null) {
108 let expected = field.data_type();
109 let actual = value.data_type();
110 if *expected != actual {
111 return Err(DbxError::TypeMismatch {
112 expected: format!("column {i} ({}): {:?}", field.name(), expected),
113 actual: format!("{actual:?}"),
114 });
115 }
116 }
117 }
118
119 self.rows.push(values.to_vec());
120 Ok(())
121 }
122
123 pub fn to_record_batch(&self) -> DbxResult<RecordBatch> {
125 if self.rows.is_empty() {
126 return Ok(RecordBatch::new_empty(Arc::clone(&self.schema)));
127 }
128
129 let columns: Vec<ArrayRef> = self
131 .schema
132 .fields()
133 .par_iter()
134 .enumerate()
135 .map(|(col_idx, field)| self.build_column(col_idx, field.data_type()))
136 .collect::<DbxResult<_>>()?;
137
138 Ok(RecordBatch::try_new(Arc::clone(&self.schema), columns)?)
139 }
140
141 pub fn schema(&self) -> &Schema {
143 &self.schema
144 }
145
146 pub fn row_count(&self) -> usize {
148 self.rows.len()
149 }
150
151 pub fn clear(&mut self) {
153 self.rows.clear();
154 }
155
156 fn build_column(&self, col_idx: usize, data_type: &DataType) -> DbxResult<ArrayRef> {
158 match data_type {
159 DataType::Int32 => {
160 let mut builder = Int32Builder::with_capacity(self.rows.len());
161 for row in &self.rows {
162 match &row[col_idx] {
163 ScalarValue::Int32(v) => builder.append_value(*v),
164 ScalarValue::Null => builder.append_null(),
165 other => {
166 return Err(DbxError::TypeMismatch {
167 expected: "Int32".to_string(),
168 actual: format!("{other:?}"),
169 });
170 }
171 }
172 }
173 Ok(Arc::new(builder.finish()))
174 }
175 DataType::Int64 => {
176 let mut builder = Int64Builder::with_capacity(self.rows.len());
177 for row in &self.rows {
178 match &row[col_idx] {
179 ScalarValue::Int64(v) => builder.append_value(*v),
180 ScalarValue::Null => builder.append_null(),
181 other => {
182 return Err(DbxError::TypeMismatch {
183 expected: "Int64".to_string(),
184 actual: format!("{other:?}"),
185 });
186 }
187 }
188 }
189 Ok(Arc::new(builder.finish()))
190 }
191 DataType::Float64 => {
192 let mut builder = Float64Builder::with_capacity(self.rows.len());
193 for row in &self.rows {
194 match &row[col_idx] {
195 ScalarValue::Float64(v) => builder.append_value(*v),
196 ScalarValue::Null => builder.append_null(),
197 other => {
198 return Err(DbxError::TypeMismatch {
199 expected: "Float64".to_string(),
200 actual: format!("{other:?}"),
201 });
202 }
203 }
204 }
205 Ok(Arc::new(builder.finish()))
206 }
207 DataType::Utf8 => {
208 let mut builder = StringBuilder::with_capacity(self.rows.len(), 256);
209 for row in &self.rows {
210 match &row[col_idx] {
211 ScalarValue::Utf8(v) => builder.append_value(v),
212 ScalarValue::Null => builder.append_null(),
213 other => {
214 return Err(DbxError::TypeMismatch {
215 expected: "Utf8".to_string(),
216 actual: format!("{other:?}"),
217 });
218 }
219 }
220 }
221 Ok(Arc::new(builder.finish()))
222 }
223 DataType::Boolean => {
224 let mut builder = BooleanBuilder::with_capacity(self.rows.len());
225 for row in &self.rows {
226 match &row[col_idx] {
227 ScalarValue::Boolean(v) => builder.append_value(*v),
228 ScalarValue::Null => builder.append_null(),
229 other => {
230 return Err(DbxError::TypeMismatch {
231 expected: "Boolean".to_string(),
232 actual: format!("{other:?}"),
233 });
234 }
235 }
236 }
237 Ok(Arc::new(builder.finish()))
238 }
239 DataType::Binary => {
240 let mut builder = arrow::array::BinaryBuilder::with_capacity(self.rows.len(), 256);
241 for row in &self.rows {
242 match &row[col_idx] {
243 ScalarValue::Binary(v) => builder.append_value(v),
244 ScalarValue::Null => builder.append_null(),
245 other => {
246 return Err(DbxError::TypeMismatch {
247 expected: "Binary".to_string(),
248 actual: format!("{other:?}"),
249 });
250 }
251 }
252 }
253 Ok(Arc::new(builder.finish()))
254 }
255 dt => Err(DbxError::Schema(format!("unsupported data type: {dt:?}"))),
256 }
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use arrow::array::{Array, BooleanArray, Float64Array, Int32Array, Int64Array, StringArray};
264 use arrow::datatypes::Field;
265
266 fn test_schema() -> Arc<Schema> {
267 Arc::new(Schema::new(vec![
268 Field::new("id", DataType::Int32, false),
269 Field::new("name", DataType::Utf8, false),
270 Field::new("age", DataType::Int64, true),
271 Field::new("score", DataType::Float64, true),
272 Field::new("active", DataType::Boolean, false),
273 ]))
274 }
275
276 #[test]
277 fn create_empty_store() {
278 let store = ColumnarStore::new(test_schema());
279 assert_eq!(store.row_count(), 0);
280 let batch = store.to_record_batch().unwrap();
281 assert_eq!(batch.num_rows(), 0);
282 assert_eq!(batch.num_columns(), 5);
283 }
284
285 #[test]
286 fn append_and_convert() {
287 let mut store = ColumnarStore::new(test_schema());
288 store
289 .append_row(&[
290 ScalarValue::Int32(1),
291 ScalarValue::Utf8("Alice".to_string()),
292 ScalarValue::Int64(30),
293 ScalarValue::Float64(95.5),
294 ScalarValue::Boolean(true),
295 ])
296 .unwrap();
297 store
298 .append_row(&[
299 ScalarValue::Int32(2),
300 ScalarValue::Utf8("Bob".to_string()),
301 ScalarValue::Int64(25),
302 ScalarValue::Float64(87.3),
303 ScalarValue::Boolean(false),
304 ])
305 .unwrap();
306
307 assert_eq!(store.row_count(), 2);
308 let batch = store.to_record_batch().unwrap();
309 assert_eq!(batch.num_rows(), 2);
310 assert_eq!(batch.num_columns(), 5);
311
312 let ids = batch
314 .column(0)
315 .as_any()
316 .downcast_ref::<Int32Array>()
317 .unwrap();
318 assert_eq!(ids.value(0), 1);
319 assert_eq!(ids.value(1), 2);
320
321 let names = batch
322 .column(1)
323 .as_any()
324 .downcast_ref::<StringArray>()
325 .unwrap();
326 assert_eq!(names.value(0), "Alice");
327 assert_eq!(names.value(1), "Bob");
328
329 let ages = batch
330 .column(2)
331 .as_any()
332 .downcast_ref::<Int64Array>()
333 .unwrap();
334 assert_eq!(ages.value(0), 30);
335 assert_eq!(ages.value(1), 25);
336
337 let scores = batch
338 .column(3)
339 .as_any()
340 .downcast_ref::<Float64Array>()
341 .unwrap();
342 assert!((scores.value(0) - 95.5).abs() < f64::EPSILON);
343
344 let active = batch
345 .column(4)
346 .as_any()
347 .downcast_ref::<BooleanArray>()
348 .unwrap();
349 assert!(active.value(0));
350 assert!(!active.value(1));
351 }
352
353 #[test]
354 fn null_handling() {
355 let mut store = ColumnarStore::new(test_schema());
356 store
357 .append_row(&[
358 ScalarValue::Int32(1),
359 ScalarValue::Utf8("Alice".to_string()),
360 ScalarValue::Null, ScalarValue::Null, ScalarValue::Boolean(true),
363 ])
364 .unwrap();
365
366 let batch = store.to_record_batch().unwrap();
367 let ages = batch
368 .column(2)
369 .as_any()
370 .downcast_ref::<Int64Array>()
371 .unwrap();
372 assert!(ages.is_null(0));
373
374 let scores = batch
375 .column(3)
376 .as_any()
377 .downcast_ref::<Float64Array>()
378 .unwrap();
379 assert!(scores.is_null(0));
380 }
381
382 #[test]
383 fn wrong_column_count_rejected() {
384 let mut store = ColumnarStore::new(test_schema());
385 let result = store.append_row(&[ScalarValue::Int32(1), ScalarValue::Utf8("x".into())]);
386 assert!(result.is_err());
387 }
388
389 #[test]
390 fn type_mismatch_rejected() {
391 let mut store = ColumnarStore::new(test_schema());
392 let result = store.append_row(&[
393 ScalarValue::Utf8("wrong".into()), ScalarValue::Utf8("name".into()),
395 ScalarValue::Int64(0),
396 ScalarValue::Float64(0.0),
397 ScalarValue::Boolean(false),
398 ]);
399 assert!(result.is_err());
400 }
401
402 #[test]
403 fn clear_rows() {
404 let mut store = ColumnarStore::new(test_schema());
405 store
406 .append_row(&[
407 ScalarValue::Int32(1),
408 ScalarValue::Utf8("x".into()),
409 ScalarValue::Int64(0),
410 ScalarValue::Float64(0.0),
411 ScalarValue::Boolean(false),
412 ])
413 .unwrap();
414 assert_eq!(store.row_count(), 1);
415 store.clear();
416 assert_eq!(store.row_count(), 0);
417 }
418
419 #[test]
420 fn schema_accessible() {
421 let schema = test_schema();
422 let store = ColumnarStore::new(Arc::clone(&schema));
423 assert_eq!(store.schema().fields().len(), 5);
424 assert_eq!(store.schema().field(0).name(), "id");
425 }
426
427 #[test]
428 fn round_trip_1000_rows() {
429 let mut store = ColumnarStore::new(test_schema());
430 for i in 0..1000 {
431 store
432 .append_row(&[
433 ScalarValue::Int32(i),
434 ScalarValue::Utf8(format!("user_{i}")),
435 ScalarValue::Int64(i as i64 * 2),
436 ScalarValue::Float64(i as f64 * 1.5),
437 ScalarValue::Boolean(i % 2 == 0),
438 ])
439 .unwrap();
440 }
441
442 let batch = store.to_record_batch().unwrap();
443 assert_eq!(batch.num_rows(), 1000);
444
445 let ids = batch
446 .column(0)
447 .as_any()
448 .downcast_ref::<Int32Array>()
449 .unwrap();
450 assert_eq!(ids.value(0), 0);
451 assert_eq!(ids.value(999), 999);
452
453 let names = batch
454 .column(1)
455 .as_any()
456 .downcast_ref::<StringArray>()
457 .unwrap();
458 assert_eq!(names.value(0), "user_0");
459 assert_eq!(names.value(999), "user_999");
460 }
461}