parquet_variant_compute/
type_conversion.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Module for transforming a typed arrow `Array` to `VariantArray`.
19
20use arrow::array::ArrowNativeTypeOp;
21use arrow::compute::DecimalCast;
22use arrow::datatypes::{
23    self, ArrowPrimitiveType, ArrowTimestampType, Decimal32Type, Decimal64Type, Decimal128Type,
24    DecimalType,
25};
26use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16};
27
28/// Options for controlling the behavior of `cast_to_variant_with_options`.
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct CastOptions {
31    /// If true, return error on conversion failure. If false, insert null for failed conversions.
32    pub strict: bool,
33}
34
35impl Default for CastOptions {
36    fn default() -> Self {
37        Self { strict: true }
38    }
39}
40
41/// Extension trait for Arrow primitive types that can extract their native value from a Variant
42pub(crate) trait PrimitiveFromVariant: ArrowPrimitiveType {
43    fn from_variant(variant: &Variant<'_, '_>) -> Option<Self::Native>;
44}
45
46/// Extension trait for Arrow timestamp types that can extract their native value from a Variant
47/// We can't use [`PrimitiveFromVariant`] directly because we need _two_ implementations for each
48/// timestamp type -- the `NTZ` param here.
49pub(crate) trait TimestampFromVariant<const NTZ: bool>: ArrowTimestampType {
50    fn from_variant(variant: &Variant<'_, '_>) -> Option<Self::Native>;
51}
52
53/// Macro to generate PrimitiveFromVariant implementations for Arrow primitive types
54macro_rules! impl_primitive_from_variant {
55    ($arrow_type:ty, $variant_method:ident $(, $cast_fn:expr)?) => {
56        impl PrimitiveFromVariant for $arrow_type {
57            fn from_variant(variant: &Variant<'_, '_>) -> Option<Self::Native> {
58                let value = variant.$variant_method();
59                $( let value = value.map($cast_fn); )?
60                value
61            }
62        }
63    };
64}
65
66macro_rules! impl_timestamp_from_variant {
67    ($timestamp_type:ty, $variant_method:ident, ntz=$ntz:ident, $cast_fn:expr $(,)?) => {
68        impl TimestampFromVariant<{ $ntz }> for $timestamp_type {
69            fn from_variant(variant: &Variant<'_, '_>) -> Option<Self::Native> {
70                variant.$variant_method().and_then($cast_fn)
71            }
72        }
73    };
74}
75
76impl_primitive_from_variant!(datatypes::Int32Type, as_int32);
77impl_primitive_from_variant!(datatypes::Int16Type, as_int16);
78impl_primitive_from_variant!(datatypes::Int8Type, as_int8);
79impl_primitive_from_variant!(datatypes::Int64Type, as_int64);
80impl_primitive_from_variant!(datatypes::UInt8Type, as_u8);
81impl_primitive_from_variant!(datatypes::UInt16Type, as_u16);
82impl_primitive_from_variant!(datatypes::UInt32Type, as_u32);
83impl_primitive_from_variant!(datatypes::UInt64Type, as_u64);
84impl_primitive_from_variant!(datatypes::Float16Type, as_f16);
85impl_primitive_from_variant!(datatypes::Float32Type, as_f32);
86impl_primitive_from_variant!(datatypes::Float64Type, as_f64);
87impl_primitive_from_variant!(
88    datatypes::Date32Type,
89    as_naive_date,
90    datatypes::Date32Type::from_naive_date
91);
92impl_timestamp_from_variant!(
93    datatypes::TimestampMicrosecondType,
94    as_timestamp_ntz_micros,
95    ntz = true,
96    Self::make_value,
97);
98impl_timestamp_from_variant!(
99    datatypes::TimestampMicrosecondType,
100    as_timestamp_micros,
101    ntz = false,
102    |timestamp| Self::make_value(timestamp.naive_utc())
103);
104impl_timestamp_from_variant!(
105    datatypes::TimestampNanosecondType,
106    as_timestamp_ntz_nanos,
107    ntz = true,
108    Self::make_value
109);
110impl_timestamp_from_variant!(
111    datatypes::TimestampNanosecondType,
112    as_timestamp_nanos,
113    ntz = false,
114    |timestamp| Self::make_value(timestamp.naive_utc())
115);
116
117/// Returns the unscaled integer representation for Arrow decimal type `O`
118/// from a `Variant`.
119///
120/// - `precision` and `scale` specify the target Arrow decimal parameters
121/// - Integer variants (`Int8/16/32/64`) are treated as decimals with scale 0
122/// - Decimal variants (`Decimal4/8/16`) use their embedded precision and scale
123///
124/// The value is rescaled to (`precision`, `scale`) using `rescale_decimal` and
125/// returns `None` if it cannot fit the requested precision.
126pub(crate) fn variant_to_unscaled_decimal<O>(
127    variant: &Variant<'_, '_>,
128    precision: u8,
129    scale: i8,
130) -> Option<O::Native>
131where
132    O: DecimalType,
133    O::Native: DecimalCast,
134{
135    match variant {
136        Variant::Int8(i) => rescale_decimal::<Decimal32Type, O>(
137            *i as i32,
138            VariantDecimal4::MAX_PRECISION,
139            0,
140            precision,
141            scale,
142        ),
143        Variant::Int16(i) => rescale_decimal::<Decimal32Type, O>(
144            *i as i32,
145            VariantDecimal4::MAX_PRECISION,
146            0,
147            precision,
148            scale,
149        ),
150        Variant::Int32(i) => rescale_decimal::<Decimal32Type, O>(
151            *i,
152            VariantDecimal4::MAX_PRECISION,
153            0,
154            precision,
155            scale,
156        ),
157        Variant::Int64(i) => rescale_decimal::<Decimal64Type, O>(
158            *i,
159            VariantDecimal8::MAX_PRECISION,
160            0,
161            precision,
162            scale,
163        ),
164        Variant::Decimal4(d) => rescale_decimal::<Decimal32Type, O>(
165            d.integer(),
166            VariantDecimal4::MAX_PRECISION,
167            d.scale() as i8,
168            precision,
169            scale,
170        ),
171        Variant::Decimal8(d) => rescale_decimal::<Decimal64Type, O>(
172            d.integer(),
173            VariantDecimal8::MAX_PRECISION,
174            d.scale() as i8,
175            precision,
176            scale,
177        ),
178        Variant::Decimal16(d) => rescale_decimal::<Decimal128Type, O>(
179            d.integer(),
180            VariantDecimal16::MAX_PRECISION,
181            d.scale() as i8,
182            precision,
183            scale,
184        ),
185        _ => None,
186    }
187}
188
189/// Rescale a decimal from (input_precision, input_scale) to (output_precision, output_scale)
190/// and return the scaled value if it fits the output precision. Similar to the implementation in
191/// decimal.rs in arrow-cast.
192pub(crate) fn rescale_decimal<I, O>(
193    value: I::Native,
194    input_precision: u8,
195    input_scale: i8,
196    output_precision: u8,
197    output_scale: i8,
198) -> Option<O::Native>
199where
200    I: DecimalType,
201    O: DecimalType,
202    I::Native: DecimalCast,
203    O::Native: DecimalCast,
204{
205    let delta_scale = output_scale - input_scale;
206
207    // Determine if the cast is infallible based on precision/scale math
208    let is_infallible_cast =
209        is_infallible_decimal_cast(input_precision, input_scale, output_precision, output_scale);
210
211    let scaled = if delta_scale == 0 {
212        O::Native::from_decimal(value)
213    } else if delta_scale > 0 {
214        let mul = O::Native::from_decimal(10_i128)
215            .and_then(|t| t.pow_checked(delta_scale as u32).ok())?;
216        O::Native::from_decimal(value).and_then(|x| x.mul_checked(mul).ok())
217    } else {
218        // delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION. If so, the
219        // scale change divides out more digits than the input has precision and the result of the cast
220        // is always zero. For example, if we try to apply delta_scale=10 a decimal32 value, the largest
221        // possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero. Smaller values
222        // (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000) produce even
223        // smaller results, which also round to zero. In that case, just return an array of zeros.
224        let delta_scale = delta_scale.unsigned_abs() as usize;
225        let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale) else {
226            return Some(O::Native::ZERO);
227        };
228        let div = max.add_wrapping(I::Native::ONE);
229        let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE));
230        let half_neg = half.neg_wrapping();
231
232        // div is >= 10 and so this cannot overflow
233        let d = value.div_wrapping(div);
234        let r = value.mod_wrapping(div);
235
236        // Round result
237        let adjusted = match value >= I::Native::ZERO {
238            true if r >= half => d.add_wrapping(I::Native::ONE),
239            false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
240            _ => d,
241        };
242        O::Native::from_decimal(adjusted)
243    };
244
245    scaled.filter(|v| is_infallible_cast || O::is_valid_decimal_precision(*v, output_precision))
246}
247
248/// Returns true if casting from (input_precision, input_scale) to
249/// (output_precision, output_scale) is infallible based on precision/scale math.
250fn is_infallible_decimal_cast(
251    input_precision: u8,
252    input_scale: i8,
253    output_precision: u8,
254    output_scale: i8,
255) -> bool {
256    let delta_scale = output_scale - input_scale;
257    let input_precision = input_precision as i8;
258    let output_precision = output_precision as i8;
259    if delta_scale >= 0 {
260        // if the gain in precision (digits) is greater than the multiplication due to scaling
261        // every number will fit into the output type
262        // Example: If we are starting with any number of precision 5 [xxxxx],
263        // then an increase of scale by 3 will have the following effect on the representation:
264        // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
265        // needs to provide at least 8 digits precision
266        input_precision + delta_scale <= output_precision
267    } else {
268        // if the reduction of the input number through scaling (dividing) is greater
269        // than a possible precision loss (plus potential increase via rounding)
270        // every input number will fit into the output type
271        // Example: If we are starting with any number of precision 5 [xxxxx],
272        // then and decrease the scale by 3 will have the following effect on the representation:
273        // [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
274        // The rounding may add an additional digit, so for the cast to be infallible,
275        // the output type needs to have at least 3 digits of precision.
276        // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
277        // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
278        input_precision + delta_scale < output_precision
279    }
280}
281
282/// Convert the value at a specific index in the given array into a `Variant`.
283macro_rules! non_generic_conversion_single_value {
284    ($array:expr, $cast_fn:expr, $index:expr) => {{
285        let array = $array;
286        if array.is_null($index) {
287            Variant::Null
288        } else {
289            let cast_value = $cast_fn(array.value($index));
290            Variant::from(cast_value)
291        }
292    }};
293}
294pub(crate) use non_generic_conversion_single_value;
295
296/// Convert the value at a specific index in the given array into a `Variant`,
297/// using `method` requiring a generic type to downcast the generic array
298/// to a specific array type and `cast_fn` to transform the element.
299macro_rules! generic_conversion_single_value {
300    ($t:ty, $method:ident, $cast_fn:expr, $input:expr, $index:expr) => {{
301        $crate::type_conversion::non_generic_conversion_single_value!(
302            $input.$method::<$t>(),
303            $cast_fn,
304            $index
305        )
306    }};
307}
308pub(crate) use generic_conversion_single_value;
309
310/// Convert the value at a specific index in the given array into a `Variant`.
311macro_rules! primitive_conversion_single_value {
312    ($t:ty, $input:expr, $index:expr) => {{
313        $crate::type_conversion::generic_conversion_single_value!(
314            $t,
315            as_primitive,
316            |v| v,
317            $input,
318            $index
319        )
320    }};
321}
322pub(crate) use primitive_conversion_single_value;