Skip to main content

fluss/row/
column.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::error::Error::IllegalArgument;
19use crate::error::Result;
20use crate::row::InternalRow;
21use crate::row::datum::{Date, Time, TimestampLtz, TimestampNtz};
22use arrow::array::{Array, AsArray, BinaryArray, RecordBatch, StringArray};
23use arrow::datatypes::{
24    DataType as ArrowDataType, Date32Type, Decimal128Type, Float32Type, Float64Type, Int8Type,
25    Int16Type, Int32Type, Int64Type, Time32MillisecondType, Time32SecondType,
26    Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
27    TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
28};
29use std::sync::Arc;
30
31#[derive(Clone)]
32pub struct ColumnarRow {
33    record_batch: Arc<RecordBatch>,
34    row_id: usize,
35}
36
37impl ColumnarRow {
38    pub fn new(batch: Arc<RecordBatch>) -> Self {
39        ColumnarRow {
40            record_batch: batch,
41            row_id: 0,
42        }
43    }
44
45    pub fn new_with_row_id(bach: Arc<RecordBatch>, row_id: usize) -> Self {
46        ColumnarRow {
47            record_batch: bach,
48            row_id,
49        }
50    }
51
52    pub fn set_row_id(&mut self, row_id: usize) {
53        self.row_id = row_id
54    }
55
56    pub fn get_row_id(&self) -> usize {
57        self.row_id
58    }
59
60    pub fn get_record_batch(&self) -> &RecordBatch {
61        &self.record_batch
62    }
63
64    fn column(&self, pos: usize) -> Result<&Arc<dyn Array>> {
65        self.record_batch
66            .columns()
67            .get(pos)
68            .ok_or_else(|| IllegalArgument {
69                message: format!(
70                    "column index {pos} out of bounds (batch has {} columns)",
71                    self.record_batch.num_columns()
72                ),
73            })
74    }
75
76    /// Generic helper to read timestamp from Arrow, handling all TimeUnit conversions.
77    /// Like Java, the precision parameter is ignored - conversion is determined by Arrow TimeUnit.
78    fn read_timestamp_from_arrow<T>(
79        &self,
80        pos: usize,
81        _precision: u32,
82        construct_compact: impl FnOnce(i64) -> T,
83        construct_with_nanos: impl FnOnce(i64, i32) -> Result<T>,
84    ) -> Result<T> {
85        let column = self.column(pos)?;
86
87        // Read value and time unit based on the actual Arrow timestamp type
88        let (value, time_unit) = match column.data_type() {
89            ArrowDataType::Timestamp(TimeUnit::Second, _) => (
90                column
91                    .as_primitive_opt::<TimestampSecondType>()
92                    .ok_or_else(|| IllegalArgument {
93                        message: format!("expected TimestampSecondArray at position {pos}"),
94                    })?
95                    .value(self.row_id),
96                TimeUnit::Second,
97            ),
98            ArrowDataType::Timestamp(TimeUnit::Millisecond, _) => (
99                column
100                    .as_primitive_opt::<TimestampMillisecondType>()
101                    .ok_or_else(|| IllegalArgument {
102                        message: format!("expected TimestampMillisecondArray at position {pos}"),
103                    })?
104                    .value(self.row_id),
105                TimeUnit::Millisecond,
106            ),
107            ArrowDataType::Timestamp(TimeUnit::Microsecond, _) => (
108                column
109                    .as_primitive_opt::<TimestampMicrosecondType>()
110                    .ok_or_else(|| IllegalArgument {
111                        message: format!("expected TimestampMicrosecondArray at position {pos}"),
112                    })?
113                    .value(self.row_id),
114                TimeUnit::Microsecond,
115            ),
116            ArrowDataType::Timestamp(TimeUnit::Nanosecond, _) => (
117                column
118                    .as_primitive_opt::<TimestampNanosecondType>()
119                    .ok_or_else(|| IllegalArgument {
120                        message: format!("expected TimestampNanosecondArray at position {pos}"),
121                    })?
122                    .value(self.row_id),
123                TimeUnit::Nanosecond,
124            ),
125            other => {
126                return Err(IllegalArgument {
127                    message: format!("expected Timestamp column at position {pos}, got {other:?}"),
128                });
129            }
130        };
131
132        // Convert based on Arrow TimeUnit
133        let (millis, nanos) = match time_unit {
134            TimeUnit::Second => (value * 1000, 0),
135            TimeUnit::Millisecond => (value, 0),
136            TimeUnit::Microsecond => {
137                // Use Euclidean division so that nanos is always non-negative,
138                // even for timestamps before the Unix epoch.
139                let millis = value.div_euclid(1000);
140                let nanos = (value.rem_euclid(1000) * 1000) as i32;
141                (millis, nanos)
142            }
143            TimeUnit::Nanosecond => {
144                // Use Euclidean division so that nanos is always in [0, 999_999].
145                let millis = value.div_euclid(1_000_000);
146                let nanos = value.rem_euclid(1_000_000) as i32;
147                (millis, nanos)
148            }
149        };
150
151        if nanos == 0 {
152            Ok(construct_compact(millis))
153        } else {
154            construct_with_nanos(millis, nanos)
155        }
156    }
157
158    /// Read date value from Arrow Date32Array
159    fn read_date_from_arrow(&self, pos: usize) -> Result<i32> {
160        Ok(self
161            .column(pos)?
162            .as_primitive_opt::<Date32Type>()
163            .ok_or_else(|| IllegalArgument {
164                message: format!("expected Date32Array at position {pos}"),
165            })?
166            .value(self.row_id))
167    }
168
169    /// Read time value from Arrow Time32/Time64 arrays, converting to milliseconds
170    fn read_time_from_arrow(&self, pos: usize) -> Result<i32> {
171        let column = self.column(pos)?;
172
173        match column.data_type() {
174            ArrowDataType::Time32(TimeUnit::Second) => {
175                let value = column
176                    .as_primitive_opt::<Time32SecondType>()
177                    .ok_or_else(|| IllegalArgument {
178                        message: format!("expected Time32SecondArray at position {pos}"),
179                    })?
180                    .value(self.row_id);
181                Ok(value * 1000) // Convert seconds to milliseconds
182            }
183            ArrowDataType::Time32(TimeUnit::Millisecond) => Ok(column
184                .as_primitive_opt::<Time32MillisecondType>()
185                .ok_or_else(|| IllegalArgument {
186                    message: format!("expected Time32MillisecondArray at position {pos}"),
187                })?
188                .value(self.row_id)),
189            ArrowDataType::Time64(TimeUnit::Microsecond) => {
190                let value = column
191                    .as_primitive_opt::<Time64MicrosecondType>()
192                    .ok_or_else(|| IllegalArgument {
193                        message: format!("expected Time64MicrosecondArray at position {pos}"),
194                    })?
195                    .value(self.row_id);
196                Ok((value / 1000) as i32) // Convert microseconds to milliseconds
197            }
198            ArrowDataType::Time64(TimeUnit::Nanosecond) => {
199                let value = column
200                    .as_primitive_opt::<Time64NanosecondType>()
201                    .ok_or_else(|| IllegalArgument {
202                        message: format!("expected Time64NanosecondArray at position {pos}"),
203                    })?
204                    .value(self.row_id);
205                Ok((value / 1_000_000) as i32) // Convert nanoseconds to milliseconds
206            }
207            other => Err(IllegalArgument {
208                message: format!("expected Time column at position {pos}, got {other:?}"),
209            }),
210        }
211    }
212}
213
214impl InternalRow for ColumnarRow {
215    fn get_field_count(&self) -> usize {
216        self.record_batch.num_columns()
217    }
218
219    fn is_null_at(&self, pos: usize) -> Result<bool> {
220        Ok(self.column(pos)?.is_null(self.row_id))
221    }
222
223    fn get_boolean(&self, pos: usize) -> Result<bool> {
224        Ok(self
225            .column(pos)?
226            .as_boolean_opt()
227            .ok_or_else(|| IllegalArgument {
228                message: format!("expected boolean array at position {pos}"),
229            })?
230            .value(self.row_id))
231    }
232
233    fn get_byte(&self, pos: usize) -> Result<i8> {
234        Ok(self
235            .column(pos)?
236            .as_primitive_opt::<Int8Type>()
237            .ok_or_else(|| IllegalArgument {
238                message: format!("expected byte array at position {pos}"),
239            })?
240            .value(self.row_id))
241    }
242
243    fn get_short(&self, pos: usize) -> Result<i16> {
244        Ok(self
245            .column(pos)?
246            .as_primitive_opt::<Int16Type>()
247            .ok_or_else(|| IllegalArgument {
248                message: format!("expected short array at position {pos}"),
249            })?
250            .value(self.row_id))
251    }
252
253    fn get_int(&self, pos: usize) -> Result<i32> {
254        Ok(self
255            .column(pos)?
256            .as_primitive_opt::<Int32Type>()
257            .ok_or_else(|| IllegalArgument {
258                message: format!("expected int array at position {pos}"),
259            })?
260            .value(self.row_id))
261    }
262
263    fn get_long(&self, pos: usize) -> Result<i64> {
264        Ok(self
265            .column(pos)?
266            .as_primitive_opt::<Int64Type>()
267            .ok_or_else(|| IllegalArgument {
268                message: format!("expected long array at position {pos}"),
269            })?
270            .value(self.row_id))
271    }
272
273    fn get_float(&self, pos: usize) -> Result<f32> {
274        Ok(self
275            .column(pos)?
276            .as_primitive_opt::<Float32Type>()
277            .ok_or_else(|| IllegalArgument {
278                message: format!("expected float32 array at position {pos}"),
279            })?
280            .value(self.row_id))
281    }
282
283    fn get_double(&self, pos: usize) -> Result<f64> {
284        Ok(self
285            .column(pos)?
286            .as_primitive_opt::<Float64Type>()
287            .ok_or_else(|| IllegalArgument {
288                message: format!("expected float64 array at position {pos}"),
289            })?
290            .value(self.row_id))
291    }
292
293    fn get_char(&self, pos: usize, _length: usize) -> Result<&str> {
294        Ok(self
295            .column(pos)?
296            .as_any()
297            .downcast_ref::<StringArray>()
298            .ok_or_else(|| IllegalArgument {
299                message: format!("expected String array for char type at position {pos}"),
300            })?
301            .value(self.row_id))
302    }
303
304    fn get_string(&self, pos: usize) -> Result<&str> {
305        Ok(self
306            .column(pos)?
307            .as_any()
308            .downcast_ref::<StringArray>()
309            .ok_or_else(|| IllegalArgument {
310                message: format!("expected String array at position {pos}"),
311            })?
312            .value(self.row_id))
313    }
314
315    fn get_decimal(
316        &self,
317        pos: usize,
318        precision: usize,
319        scale: usize,
320    ) -> Result<crate::row::Decimal> {
321        use arrow::datatypes::DataType;
322
323        let column = self.column(pos)?;
324        let array = column
325            .as_primitive_opt::<Decimal128Type>()
326            .ok_or_else(|| IllegalArgument {
327                message: format!(
328                    "expected Decimal128Array at column {pos}, found: {:?}",
329                    column.data_type()
330                ),
331            })?;
332
333        // Contract: caller must check is_null_at() before calling get_decimal.
334        debug_assert!(
335            !array.is_null(self.row_id),
336            "get_decimal called on null value at pos {} row {}",
337            pos,
338            self.row_id
339        );
340
341        // Read scale from Arrow column data type
342        let arrow_scale = match column.data_type() {
343            DataType::Decimal128(_p, s) => *s as i64,
344            dt => {
345                return Err(IllegalArgument {
346                    message: format!(
347                        "expected Decimal128 data type at column {pos}, found: {dt:?}"
348                    ),
349                });
350            }
351        };
352
353        let i128_val = array.value(self.row_id);
354
355        // Convert Arrow Decimal128 to Fluss Decimal (handles rescaling and validation)
356        crate::row::Decimal::from_arrow_decimal128(
357            i128_val,
358            arrow_scale,
359            precision as u32,
360            scale as u32,
361        )
362    }
363
364    fn get_date(&self, pos: usize) -> Result<Date> {
365        Ok(Date::new(self.read_date_from_arrow(pos)?))
366    }
367
368    fn get_time(&self, pos: usize) -> Result<Time> {
369        Ok(Time::new(self.read_time_from_arrow(pos)?))
370    }
371
372    fn get_timestamp_ntz(&self, pos: usize, precision: u32) -> Result<TimestampNtz> {
373        self.read_timestamp_from_arrow(
374            pos,
375            precision,
376            TimestampNtz::new,
377            TimestampNtz::from_millis_nanos,
378        )
379    }
380
381    fn get_timestamp_ltz(&self, pos: usize, precision: u32) -> Result<TimestampLtz> {
382        self.read_timestamp_from_arrow(
383            pos,
384            precision,
385            TimestampLtz::new,
386            TimestampLtz::from_millis_nanos,
387        )
388    }
389
390    fn get_binary(&self, pos: usize, _length: usize) -> Result<&[u8]> {
391        Ok(self
392            .column(pos)?
393            .as_fixed_size_binary_opt()
394            .ok_or_else(|| IllegalArgument {
395                message: format!("expected binary array at position {pos}"),
396            })?
397            .value(self.row_id))
398    }
399
400    fn get_bytes(&self, pos: usize) -> Result<&[u8]> {
401        Ok(self
402            .column(pos)?
403            .as_any()
404            .downcast_ref::<BinaryArray>()
405            .ok_or_else(|| IllegalArgument {
406                message: format!("expected bytes array at position {pos}"),
407            })?
408            .value(self.row_id))
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use arrow::array::{
416        BinaryArray, BooleanArray, Decimal128Array, Float32Array, Float64Array, Int8Array,
417        Int16Array, Int32Array, Int64Array, StringArray,
418    };
419    use arrow::datatypes::{DataType, Field, Schema};
420
421    #[test]
422    fn columnar_row_reads_values() {
423        let schema = Arc::new(Schema::new(vec![
424            Field::new("b", DataType::Boolean, false),
425            Field::new("i8", DataType::Int8, false),
426            Field::new("i16", DataType::Int16, false),
427            Field::new("i32", DataType::Int32, false),
428            Field::new("i64", DataType::Int64, false),
429            Field::new("f32", DataType::Float32, false),
430            Field::new("f64", DataType::Float64, false),
431            Field::new("s", DataType::Utf8, false),
432            Field::new("bin", DataType::Binary, false),
433            Field::new("char", DataType::Utf8, false),
434        ]));
435
436        let batch = RecordBatch::try_new(
437            schema,
438            vec![
439                Arc::new(BooleanArray::from(vec![true])),
440                Arc::new(Int8Array::from(vec![1])),
441                Arc::new(Int16Array::from(vec![2])),
442                Arc::new(Int32Array::from(vec![3])),
443                Arc::new(Int64Array::from(vec![4])),
444                Arc::new(Float32Array::from(vec![1.25])),
445                Arc::new(Float64Array::from(vec![2.5])),
446                Arc::new(StringArray::from(vec!["hello"])),
447                Arc::new(BinaryArray::from(vec![b"data".as_slice()])),
448                Arc::new(StringArray::from(vec!["ab"])),
449            ],
450        )
451        .expect("record batch");
452
453        let mut row = ColumnarRow::new(Arc::new(batch));
454        assert_eq!(row.get_field_count(), 10);
455        assert!(row.get_boolean(0).unwrap());
456        assert_eq!(row.get_byte(1).unwrap(), 1);
457        assert_eq!(row.get_short(2).unwrap(), 2);
458        assert_eq!(row.get_int(3).unwrap(), 3);
459        assert_eq!(row.get_long(4).unwrap(), 4);
460        assert_eq!(row.get_float(5).unwrap(), 1.25);
461        assert_eq!(row.get_double(6).unwrap(), 2.5);
462        assert_eq!(row.get_string(7).unwrap(), "hello");
463        assert_eq!(row.get_bytes(8).unwrap(), b"data");
464        assert_eq!(row.get_char(9, 2).unwrap(), "ab");
465        row.set_row_id(0);
466        assert_eq!(row.get_row_id(), 0);
467    }
468
469    #[test]
470    fn columnar_row_reads_decimal() {
471        use arrow::datatypes::DataType;
472        use bigdecimal::{BigDecimal, num_bigint::BigInt};
473
474        // Test with Decimal128
475        let schema = Arc::new(Schema::new(vec![
476            Field::new("dec1", DataType::Decimal128(10, 2), false),
477            Field::new("dec2", DataType::Decimal128(20, 5), false),
478            Field::new("dec3", DataType::Decimal128(38, 10), false),
479        ]));
480
481        // Create decimal values: 123.45, 12345.67890, large decimal
482        let dec1_val = 12345i128; // 123.45 with scale 2
483        let dec2_val = 1234567890i128; // 12345.67890 with scale 5
484        let dec3_val = 999999999999999999i128; // Large value (18 nines) with scale 10
485
486        let batch = RecordBatch::try_new(
487            schema,
488            vec![
489                Arc::new(
490                    Decimal128Array::from(vec![dec1_val])
491                        .with_precision_and_scale(10, 2)
492                        .unwrap(),
493                ),
494                Arc::new(
495                    Decimal128Array::from(vec![dec2_val])
496                        .with_precision_and_scale(20, 5)
497                        .unwrap(),
498                ),
499                Arc::new(
500                    Decimal128Array::from(vec![dec3_val])
501                        .with_precision_and_scale(38, 10)
502                        .unwrap(),
503                ),
504            ],
505        )
506        .expect("record batch");
507
508        let row = ColumnarRow::new(Arc::new(batch));
509        assert_eq!(row.get_field_count(), 3);
510
511        // Verify decimal values
512        assert_eq!(
513            row.get_decimal(0, 10, 2).unwrap(),
514            crate::row::Decimal::from_big_decimal(BigDecimal::new(BigInt::from(12345), 2), 10, 2)
515                .unwrap()
516        );
517        assert_eq!(
518            row.get_decimal(1, 20, 5).unwrap(),
519            crate::row::Decimal::from_big_decimal(
520                BigDecimal::new(BigInt::from(1234567890), 5),
521                20,
522                5
523            )
524            .unwrap()
525        );
526        assert_eq!(
527            row.get_decimal(2, 38, 10).unwrap(),
528            crate::row::Decimal::from_big_decimal(
529                BigDecimal::new(BigInt::from(999999999999999999i128), 10),
530                38,
531                10
532            )
533            .unwrap()
534        );
535    }
536}