datafusion_spark/function/math/
width_bucket.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::function::error_utils::unsupported_data_types_exec_err;
22use arrow::array::{
23    Array, ArrayRef, DurationMicrosecondArray, Float64Array, IntervalMonthDayNanoArray,
24    IntervalYearMonthArray,
25};
26use arrow::datatypes::DataType;
27use arrow::datatypes::DataType::{Duration, Float64, Int32, Interval};
28use arrow::datatypes::IntervalUnit::{MonthDayNano, YearMonth};
29use datafusion_common::cast::{
30    as_duration_microsecond_array, as_float64_array, as_int32_array,
31    as_interval_mdn_array, as_interval_ym_array,
32};
33use datafusion_common::{exec_err, Result};
34use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
35use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature};
36use datafusion_functions::utils::make_scalar_function;
37
38use arrow::array::{Int32Array, Int32Builder};
39use arrow::datatypes::TimeUnit::Microsecond;
40use datafusion_expr::Volatility::Immutable;
41
42#[derive(Debug, PartialEq, Eq, Hash)]
43pub struct SparkWidthBucket {
44    signature: Signature,
45}
46
47impl Default for SparkWidthBucket {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl SparkWidthBucket {
54    pub fn new() -> Self {
55        Self {
56            signature: Signature::user_defined(Immutable),
57        }
58    }
59}
60
61impl ScalarUDFImpl for SparkWidthBucket {
62    fn as_any(&self) -> &dyn Any {
63        self
64    }
65
66    fn name(&self) -> &str {
67        "width_bucket"
68    }
69
70    fn signature(&self) -> &Signature {
71        &self.signature
72    }
73
74    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
75        Ok(Int32)
76    }
77
78    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
79        make_scalar_function(width_bucket_kern, vec![])(&args.args)
80    }
81
82    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
83        if input.len() == 1 {
84            let value = &input[0];
85            Ok(value.sort_properties)
86        } else {
87            Ok(SortProperties::default())
88        }
89    }
90
91    fn coerce_types(&self, types: &[DataType]) -> Result<Vec<DataType>> {
92        use DataType::*;
93
94        let (v, lo, hi, n) = (&types[0], &types[1], &types[2], &types[3]);
95
96        let is_num = |t: &DataType| {
97            matches!(
98                t,
99                Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
100            )
101        };
102
103        match (v, lo, hi, n) {
104            (a, b, c, &(Int8 | Int16 | Int32 | Int64))
105                if is_num(a) && is_num(b) && is_num(c) =>
106            {
107                Ok(vec![Float64, Float64, Float64, Int32])
108            }
109            (
110                &Duration(_),
111                &Duration(_),
112                &Duration(_),
113                &(Int8 | Int16 | Int32 | Int64),
114            ) => Ok(vec![
115                Duration(Microsecond),
116                Duration(Microsecond),
117                Duration(Microsecond),
118                Int32,
119            ]),
120            (
121                &Interval(MonthDayNano),
122                &Interval(MonthDayNano),
123                &Interval(MonthDayNano),
124                &(Int8 | Int16 | Int32 | Int64),
125            ) => Ok(vec![
126                Interval(MonthDayNano),
127                Interval(MonthDayNano),
128                Interval(MonthDayNano),
129                Int32,
130            ]),
131            (
132                &Interval(YearMonth),
133                &Interval(YearMonth),
134                &Interval(YearMonth),
135                &(Int8 | Int16 | Int32 | Int64),
136            ) => Ok(vec![
137                Interval(YearMonth),
138                Interval(YearMonth),
139                Interval(YearMonth),
140                Int32,
141            ]),
142
143            _ => exec_err!(
144                "width_bucket expects a numeric argument, got {} {} {} {}",
145                types[0],
146                types[1],
147                types[2],
148                types[3]
149            ),
150        }
151    }
152}
153
154fn width_bucket_kern(args: &[ArrayRef]) -> Result<ArrayRef> {
155    let [v, minv, maxv, nb] = args else {
156        return exec_err!(
157            "width_bucket expects exactly 4 argument, got {}",
158            args.len()
159        );
160    };
161
162    match v.data_type() {
163        Float64 => {
164            let v = as_float64_array(v)?;
165            let min = as_float64_array(minv)?;
166            let max = as_float64_array(maxv)?;
167            let n_bucket = as_int32_array(nb)?;
168            Ok(Arc::new(width_bucket_float64(v, min, max, n_bucket)))
169        }
170        Duration(Microsecond) => {
171            let v = as_duration_microsecond_array(v)?;
172            let min = as_duration_microsecond_array(minv)?;
173            let max = as_duration_microsecond_array(maxv)?;
174            let n_bucket = as_int32_array(nb)?;
175            Ok(Arc::new(width_bucket_i64_as_float(v, min, max, n_bucket)))
176        }
177        Interval(YearMonth) => {
178            let v = as_interval_ym_array(v)?;
179            let min = as_interval_ym_array(minv)?;
180            let max = as_interval_ym_array(maxv)?;
181            let n_bucket = as_int32_array(nb)?;
182            Ok(Arc::new(width_bucket_i32_as_float(v, min, max, n_bucket)))
183        }
184        Interval(MonthDayNano) => {
185            let v = as_interval_mdn_array(v)?;
186            let min = as_interval_mdn_array(minv)?;
187            let max = as_interval_mdn_array(maxv)?;
188            let n_bucket = as_int32_array(nb)?;
189            Ok(Arc::new(width_bucket_interval_mdn_exact(v, min, max, n_bucket)))
190        }
191
192
193        other => Err(unsupported_data_types_exec_err(
194            "width_bucket",
195            "Float/Decimal OR Duration OR Interval(YearMonth) for first 3 args; Int for 4th",
196            &[
197                other.clone(),
198                minv.data_type().clone(),
199                maxv.data_type().clone(),
200                nb.data_type().clone(),
201            ],
202        )),
203    }
204}
205
206macro_rules! width_bucket_kernel_impl {
207    ($name:ident, $arr_ty:ty, $to_f64:expr, $check_nan:expr) => {
208        pub(crate) fn $name(
209            v: &$arr_ty,
210            min: &$arr_ty,
211            max: &$arr_ty,
212            n_bucket: &Int32Array,
213        ) -> Int32Array {
214            let len = v.len();
215            let mut b = Int32Builder::with_capacity(len);
216
217            for i in 0..len {
218                if v.is_null(i) || min.is_null(i) || max.is_null(i) || n_bucket.is_null(i)
219                {
220                    b.append_null();
221                    continue;
222                }
223                let x = ($to_f64)(v, i);
224                let l = ($to_f64)(min, i);
225                let h = ($to_f64)(max, i);
226                let buckets = n_bucket.value(i);
227
228                if buckets <= 0 {
229                    b.append_null();
230                    continue;
231                }
232                if $check_nan {
233                    if !x.is_finite() || !l.is_finite() || !h.is_finite() {
234                        b.append_null();
235                        continue;
236                    }
237                }
238
239                let ord = match l.partial_cmp(&h) {
240                    Some(o) => o,
241                    None => {
242                        b.append_null();
243                        continue;
244                    }
245                };
246                if matches!(ord, std::cmp::Ordering::Equal) {
247                    b.append_null();
248                    continue;
249                }
250                let asc = matches!(ord, std::cmp::Ordering::Less);
251
252                if asc {
253                    if x < l {
254                        b.append_value(0);
255                        continue;
256                    }
257                    if x >= h {
258                        b.append_value(buckets + 1);
259                        continue;
260                    }
261                } else {
262                    if x > l {
263                        b.append_value(0);
264                        continue;
265                    }
266                    if x <= h {
267                        b.append_value(buckets + 1);
268                        continue;
269                    }
270                }
271
272                let width = (h - l) / (buckets as f64);
273                if width == 0.0 || !width.is_finite() {
274                    b.append_null();
275                    continue;
276                }
277                let mut bucket = ((x - l) / width).floor() as i32 + 1;
278                if bucket < 1 {
279                    bucket = 1;
280                }
281                if bucket > buckets + 1 {
282                    bucket = buckets + 1;
283                }
284
285                b.append_value(bucket);
286            }
287
288            b.finish()
289        }
290    };
291}
292
293width_bucket_kernel_impl!(
294    width_bucket_float64,
295    Float64Array,
296    |arr: &Float64Array, i: usize| arr.value(i),
297    true
298);
299
300width_bucket_kernel_impl!(
301    width_bucket_i64_as_float,
302    DurationMicrosecondArray,
303    |arr: &DurationMicrosecondArray, i: usize| arr.value(i) as f64,
304    false
305);
306
307width_bucket_kernel_impl!(
308    width_bucket_i32_as_float,
309    IntervalYearMonthArray,
310    |arr: &IntervalYearMonthArray, i: usize| arr.value(i) as f64,
311    false
312);
313const NS_PER_DAY_I128: i128 = 86_400_000_000_000;
314pub(crate) fn width_bucket_interval_mdn_exact(
315    v: &IntervalMonthDayNanoArray,
316    lo: &IntervalMonthDayNanoArray,
317    hi: &IntervalMonthDayNanoArray,
318    n: &Int32Array,
319) -> Int32Array {
320    let len = v.len();
321    let mut b = Int32Builder::with_capacity(len);
322
323    for i in 0..len {
324        if v.is_null(i) || lo.is_null(i) || hi.is_null(i) || n.is_null(i) {
325            b.append_null();
326            continue;
327        }
328        let buckets = n.value(i);
329        if buckets <= 0 {
330            b.append_null();
331            continue;
332        }
333
334        let x = v.value(i);
335        let l = lo.value(i);
336        let h = hi.value(i);
337
338        // asc/desc
339        // Values of IntervalMonthDayNano are compared using their binary representation, which can lead to surprising results.
340        let asc = (l.months, l.days, l.nanoseconds) < (h.months, h.days, h.nanoseconds);
341        if (l.months, l.days, l.nanoseconds) == (h.months, h.days, h.nanoseconds) {
342            b.append_null();
343            continue;
344        }
345
346        // ------------------- only month -------------------
347        if l.days == h.days && l.nanoseconds == h.nanoseconds && l.months != h.months {
348            let x_m = x.months as f64;
349            let l_m = l.months as f64;
350            let h_m = h.months as f64;
351
352            if asc {
353                if x_m < l_m {
354                    b.append_value(0);
355                    continue;
356                }
357                if x_m >= h_m {
358                    b.append_value(buckets + 1);
359                    continue;
360                }
361            } else {
362                if x_m > l_m {
363                    b.append_value(0);
364                    continue;
365                }
366                if x_m <= h_m {
367                    b.append_value(buckets + 1);
368                    continue;
369                }
370            }
371
372            let width = (h_m - l_m) / (buckets as f64);
373            if width == 0.0 || !width.is_finite() {
374                b.append_null();
375                continue;
376            }
377
378            let mut bucket = ((x_m - l_m) / width).floor() as i32 + 1;
379            if bucket < 1 {
380                bucket = 1;
381            }
382            if bucket > buckets + 1 {
383                bucket = buckets + 1;
384            }
385            b.append_value(bucket);
386            continue;
387        }
388
389        // ---------------  months equals -------------------
390        if l.months == h.months {
391            let base_days = l.days as i128;
392            let base_ns = l.nanoseconds as i128;
393
394            let xf = (x.days as i128 - base_days) * NS_PER_DAY_I128
395                + (x.nanoseconds as i128 - base_ns);
396            let hf = (h.days as i128 - base_days) * NS_PER_DAY_I128
397                + (h.nanoseconds as i128 - base_ns);
398
399            let x_f = xf as f64;
400            let l_f = 0.0;
401            let h_f = hf as f64;
402
403            if asc {
404                if x_f < l_f {
405                    b.append_value(0);
406                    continue;
407                }
408                if x_f >= h_f {
409                    b.append_value(buckets + 1);
410                    continue;
411                }
412            } else {
413                if x_f > l_f {
414                    b.append_value(0);
415                    continue;
416                }
417                if x_f <= h_f {
418                    b.append_value(buckets + 1);
419                    continue;
420                }
421            }
422
423            let width = (h_f - l_f) / (buckets as f64);
424            if width == 0.0 || !width.is_finite() {
425                b.append_null();
426                continue;
427            }
428
429            let mut bucket = ((x_f - l_f) / width).floor() as i32 + 1;
430            if bucket < 1 {
431                bucket = 1;
432            }
433            if bucket > buckets + 1 {
434                bucket = buckets + 1;
435            }
436            b.append_value(bucket);
437            continue;
438        }
439
440        b.append_null();
441    }
442
443    b.finish()
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449    use std::sync::Arc;
450
451    use arrow::array::{
452        ArrayRef, DurationMicrosecondArray, Float64Array, Int32Array,
453        IntervalYearMonthArray,
454    };
455    use arrow::datatypes::IntervalMonthDayNano;
456
457    // --- Helpers -------------------------------------------------------------
458
459    fn i32_array_all(len: usize, val: i32) -> Arc<Int32Array> {
460        Arc::new(Int32Array::from(vec![val; len]))
461    }
462
463    fn f64_array(vals: &[f64]) -> Arc<Float64Array> {
464        Arc::new(Float64Array::from(vals.to_vec()))
465    }
466
467    fn f64_array_opt(vals: &[Option<f64>]) -> Arc<Float64Array> {
468        Arc::new(Float64Array::from(vals.to_vec()))
469    }
470
471    fn dur_us_array(vals: &[i64]) -> Arc<DurationMicrosecondArray> {
472        Arc::new(DurationMicrosecondArray::from(vals.to_vec()))
473    }
474
475    fn ym_array(vals: &[i32]) -> Arc<IntervalYearMonthArray> {
476        Arc::new(IntervalYearMonthArray::from(vals.to_vec()))
477    }
478
479    fn downcast_i32(arr: &ArrayRef) -> &Int32Array {
480        arr.as_any().downcast_ref::<Int32Array>().unwrap()
481    }
482
483    fn mdn_array(vals: &[(i32, i32, i64)]) -> Arc<IntervalMonthDayNanoArray> {
484        let data: Vec<IntervalMonthDayNano> = vals
485            .iter()
486            .map(|(m, d, ns)| IntervalMonthDayNano::new(*m, *d, *ns))
487            .collect();
488        Arc::new(IntervalMonthDayNanoArray::from(data))
489    }
490
491    // --- Float64 -------------------------------------------------------------
492
493    #[test]
494    fn test_width_bucket_f64_basic() {
495        let v = f64_array(&[0.5, 1.0, 9.9, -1.0, 10.0]);
496        let lo = f64_array(&[0.0, 0.0, 0.0, 0.0, 0.0]);
497        let hi = f64_array(&[10.0, 10.0, 10.0, 10.0, 10.0]);
498        let n = i32_array_all(5, 10);
499
500        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
501        let out = downcast_i32(&out);
502        assert_eq!(out.values(), &[1, 2, 10, 0, 11]);
503    }
504
505    #[test]
506    fn test_width_bucket_f64_descending_range() {
507        let v = f64_array(&[9.9, 10.0, 0.0, -0.1, 10.1]);
508        let lo = f64_array(&[10.0; 5]);
509        let hi = f64_array(&[0.0; 5]);
510        let n = i32_array_all(5, 10);
511
512        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
513        let out = downcast_i32(&out);
514
515        assert_eq!(out.values(), &[1, 1, 11, 11, 0]);
516    }
517    #[test]
518    fn test_width_bucket_f64_bounds_inclusive_exclusive_asc() {
519        let v = f64_array(&[0.0, 9.999999999, 10.0]);
520        let lo = f64_array(&[0.0; 3]);
521        let hi = f64_array(&[10.0; 3]);
522        let n = i32_array_all(3, 10);
523
524        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
525        let out = downcast_i32(&out);
526        assert_eq!(out.values(), &[1, 10, 11]);
527    }
528
529    #[test]
530    fn test_width_bucket_f64_bounds_inclusive_exclusive_desc() {
531        let v = f64_array(&[10.0, 0.0, -0.000001]);
532        let lo = f64_array(&[10.0; 3]);
533        let hi = f64_array(&[0.0; 3]);
534        let n = i32_array_all(3, 10);
535
536        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
537        let out = downcast_i32(&out);
538        assert_eq!(out.values(), &[1, 11, 11]);
539    }
540
541    #[test]
542    fn test_width_bucket_f64_edge_cases() {
543        let v = f64_array(&[1.0, 5.0, 9.0]);
544        let lo = f64_array(&[0.0, 0.0, 0.0]);
545        let hi = f64_array(&[10.0, 10.0, 10.0]);
546        let n = Arc::new(Int32Array::from(vec![0, -1, 10]));
547        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
548        let out = downcast_i32(&out);
549        assert!(out.is_null(0));
550        assert!(out.is_null(1));
551        assert_eq!(out.value(2), 10);
552
553        let v = f64_array(&[1.0]);
554        let lo = f64_array(&[5.0]);
555        let hi = f64_array(&[5.0]);
556        let n = i32_array_all(1, 10);
557        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
558        let out = downcast_i32(&out);
559        assert!(out.is_null(0));
560
561        let v = f64_array_opt(&[Some(f64::NAN)]);
562        let lo = f64_array(&[0.0]);
563        let hi = f64_array(&[10.0]);
564        let n = i32_array_all(1, 10);
565        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
566        let out = downcast_i32(&out);
567        assert!(out.is_null(0));
568    }
569
570    #[test]
571    fn test_width_bucket_f64_nulls_propagate() {
572        let v = f64_array_opt(&[None, Some(1.0), Some(2.0), Some(3.0)]);
573        let lo = f64_array(&[0.0; 4]);
574        let hi = f64_array(&[10.0; 4]);
575        let n = i32_array_all(4, 10);
576
577        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
578        let out = downcast_i32(&out);
579        assert!(out.is_null(0));
580        assert_eq!(out.value(1), 2);
581        assert_eq!(out.value(2), 3);
582        assert_eq!(out.value(3), 4);
583
584        let v = f64_array(&[1.0]);
585        let lo = f64_array_opt(&[None]);
586        let hi = f64_array(&[10.0]);
587        let n = i32_array_all(1, 10);
588        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
589        let out = downcast_i32(&out);
590        assert!(out.is_null(0));
591    }
592
593    // --- Duration(Microsecond) ----------------------------------------------
594
595    #[test]
596    fn test_width_bucket_duration_us() {
597        let v = dur_us_array(&[1_000_000, 0, -1]);
598        let lo = dur_us_array(&[0, 0, 0]);
599        let hi = dur_us_array(&[2_000_000, 2_000_000, 2_000_000]);
600        let n = i32_array_all(3, 2);
601
602        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
603        let out = downcast_i32(&out);
604        assert_eq!(out.values(), &[2, 1, 0]);
605    }
606
607    #[test]
608    fn test_width_bucket_duration_us_equal_bounds() {
609        let v = dur_us_array(&[0]);
610        let lo = dur_us_array(&[1]);
611        let hi = dur_us_array(&[1]);
612        let n = i32_array_all(1, 10);
613        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
614        assert!(downcast_i32(&out).is_null(0));
615    }
616
617    // --- Interval(YearMonth) ------------------------------------------------
618
619    #[test]
620    fn test_width_bucket_interval_ym_basic() {
621        let v = ym_array(&[0, 5, 11, 12, 13]);
622        let lo = ym_array(&[0; 5]);
623        let hi = ym_array(&[12; 5]);
624        let n = i32_array_all(5, 12);
625
626        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
627        let out = downcast_i32(&out);
628        assert_eq!(out.values(), &[1, 6, 12, 13, 13]);
629    }
630
631    #[test]
632    fn test_width_bucket_interval_ym_desc() {
633        let v = ym_array(&[11, 12, 0, -1, 13]);
634        let lo = ym_array(&[12; 5]);
635        let hi = ym_array(&[0; 5]);
636        let n = i32_array_all(5, 12);
637
638        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
639        let out = downcast_i32(&out);
640        assert_eq!(out.values(), &[2, 1, 13, 13, 0]);
641    }
642
643    // --- Interval(MonthDayNano) --------------------------------------------
644
645    #[test]
646    fn test_width_bucket_interval_mdn_months_only_basic() {
647        let v = mdn_array(&[(0, 0, 0), (5, 0, 0), (11, 0, 0), (12, 0, 0), (13, 0, 0)]);
648        let lo = mdn_array(&[(0, 0, 0); 5]);
649        let hi = mdn_array(&[(12, 0, 0); 5]);
650        let n = i32_array_all(5, 12);
651
652        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
653        let out = downcast_i32(&out);
654        assert_eq!(out.values(), &[1, 6, 12, 13, 13]);
655    }
656
657    #[test]
658    fn test_width_bucket_interval_mdn_months_only_desc() {
659        let v = mdn_array(&[(11, 0, 0), (12, 0, 0), (0, 0, 0), (-1, 0, 0), (13, 0, 0)]);
660        let lo = mdn_array(&[(12, 0, 0); 5]);
661        let hi = mdn_array(&[(0, 0, 0); 5]);
662        let n = i32_array_all(5, 12);
663
664        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
665        let out = downcast_i32(&out);
666        // Mismo patrĂ³n que YM descendente
667        assert_eq!(out.values(), &[2, 1, 13, 13, 0]);
668    }
669
670    #[test]
671    fn test_width_bucket_interval_mdn_day_nano_basic() {
672        let v = mdn_array(&[
673            (0, 0, 0),
674            (0, 5, 0),
675            (0, 9, 0),
676            (0, 10, 0),
677            (0, -1, 0),
678            (0, 11, 0),
679        ]);
680        let lo = mdn_array(&[(0, 0, 0); 6]);
681        let hi = mdn_array(&[(0, 10, 0); 6]);
682        let n = i32_array_all(6, 10);
683
684        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
685        let out = downcast_i32(&out);
686        // x==hi -> n+1, x<lo -> 0, x>hi -> n+1
687        assert_eq!(out.values(), &[1, 6, 10, 11, 0, 11]);
688    }
689
690    #[test]
691    fn test_width_bucket_interval_mdn_day_nano_desc() {
692        let v = mdn_array(&[(0, 9, 0), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]);
693        let lo = mdn_array(&[(0, 10, 0); 5]);
694        let hi = mdn_array(&[(0, 0, 0); 5]);
695        let n = i32_array_all(5, 10);
696
697        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
698        let out = downcast_i32(&out);
699
700        assert_eq!(out.values(), &[2, 1, 11, 11, 0]);
701    }
702    #[test]
703    fn test_width_bucket_interval_mdn_day_nano_desc_inside() {
704        let v = mdn_array(&[(0, 9, 1), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]);
705        let lo = mdn_array(&[(0, 10, 0); 5]);
706        let hi = mdn_array(&[(0, 0, 0); 5]);
707        let n = i32_array_all(5, 10);
708
709        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
710        let out = downcast_i32(&out);
711
712        assert_eq!(out.values(), &[1, 1, 11, 11, 0]);
713    }
714
715    #[test]
716    fn test_width_bucket_interval_mdn_mixed_months_and_days_is_null() {
717        let v = mdn_array(&[(0, 1, 0)]);
718        let lo = mdn_array(&[(0, 0, 0)]);
719        let hi = mdn_array(&[(1, 1, 0)]);
720        let n = i32_array_all(1, 4);
721
722        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
723        let out = downcast_i32(&out);
724        assert!(out.is_null(0));
725    }
726
727    #[test]
728    fn test_width_bucket_interval_mdn_equal_bounds_is_null() {
729        let v = mdn_array(&[(0, 0, 0)]);
730        let lo = mdn_array(&[(1, 2, 3)]);
731        let hi = mdn_array(&[(1, 2, 3)]); // lo == hi
732        let n = i32_array_all(1, 10);
733
734        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
735        assert!(downcast_i32(&out).is_null(0));
736    }
737
738    #[test]
739    fn test_width_bucket_interval_mdn_invalid_n_is_null() {
740        let v = mdn_array(&[(0, 0, 0)]);
741        let lo = mdn_array(&[(0, 0, 0)]);
742        let hi = mdn_array(&[(0, 10, 0)]);
743        let n = Arc::new(Int32Array::from(vec![0])); // n <= 0
744
745        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
746        assert!(downcast_i32(&out).is_null(0));
747    }
748
749    #[test]
750    fn test_width_bucket_interval_mdn_nulls_propagate() {
751        let v = Arc::new(IntervalMonthDayNanoArray::from(vec![
752            None,
753            Some(IntervalMonthDayNano::new(0, 5, 0)),
754        ]));
755        let lo = mdn_array(&[(0, 0, 0), (0, 0, 0)]);
756        let hi = mdn_array(&[(0, 10, 0), (0, 10, 0)]);
757        let n = i32_array_all(2, 10);
758
759        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
760        let out = downcast_i32(&out);
761        assert!(out.is_null(0));
762        assert_eq!(out.value(1), 6);
763    }
764
765    // --- Errores -------------------------------------------------------------
766
767    #[test]
768    fn test_width_bucket_wrong_arg_count() {
769        let v = f64_array(&[1.0]);
770        let lo = f64_array(&[0.0]);
771        let hi = f64_array(&[10.0]);
772        let err = width_bucket_kern(&[v, lo, hi]).unwrap_err();
773        let msg = format!("{err}");
774        assert!(msg.contains("expects exactly 4"), "unexpected error: {msg}");
775    }
776
777    #[test]
778    fn test_width_bucket_unsupported_type() {
779        let v: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
780        let lo = f64_array(&[0.0, 0.0, 0.0]);
781        let hi = f64_array(&[10.0, 10.0, 10.0]);
782        let n = i32_array_all(3, 10);
783
784        let err = width_bucket_kern(&[v, lo, hi, n]).unwrap_err();
785        let msg = format!("{err}");
786        assert!(
787            msg.contains("unsupported data types")
788                || msg.contains("Float/Decimal OR Duration OR Interval(YearMonth)"),
789            "unexpected error: {msg}"
790        );
791    }
792}