1use arrow::array::ArrowNativeTypeOp;
21use arrow::compute::DecimalCast;
22use arrow::datatypes::{
23 self, ArrowPrimitiveType, ArrowTimestampType, Decimal32Type, Decimal64Type, Decimal128Type,
24 DecimalType,
25};
26use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16};
27
28#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct CastOptions {
31 pub strict: bool,
33}
34
35impl Default for CastOptions {
36 fn default() -> Self {
37 Self { strict: true }
38 }
39}
40
41pub(crate) trait PrimitiveFromVariant: ArrowPrimitiveType {
43 fn from_variant(variant: &Variant<'_, '_>) -> Option<Self::Native>;
44}
45
46pub(crate) trait TimestampFromVariant<const NTZ: bool>: ArrowTimestampType {
50 fn from_variant(variant: &Variant<'_, '_>) -> Option<Self::Native>;
51}
52
53macro_rules! impl_primitive_from_variant {
55 ($arrow_type:ty, $variant_method:ident $(, $cast_fn:expr)?) => {
56 impl PrimitiveFromVariant for $arrow_type {
57 fn from_variant(variant: &Variant<'_, '_>) -> Option<Self::Native> {
58 let value = variant.$variant_method();
59 $( let value = value.map($cast_fn); )?
60 value
61 }
62 }
63 };
64}
65
66macro_rules! impl_timestamp_from_variant {
67 ($timestamp_type:ty, $variant_method:ident, ntz=$ntz:ident, $cast_fn:expr $(,)?) => {
68 impl TimestampFromVariant<{ $ntz }> for $timestamp_type {
69 fn from_variant(variant: &Variant<'_, '_>) -> Option<Self::Native> {
70 variant.$variant_method().and_then($cast_fn)
71 }
72 }
73 };
74}
75
76impl_primitive_from_variant!(datatypes::Int32Type, as_int32);
77impl_primitive_from_variant!(datatypes::Int16Type, as_int16);
78impl_primitive_from_variant!(datatypes::Int8Type, as_int8);
79impl_primitive_from_variant!(datatypes::Int64Type, as_int64);
80impl_primitive_from_variant!(datatypes::UInt8Type, as_u8);
81impl_primitive_from_variant!(datatypes::UInt16Type, as_u16);
82impl_primitive_from_variant!(datatypes::UInt32Type, as_u32);
83impl_primitive_from_variant!(datatypes::UInt64Type, as_u64);
84impl_primitive_from_variant!(datatypes::Float16Type, as_f16);
85impl_primitive_from_variant!(datatypes::Float32Type, as_f32);
86impl_primitive_from_variant!(datatypes::Float64Type, as_f64);
87impl_primitive_from_variant!(
88 datatypes::Date32Type,
89 as_naive_date,
90 datatypes::Date32Type::from_naive_date
91);
92impl_timestamp_from_variant!(
93 datatypes::TimestampMicrosecondType,
94 as_timestamp_ntz_micros,
95 ntz = true,
96 Self::make_value,
97);
98impl_timestamp_from_variant!(
99 datatypes::TimestampMicrosecondType,
100 as_timestamp_micros,
101 ntz = false,
102 |timestamp| Self::make_value(timestamp.naive_utc())
103);
104impl_timestamp_from_variant!(
105 datatypes::TimestampNanosecondType,
106 as_timestamp_ntz_nanos,
107 ntz = true,
108 Self::make_value
109);
110impl_timestamp_from_variant!(
111 datatypes::TimestampNanosecondType,
112 as_timestamp_nanos,
113 ntz = false,
114 |timestamp| Self::make_value(timestamp.naive_utc())
115);
116
117pub(crate) fn variant_to_unscaled_decimal<O>(
127 variant: &Variant<'_, '_>,
128 precision: u8,
129 scale: i8,
130) -> Option<O::Native>
131where
132 O: DecimalType,
133 O::Native: DecimalCast,
134{
135 match variant {
136 Variant::Int8(i) => rescale_decimal::<Decimal32Type, O>(
137 *i as i32,
138 VariantDecimal4::MAX_PRECISION,
139 0,
140 precision,
141 scale,
142 ),
143 Variant::Int16(i) => rescale_decimal::<Decimal32Type, O>(
144 *i as i32,
145 VariantDecimal4::MAX_PRECISION,
146 0,
147 precision,
148 scale,
149 ),
150 Variant::Int32(i) => rescale_decimal::<Decimal32Type, O>(
151 *i,
152 VariantDecimal4::MAX_PRECISION,
153 0,
154 precision,
155 scale,
156 ),
157 Variant::Int64(i) => rescale_decimal::<Decimal64Type, O>(
158 *i,
159 VariantDecimal8::MAX_PRECISION,
160 0,
161 precision,
162 scale,
163 ),
164 Variant::Decimal4(d) => rescale_decimal::<Decimal32Type, O>(
165 d.integer(),
166 VariantDecimal4::MAX_PRECISION,
167 d.scale() as i8,
168 precision,
169 scale,
170 ),
171 Variant::Decimal8(d) => rescale_decimal::<Decimal64Type, O>(
172 d.integer(),
173 VariantDecimal8::MAX_PRECISION,
174 d.scale() as i8,
175 precision,
176 scale,
177 ),
178 Variant::Decimal16(d) => rescale_decimal::<Decimal128Type, O>(
179 d.integer(),
180 VariantDecimal16::MAX_PRECISION,
181 d.scale() as i8,
182 precision,
183 scale,
184 ),
185 _ => None,
186 }
187}
188
189pub(crate) fn rescale_decimal<I, O>(
193 value: I::Native,
194 input_precision: u8,
195 input_scale: i8,
196 output_precision: u8,
197 output_scale: i8,
198) -> Option<O::Native>
199where
200 I: DecimalType,
201 O: DecimalType,
202 I::Native: DecimalCast,
203 O::Native: DecimalCast,
204{
205 let delta_scale = output_scale - input_scale;
206
207 let is_infallible_cast =
209 is_infallible_decimal_cast(input_precision, input_scale, output_precision, output_scale);
210
211 let scaled = if delta_scale == 0 {
212 O::Native::from_decimal(value)
213 } else if delta_scale > 0 {
214 let mul = O::Native::from_decimal(10_i128)
215 .and_then(|t| t.pow_checked(delta_scale as u32).ok())?;
216 O::Native::from_decimal(value).and_then(|x| x.mul_checked(mul).ok())
217 } else {
218 let delta_scale = delta_scale.unsigned_abs() as usize;
225 let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale) else {
226 return Some(O::Native::ZERO);
227 };
228 let div = max.add_wrapping(I::Native::ONE);
229 let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE));
230 let half_neg = half.neg_wrapping();
231
232 let d = value.div_wrapping(div);
234 let r = value.mod_wrapping(div);
235
236 let adjusted = match value >= I::Native::ZERO {
238 true if r >= half => d.add_wrapping(I::Native::ONE),
239 false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
240 _ => d,
241 };
242 O::Native::from_decimal(adjusted)
243 };
244
245 scaled.filter(|v| is_infallible_cast || O::is_valid_decimal_precision(*v, output_precision))
246}
247
248fn is_infallible_decimal_cast(
251 input_precision: u8,
252 input_scale: i8,
253 output_precision: u8,
254 output_scale: i8,
255) -> bool {
256 let delta_scale = output_scale - input_scale;
257 let input_precision = input_precision as i8;
258 let output_precision = output_precision as i8;
259 if delta_scale >= 0 {
260 input_precision + delta_scale <= output_precision
267 } else {
268 input_precision + delta_scale < output_precision
279 }
280}
281
282macro_rules! non_generic_conversion_single_value {
284 ($array:expr, $cast_fn:expr, $index:expr) => {{
285 let array = $array;
286 if array.is_null($index) {
287 Variant::Null
288 } else {
289 let cast_value = $cast_fn(array.value($index));
290 Variant::from(cast_value)
291 }
292 }};
293}
294pub(crate) use non_generic_conversion_single_value;
295
296macro_rules! generic_conversion_single_value {
300 ($t:ty, $method:ident, $cast_fn:expr, $input:expr, $index:expr) => {{
301 $crate::type_conversion::non_generic_conversion_single_value!(
302 $input.$method::<$t>(),
303 $cast_fn,
304 $index
305 )
306 }};
307}
308pub(crate) use generic_conversion_single_value;
309
310macro_rules! primitive_conversion_single_value {
312 ($t:ty, $input:expr, $index:expr) => {{
313 $crate::type_conversion::generic_conversion_single_value!(
314 $t,
315 as_primitive,
316 |v| v,
317 $input,
318 $index
319 )
320 }};
321}
322pub(crate) use primitive_conversion_single_value;