1use crate::error::Error::IllegalArgument;
19use crate::error::Result;
20use crate::row::InternalRow;
21use crate::row::datum::{Date, Time, TimestampLtz, TimestampNtz};
22use arrow::array::{Array, AsArray, BinaryArray, RecordBatch, StringArray};
23use arrow::datatypes::{
24 DataType as ArrowDataType, Date32Type, Decimal128Type, Float32Type, Float64Type, Int8Type,
25 Int16Type, Int32Type, Int64Type, Time32MillisecondType, Time32SecondType,
26 Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
27 TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
28};
29use std::sync::Arc;
30
31#[derive(Clone)]
32pub struct ColumnarRow {
33 record_batch: Arc<RecordBatch>,
34 row_id: usize,
35}
36
37impl ColumnarRow {
38 pub fn new(batch: Arc<RecordBatch>) -> Self {
39 ColumnarRow {
40 record_batch: batch,
41 row_id: 0,
42 }
43 }
44
45 pub fn new_with_row_id(bach: Arc<RecordBatch>, row_id: usize) -> Self {
46 ColumnarRow {
47 record_batch: bach,
48 row_id,
49 }
50 }
51
52 pub fn set_row_id(&mut self, row_id: usize) {
53 self.row_id = row_id
54 }
55
56 pub fn get_row_id(&self) -> usize {
57 self.row_id
58 }
59
60 pub fn get_record_batch(&self) -> &RecordBatch {
61 &self.record_batch
62 }
63
64 fn column(&self, pos: usize) -> Result<&Arc<dyn Array>> {
65 self.record_batch
66 .columns()
67 .get(pos)
68 .ok_or_else(|| IllegalArgument {
69 message: format!(
70 "column index {pos} out of bounds (batch has {} columns)",
71 self.record_batch.num_columns()
72 ),
73 })
74 }
75
76 fn read_timestamp_from_arrow<T>(
79 &self,
80 pos: usize,
81 _precision: u32,
82 construct_compact: impl FnOnce(i64) -> T,
83 construct_with_nanos: impl FnOnce(i64, i32) -> Result<T>,
84 ) -> Result<T> {
85 let column = self.column(pos)?;
86
87 let (value, time_unit) = match column.data_type() {
89 ArrowDataType::Timestamp(TimeUnit::Second, _) => (
90 column
91 .as_primitive_opt::<TimestampSecondType>()
92 .ok_or_else(|| IllegalArgument {
93 message: format!("expected TimestampSecondArray at position {pos}"),
94 })?
95 .value(self.row_id),
96 TimeUnit::Second,
97 ),
98 ArrowDataType::Timestamp(TimeUnit::Millisecond, _) => (
99 column
100 .as_primitive_opt::<TimestampMillisecondType>()
101 .ok_or_else(|| IllegalArgument {
102 message: format!("expected TimestampMillisecondArray at position {pos}"),
103 })?
104 .value(self.row_id),
105 TimeUnit::Millisecond,
106 ),
107 ArrowDataType::Timestamp(TimeUnit::Microsecond, _) => (
108 column
109 .as_primitive_opt::<TimestampMicrosecondType>()
110 .ok_or_else(|| IllegalArgument {
111 message: format!("expected TimestampMicrosecondArray at position {pos}"),
112 })?
113 .value(self.row_id),
114 TimeUnit::Microsecond,
115 ),
116 ArrowDataType::Timestamp(TimeUnit::Nanosecond, _) => (
117 column
118 .as_primitive_opt::<TimestampNanosecondType>()
119 .ok_or_else(|| IllegalArgument {
120 message: format!("expected TimestampNanosecondArray at position {pos}"),
121 })?
122 .value(self.row_id),
123 TimeUnit::Nanosecond,
124 ),
125 other => {
126 return Err(IllegalArgument {
127 message: format!("expected Timestamp column at position {pos}, got {other:?}"),
128 });
129 }
130 };
131
132 let (millis, nanos) = match time_unit {
134 TimeUnit::Second => (value * 1000, 0),
135 TimeUnit::Millisecond => (value, 0),
136 TimeUnit::Microsecond => {
137 let millis = value.div_euclid(1000);
140 let nanos = (value.rem_euclid(1000) * 1000) as i32;
141 (millis, nanos)
142 }
143 TimeUnit::Nanosecond => {
144 let millis = value.div_euclid(1_000_000);
146 let nanos = value.rem_euclid(1_000_000) as i32;
147 (millis, nanos)
148 }
149 };
150
151 if nanos == 0 {
152 Ok(construct_compact(millis))
153 } else {
154 construct_with_nanos(millis, nanos)
155 }
156 }
157
158 fn read_date_from_arrow(&self, pos: usize) -> Result<i32> {
160 Ok(self
161 .column(pos)?
162 .as_primitive_opt::<Date32Type>()
163 .ok_or_else(|| IllegalArgument {
164 message: format!("expected Date32Array at position {pos}"),
165 })?
166 .value(self.row_id))
167 }
168
169 fn read_time_from_arrow(&self, pos: usize) -> Result<i32> {
171 let column = self.column(pos)?;
172
173 match column.data_type() {
174 ArrowDataType::Time32(TimeUnit::Second) => {
175 let value = column
176 .as_primitive_opt::<Time32SecondType>()
177 .ok_or_else(|| IllegalArgument {
178 message: format!("expected Time32SecondArray at position {pos}"),
179 })?
180 .value(self.row_id);
181 Ok(value * 1000) }
183 ArrowDataType::Time32(TimeUnit::Millisecond) => Ok(column
184 .as_primitive_opt::<Time32MillisecondType>()
185 .ok_or_else(|| IllegalArgument {
186 message: format!("expected Time32MillisecondArray at position {pos}"),
187 })?
188 .value(self.row_id)),
189 ArrowDataType::Time64(TimeUnit::Microsecond) => {
190 let value = column
191 .as_primitive_opt::<Time64MicrosecondType>()
192 .ok_or_else(|| IllegalArgument {
193 message: format!("expected Time64MicrosecondArray at position {pos}"),
194 })?
195 .value(self.row_id);
196 Ok((value / 1000) as i32) }
198 ArrowDataType::Time64(TimeUnit::Nanosecond) => {
199 let value = column
200 .as_primitive_opt::<Time64NanosecondType>()
201 .ok_or_else(|| IllegalArgument {
202 message: format!("expected Time64NanosecondArray at position {pos}"),
203 })?
204 .value(self.row_id);
205 Ok((value / 1_000_000) as i32) }
207 other => Err(IllegalArgument {
208 message: format!("expected Time column at position {pos}, got {other:?}"),
209 }),
210 }
211 }
212}
213
214impl InternalRow for ColumnarRow {
215 fn get_field_count(&self) -> usize {
216 self.record_batch.num_columns()
217 }
218
219 fn is_null_at(&self, pos: usize) -> Result<bool> {
220 Ok(self.column(pos)?.is_null(self.row_id))
221 }
222
223 fn get_boolean(&self, pos: usize) -> Result<bool> {
224 Ok(self
225 .column(pos)?
226 .as_boolean_opt()
227 .ok_or_else(|| IllegalArgument {
228 message: format!("expected boolean array at position {pos}"),
229 })?
230 .value(self.row_id))
231 }
232
233 fn get_byte(&self, pos: usize) -> Result<i8> {
234 Ok(self
235 .column(pos)?
236 .as_primitive_opt::<Int8Type>()
237 .ok_or_else(|| IllegalArgument {
238 message: format!("expected byte array at position {pos}"),
239 })?
240 .value(self.row_id))
241 }
242
243 fn get_short(&self, pos: usize) -> Result<i16> {
244 Ok(self
245 .column(pos)?
246 .as_primitive_opt::<Int16Type>()
247 .ok_or_else(|| IllegalArgument {
248 message: format!("expected short array at position {pos}"),
249 })?
250 .value(self.row_id))
251 }
252
253 fn get_int(&self, pos: usize) -> Result<i32> {
254 Ok(self
255 .column(pos)?
256 .as_primitive_opt::<Int32Type>()
257 .ok_or_else(|| IllegalArgument {
258 message: format!("expected int array at position {pos}"),
259 })?
260 .value(self.row_id))
261 }
262
263 fn get_long(&self, pos: usize) -> Result<i64> {
264 Ok(self
265 .column(pos)?
266 .as_primitive_opt::<Int64Type>()
267 .ok_or_else(|| IllegalArgument {
268 message: format!("expected long array at position {pos}"),
269 })?
270 .value(self.row_id))
271 }
272
273 fn get_float(&self, pos: usize) -> Result<f32> {
274 Ok(self
275 .column(pos)?
276 .as_primitive_opt::<Float32Type>()
277 .ok_or_else(|| IllegalArgument {
278 message: format!("expected float32 array at position {pos}"),
279 })?
280 .value(self.row_id))
281 }
282
283 fn get_double(&self, pos: usize) -> Result<f64> {
284 Ok(self
285 .column(pos)?
286 .as_primitive_opt::<Float64Type>()
287 .ok_or_else(|| IllegalArgument {
288 message: format!("expected float64 array at position {pos}"),
289 })?
290 .value(self.row_id))
291 }
292
293 fn get_char(&self, pos: usize, _length: usize) -> Result<&str> {
294 Ok(self
295 .column(pos)?
296 .as_any()
297 .downcast_ref::<StringArray>()
298 .ok_or_else(|| IllegalArgument {
299 message: format!("expected String array for char type at position {pos}"),
300 })?
301 .value(self.row_id))
302 }
303
304 fn get_string(&self, pos: usize) -> Result<&str> {
305 Ok(self
306 .column(pos)?
307 .as_any()
308 .downcast_ref::<StringArray>()
309 .ok_or_else(|| IllegalArgument {
310 message: format!("expected String array at position {pos}"),
311 })?
312 .value(self.row_id))
313 }
314
315 fn get_decimal(
316 &self,
317 pos: usize,
318 precision: usize,
319 scale: usize,
320 ) -> Result<crate::row::Decimal> {
321 use arrow::datatypes::DataType;
322
323 let column = self.column(pos)?;
324 let array = column
325 .as_primitive_opt::<Decimal128Type>()
326 .ok_or_else(|| IllegalArgument {
327 message: format!(
328 "expected Decimal128Array at column {pos}, found: {:?}",
329 column.data_type()
330 ),
331 })?;
332
333 debug_assert!(
335 !array.is_null(self.row_id),
336 "get_decimal called on null value at pos {} row {}",
337 pos,
338 self.row_id
339 );
340
341 let arrow_scale = match column.data_type() {
343 DataType::Decimal128(_p, s) => *s as i64,
344 dt => {
345 return Err(IllegalArgument {
346 message: format!(
347 "expected Decimal128 data type at column {pos}, found: {dt:?}"
348 ),
349 });
350 }
351 };
352
353 let i128_val = array.value(self.row_id);
354
355 crate::row::Decimal::from_arrow_decimal128(
357 i128_val,
358 arrow_scale,
359 precision as u32,
360 scale as u32,
361 )
362 }
363
364 fn get_date(&self, pos: usize) -> Result<Date> {
365 Ok(Date::new(self.read_date_from_arrow(pos)?))
366 }
367
368 fn get_time(&self, pos: usize) -> Result<Time> {
369 Ok(Time::new(self.read_time_from_arrow(pos)?))
370 }
371
372 fn get_timestamp_ntz(&self, pos: usize, precision: u32) -> Result<TimestampNtz> {
373 self.read_timestamp_from_arrow(
374 pos,
375 precision,
376 TimestampNtz::new,
377 TimestampNtz::from_millis_nanos,
378 )
379 }
380
381 fn get_timestamp_ltz(&self, pos: usize, precision: u32) -> Result<TimestampLtz> {
382 self.read_timestamp_from_arrow(
383 pos,
384 precision,
385 TimestampLtz::new,
386 TimestampLtz::from_millis_nanos,
387 )
388 }
389
390 fn get_binary(&self, pos: usize, _length: usize) -> Result<&[u8]> {
391 Ok(self
392 .column(pos)?
393 .as_fixed_size_binary_opt()
394 .ok_or_else(|| IllegalArgument {
395 message: format!("expected binary array at position {pos}"),
396 })?
397 .value(self.row_id))
398 }
399
400 fn get_bytes(&self, pos: usize) -> Result<&[u8]> {
401 Ok(self
402 .column(pos)?
403 .as_any()
404 .downcast_ref::<BinaryArray>()
405 .ok_or_else(|| IllegalArgument {
406 message: format!("expected bytes array at position {pos}"),
407 })?
408 .value(self.row_id))
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use arrow::array::{
416 BinaryArray, BooleanArray, Decimal128Array, Float32Array, Float64Array, Int8Array,
417 Int16Array, Int32Array, Int64Array, StringArray,
418 };
419 use arrow::datatypes::{DataType, Field, Schema};
420
421 #[test]
422 fn columnar_row_reads_values() {
423 let schema = Arc::new(Schema::new(vec![
424 Field::new("b", DataType::Boolean, false),
425 Field::new("i8", DataType::Int8, false),
426 Field::new("i16", DataType::Int16, false),
427 Field::new("i32", DataType::Int32, false),
428 Field::new("i64", DataType::Int64, false),
429 Field::new("f32", DataType::Float32, false),
430 Field::new("f64", DataType::Float64, false),
431 Field::new("s", DataType::Utf8, false),
432 Field::new("bin", DataType::Binary, false),
433 Field::new("char", DataType::Utf8, false),
434 ]));
435
436 let batch = RecordBatch::try_new(
437 schema,
438 vec![
439 Arc::new(BooleanArray::from(vec![true])),
440 Arc::new(Int8Array::from(vec![1])),
441 Arc::new(Int16Array::from(vec![2])),
442 Arc::new(Int32Array::from(vec![3])),
443 Arc::new(Int64Array::from(vec![4])),
444 Arc::new(Float32Array::from(vec![1.25])),
445 Arc::new(Float64Array::from(vec![2.5])),
446 Arc::new(StringArray::from(vec!["hello"])),
447 Arc::new(BinaryArray::from(vec![b"data".as_slice()])),
448 Arc::new(StringArray::from(vec!["ab"])),
449 ],
450 )
451 .expect("record batch");
452
453 let mut row = ColumnarRow::new(Arc::new(batch));
454 assert_eq!(row.get_field_count(), 10);
455 assert!(row.get_boolean(0).unwrap());
456 assert_eq!(row.get_byte(1).unwrap(), 1);
457 assert_eq!(row.get_short(2).unwrap(), 2);
458 assert_eq!(row.get_int(3).unwrap(), 3);
459 assert_eq!(row.get_long(4).unwrap(), 4);
460 assert_eq!(row.get_float(5).unwrap(), 1.25);
461 assert_eq!(row.get_double(6).unwrap(), 2.5);
462 assert_eq!(row.get_string(7).unwrap(), "hello");
463 assert_eq!(row.get_bytes(8).unwrap(), b"data");
464 assert_eq!(row.get_char(9, 2).unwrap(), "ab");
465 row.set_row_id(0);
466 assert_eq!(row.get_row_id(), 0);
467 }
468
469 #[test]
470 fn columnar_row_reads_decimal() {
471 use arrow::datatypes::DataType;
472 use bigdecimal::{BigDecimal, num_bigint::BigInt};
473
474 let schema = Arc::new(Schema::new(vec![
476 Field::new("dec1", DataType::Decimal128(10, 2), false),
477 Field::new("dec2", DataType::Decimal128(20, 5), false),
478 Field::new("dec3", DataType::Decimal128(38, 10), false),
479 ]));
480
481 let dec1_val = 12345i128; let dec2_val = 1234567890i128; let dec3_val = 999999999999999999i128; let batch = RecordBatch::try_new(
487 schema,
488 vec![
489 Arc::new(
490 Decimal128Array::from(vec![dec1_val])
491 .with_precision_and_scale(10, 2)
492 .unwrap(),
493 ),
494 Arc::new(
495 Decimal128Array::from(vec![dec2_val])
496 .with_precision_and_scale(20, 5)
497 .unwrap(),
498 ),
499 Arc::new(
500 Decimal128Array::from(vec![dec3_val])
501 .with_precision_and_scale(38, 10)
502 .unwrap(),
503 ),
504 ],
505 )
506 .expect("record batch");
507
508 let row = ColumnarRow::new(Arc::new(batch));
509 assert_eq!(row.get_field_count(), 3);
510
511 assert_eq!(
513 row.get_decimal(0, 10, 2).unwrap(),
514 crate::row::Decimal::from_big_decimal(BigDecimal::new(BigInt::from(12345), 2), 10, 2)
515 .unwrap()
516 );
517 assert_eq!(
518 row.get_decimal(1, 20, 5).unwrap(),
519 crate::row::Decimal::from_big_decimal(
520 BigDecimal::new(BigInt::from(1234567890), 5),
521 20,
522 5
523 )
524 .unwrap()
525 );
526 assert_eq!(
527 row.get_decimal(2, 38, 10).unwrap(),
528 crate::row::Decimal::from_big_decimal(
529 BigDecimal::new(BigInt::from(999999999999999999i128), 10),
530 38,
531 10
532 )
533 .unwrap()
534 );
535 }
536}